save_controls.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. 保存功能控件
  3. """
  4. import streamlit as st
  5. import io
  6. from pathlib import Path
  7. from typing import Dict
  8. from .config_loader import save_structure_to_config
  9. from .drawing import draw_clean_table_lines
  10. def create_save_section(work_mode: str, structure: Dict, image, line_width: int, output_config: Dict):
  11. """
  12. 保存设置(目录/命名来自配置)
  13. Args:
  14. work_mode: 工作模式
  15. structure: 表格结构
  16. image: 图片对象
  17. line_width: 线条宽度
  18. output_config: 输出配置(兜底用)
  19. """
  20. st.divider()
  21. # 🔑 优先使用当前数据源的输出配置
  22. if 'current_output_config' in st.session_state:
  23. active_output_config = st.session_state.current_output_config
  24. st.info(f"📂 保存位置:{active_output_config.get('directory', 'N/A')}")
  25. else:
  26. active_output_config = output_config
  27. defaults = active_output_config.get("defaults", {})
  28. line_colors = active_output_config.get("line_colors") or [
  29. {"name": "黑色", "rgb": [0, 0, 0]},
  30. {"name": "蓝色", "rgb": [0, 0, 255]},
  31. {"name": "红色", "rgb": [255, 0, 0]},
  32. ]
  33. save_col1, save_col2, save_col3 = st.columns(3)
  34. with save_col1:
  35. save_structure = st.checkbox(
  36. "保存表格结构配置",
  37. value=bool(defaults.get("save_structure", True)),
  38. )
  39. with save_col2:
  40. save_image = st.checkbox(
  41. "保存表格线图片",
  42. value=bool(defaults.get("save_image", True)),
  43. )
  44. color_names = [c["name"] for c in line_colors]
  45. default_color = defaults.get("line_color", color_names[0])
  46. default_index = (
  47. color_names.index(default_color)
  48. if default_color in color_names
  49. else 0
  50. )
  51. with save_col3:
  52. line_color_option = st.selectbox(
  53. "线条颜色",
  54. color_names,
  55. index=default_index,
  56. label_visibility="collapsed",
  57. key="save_line_color"
  58. )
  59. if st.button("💾 保存", type="primary"):
  60. output_dir = Path(active_output_config.get("directory", "output/table_structures"))
  61. output_dir.mkdir(parents=True, exist_ok=True)
  62. structure_suffix = active_output_config.get("structure_suffix", "_structure.json")
  63. image_suffix = active_output_config.get("image_suffix", "_with_lines.png")
  64. # 确定文件名
  65. base_name = _determine_base_name(work_mode)
  66. saved_files = []
  67. if save_structure:
  68. _save_structure_file(
  69. structure,
  70. output_dir,
  71. base_name,
  72. structure_suffix,
  73. saved_files
  74. )
  75. if save_image:
  76. _save_image_file(
  77. image,
  78. structure,
  79. line_width,
  80. line_color_option,
  81. line_colors,
  82. output_dir,
  83. base_name,
  84. image_suffix,
  85. saved_files
  86. )
  87. if saved_files:
  88. st.success(f"✅ 已保存 {len(saved_files)} 个文件:")
  89. for file_type, file_path in saved_files:
  90. st.info(f" • {file_type}: {file_path}")
  91. # 显示当前数据源信息
  92. if 'current_data_source' in st.session_state:
  93. ds = st.session_state.current_data_source
  94. with st.expander("📋 数据源信息"):
  95. st.json({
  96. "名称": ds.get("name"),
  97. "JSON目录": str(ds.get("json_dir")),
  98. "图片目录": str(ds.get("image_dir")),
  99. "输出目录": str(output_dir),
  100. })
  101. def _determine_base_name(work_mode: str) -> str:
  102. """确定保存文件的基础名称"""
  103. if work_mode == "🆕 新建标注" or work_mode == "new":
  104. if st.session_state.loaded_json_name:
  105. return Path(st.session_state.loaded_json_name).stem
  106. else:
  107. return "table_structure"
  108. else:
  109. if st.session_state.loaded_config_name:
  110. base_name = Path(st.session_state.loaded_config_name).stem
  111. if base_name.endswith('_structure'):
  112. base_name = base_name[:-10]
  113. return base_name
  114. elif st.session_state.loaded_image_name:
  115. return Path(st.session_state.loaded_image_name).stem
  116. else:
  117. return "table_structure"
  118. def _save_structure_file(structure, output_dir, base_name, suffix, saved_files):
  119. """保存结构配置文件"""
  120. structure_filename = f"{base_name}{suffix}"
  121. structure_path = output_dir / structure_filename
  122. save_structure_to_config(structure, structure_path)
  123. saved_files.append(("配置文件", structure_path))
  124. with open(structure_path, 'r') as f:
  125. st.download_button(
  126. "📥 下载配置文件",
  127. f.read(),
  128. file_name=f"{base_name}_structure.json",
  129. mime="application/json"
  130. )
  131. def _save_image_file(image, structure, line_width, color_option, line_colors,
  132. output_dir, base_name, suffix, saved_files):
  133. """保存表格线图片"""
  134. if image is None:
  135. st.warning("⚠️ 无法保存图片:未加载图片文件")
  136. return
  137. selected_color_rgb = next(
  138. (tuple(c["rgb"]) for c in line_colors if c["name"] == color_option),
  139. (0, 0, 0),
  140. )
  141. clean_img = draw_clean_table_lines(
  142. image,
  143. structure,
  144. line_width=line_width,
  145. line_color=selected_color_rgb,
  146. )
  147. image_filename = f"{base_name}{suffix}"
  148. output_image_path = output_dir / image_filename
  149. clean_img.save(output_image_path)
  150. saved_files.append(("表格线图片", output_image_path))
  151. buf = io.BytesIO()
  152. clean_img.save(buf, format='PNG')
  153. buf.seek(0)
  154. st.download_button(
  155. "📥 下载表格线图片",
  156. buf,
  157. file_name=f"{base_name}_with_lines.png",
  158. mime="image/png"
  159. )