save_controls.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. """
  2. 保存功能控件
  3. """
  4. import streamlit as st
  5. import io
  6. import json
  7. from pathlib import Path
  8. from typing import Dict
  9. from .drawing import draw_clean_table_lines
  10. def create_save_section(work_mode: str, structure: Dict, image, line_width: int, output_config: Dict):
  11. """
  12. 保存设置(目录/命名来自配置)
  13. Args:
  14. work_mode: 工作模式
  15. structure: 表格结构
  16. image: 图片对象
  17. line_width: 线条宽度
  18. output_config: 输出配置(兜底用)
  19. """
  20. st.divider()
  21. # 🔑 优先使用当前数据源的输出配置
  22. if 'current_output_config' in st.session_state:
  23. active_output_config = st.session_state.current_output_config
  24. st.info(f"📂 保存位置:{active_output_config.get('directory', 'N/A')}")
  25. else:
  26. active_output_config = output_config
  27. defaults = active_output_config.get("defaults", {})
  28. line_colors = active_output_config.get("line_colors") or [
  29. {"name": "黑色", "rgb": [0, 0, 0]},
  30. {"name": "蓝色", "rgb": [0, 0, 255]},
  31. {"name": "红色", "rgb": [255, 0, 0]},
  32. ]
  33. save_col1, save_col2, save_col3 = st.columns(3)
  34. with save_col1:
  35. save_structure = st.checkbox(
  36. "保存表格结构配置",
  37. value=bool(defaults.get("save_structure", True)),
  38. )
  39. with save_col2:
  40. save_image = st.checkbox(
  41. "保存表格线图片",
  42. value=bool(defaults.get("save_image", True)),
  43. )
  44. color_names = [c["name"] for c in line_colors]
  45. default_color = defaults.get("line_color", color_names[0])
  46. default_index = (
  47. color_names.index(default_color)
  48. if default_color in color_names
  49. else 0
  50. )
  51. with save_col3:
  52. line_color_option = st.selectbox(
  53. "线条颜色",
  54. color_names,
  55. index=default_index,
  56. label_visibility="collapsed",
  57. key="save_line_color"
  58. )
  59. if st.button("💾 保存", type="primary"):
  60. output_dir = Path(active_output_config.get("directory", "output/table_structures"))
  61. output_dir.mkdir(parents=True, exist_ok=True)
  62. structure_suffix = active_output_config.get("structure_suffix", "_structure.json")
  63. image_suffix = active_output_config.get("image_suffix", "_with_lines.png")
  64. # 确定文件名
  65. base_name = _determine_base_name(work_mode)
  66. saved_files = []
  67. if save_structure:
  68. _save_structure_file(
  69. structure,
  70. output_dir,
  71. base_name,
  72. structure_suffix,
  73. saved_files
  74. )
  75. if save_image:
  76. _save_image_file(
  77. image,
  78. structure,
  79. line_width,
  80. line_color_option,
  81. line_colors,
  82. output_dir,
  83. base_name,
  84. image_suffix,
  85. saved_files
  86. )
  87. if saved_files:
  88. st.success(f"✅ 已保存 {len(saved_files)} 个文件:")
  89. for file_type, file_path in saved_files:
  90. st.info(f" • {file_type}: {file_path}")
  91. # 显示当前数据源信息
  92. if 'current_data_source' in st.session_state:
  93. ds = st.session_state.current_data_source
  94. with st.expander("📋 数据源信息"):
  95. st.json({
  96. "名称": ds.get("name"),
  97. "JSON目录": str(ds.get("json_dir")),
  98. "图片目录": str(ds.get("image_dir")),
  99. "输出目录": str(output_dir),
  100. })
  101. def _determine_base_name(work_mode: str) -> str:
  102. """确定保存文件的基础名称"""
  103. if work_mode == "🆕 新建标注" or work_mode == "new":
  104. if st.session_state.loaded_json_name:
  105. return Path(st.session_state.loaded_json_name).stem
  106. else:
  107. return "table_structure"
  108. else:
  109. if st.session_state.loaded_config_name:
  110. base_name = Path(st.session_state.loaded_config_name).stem
  111. if base_name.endswith('_structure'):
  112. base_name = base_name[:-10]
  113. return base_name
  114. elif st.session_state.loaded_image_name:
  115. return Path(st.session_state.loaded_image_name).stem
  116. else:
  117. return "table_structure"
  118. def _save_structure_file(structure, output_dir, base_name, suffix, saved_files):
  119. """保存结构配置文件"""
  120. structure_filename = f"{base_name}{suffix}"
  121. structure_path = output_dir / structure_filename
  122. # save_structure_to_config(structure, structure_path)
  123. with open(structure_path, 'w', encoding='utf-8') as f:
  124. json.dump(structure, f, indent=2, ensure_ascii=False)
  125. saved_files.append(("配置文件", structure_path))
  126. with open(structure_path, 'r') as f:
  127. st.download_button(
  128. "📥 下载配置文件",
  129. f.read(),
  130. file_name=f"{base_name}_structure.json",
  131. mime="application/json"
  132. )
  133. def _save_image_file(image, structure, line_width, color_option, line_colors,
  134. output_dir, base_name, suffix, saved_files):
  135. """保存表格线图片"""
  136. if image is None:
  137. st.warning("⚠️ 无法保存图片:未加载图片文件")
  138. return
  139. selected_color_rgb = next(
  140. (tuple(c["rgb"]) for c in line_colors if c["name"] == color_option),
  141. (0, 0, 0),
  142. )
  143. clean_img = draw_clean_table_lines(
  144. image,
  145. structure,
  146. line_width=line_width,
  147. line_color=selected_color_rgb,
  148. )
  149. image_filename = f"{base_name}{suffix}"
  150. output_image_path = output_dir / image_filename
  151. clean_img.save(output_image_path)
  152. saved_files.append(("表格线图片", output_image_path))
  153. buf = io.BytesIO()
  154. clean_img.save(buf, format='PNG')
  155. buf.seek(0)
  156. st.download_button(
  157. "📥 下载表格线图片",
  158. buf,
  159. file_name=f"{base_name}_with_lines.png",
  160. mime="image/png"
  161. )