file_handlers.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. """
  2. 文件上传和加载处理
  3. """
  4. import streamlit as st
  5. import json
  6. import tempfile
  7. from pathlib import Path
  8. from PIL import Image
  9. from .config_loader import load_structure_from_config
  10. from .drawing import clear_table_image_cache
  11. try:
  12. from table_line_generator import TableLineGenerator
  13. except ImportError:
  14. from ..table_line_generator import TableLineGenerator
  15. def handle_json_upload(uploaded_json):
  16. """处理 JSON 文件上传"""
  17. if uploaded_json is None:
  18. return
  19. if st.session_state.loaded_json_name == uploaded_json.name:
  20. return
  21. try:
  22. raw_data = json.load(uploaded_json)
  23. with st.expander("🔍 原始数据结构"):
  24. if isinstance(raw_data, dict):
  25. st.json({
  26. k: f"<{type(v).__name__}>"
  27. if not isinstance(v, (str, int, float, bool, type(None)))
  28. else v
  29. for k, v in list(raw_data.items())[:5]
  30. })
  31. else:
  32. st.json(raw_data[:3] if len(raw_data) > 3 else raw_data)
  33. ocr_data = TableLineGenerator.parse_ocr_data(raw_data, tool="ppstructv3")
  34. if not ocr_data:
  35. st.error("❌ 无法解析 OCR 数据,请检查 JSON 格式")
  36. st.stop()
  37. st.session_state.ocr_data = ocr_data
  38. st.session_state.loaded_json_name = uploaded_json.name
  39. st.session_state.loaded_config_name = None
  40. # 清除旧数据
  41. if 'structure' in st.session_state:
  42. del st.session_state.structure
  43. if 'generator' in st.session_state:
  44. del st.session_state.generator
  45. st.session_state.undo_stack = []
  46. st.session_state.redo_stack = []
  47. clear_table_image_cache()
  48. st.success(f"✅ 成功加载 {len(ocr_data)} 条 OCR 记录")
  49. except Exception as e:
  50. st.error(f"❌ 加载数据失败: {e}")
  51. st.stop()
  52. def handle_image_upload(uploaded_image):
  53. """处理图片文件上传"""
  54. if uploaded_image is None:
  55. return
  56. if st.session_state.loaded_image_name == uploaded_image.name:
  57. return
  58. try:
  59. image = Image.open(uploaded_image)
  60. st.session_state.image = image
  61. st.session_state.loaded_image_name = uploaded_image.name
  62. # 清除旧数据
  63. if 'structure' in st.session_state:
  64. del st.session_state.structure
  65. if 'generator' in st.session_state:
  66. del st.session_state.generator
  67. st.session_state.undo_stack = []
  68. st.session_state.redo_stack = []
  69. clear_table_image_cache()
  70. st.success(f"✅ 成功加载图片: {uploaded_image.name}")
  71. except Exception as e:
  72. st.error(f"❌ 加载图片失败: {e}")
  73. st.stop()
  74. def handle_config_upload(uploaded_config):
  75. """处理配置文件上传"""
  76. if uploaded_config is None:
  77. return
  78. if st.session_state.loaded_config_name == uploaded_config.name:
  79. return
  80. try:
  81. # 创建临时文件
  82. with tempfile.NamedTemporaryFile(
  83. mode='w',
  84. suffix='.json',
  85. delete=False,
  86. encoding='utf-8'
  87. ) as tmp:
  88. tmp.write(uploaded_config.getvalue().decode('utf-8'))
  89. tmp_path = tmp.name
  90. # 加载结构
  91. structure = load_structure_from_config(Path(tmp_path))
  92. # 清理临时文件
  93. Path(tmp_path).unlink()
  94. st.session_state.structure = structure
  95. st.session_state.loaded_config_name = uploaded_config.name
  96. # 清除历史记录和缓存
  97. st.session_state.undo_stack = []
  98. st.session_state.redo_stack = []
  99. clear_table_image_cache()
  100. st.success(f"✅ 成功加载配置: {uploaded_config.name}")
  101. st.info(
  102. f"📊 表格结构: {len(structure['rows'])}行 x {len(structure['columns'])}列\n\n"
  103. f"📏 横线数: {len(structure.get('horizontal_lines', []))}\n\n"
  104. f"📏 竖线数: {len(structure.get('vertical_lines', []))}"
  105. )
  106. # 显示配置文件详情
  107. with st.expander("📋 配置详情"):
  108. st.json({
  109. "行数": len(structure['rows']),
  110. "列数": len(structure['columns']),
  111. "横线数": len(structure.get('horizontal_lines', [])),
  112. "竖线数": len(structure.get('vertical_lines', [])),
  113. "行高": structure.get('row_height'),
  114. "列宽": structure.get('col_widths'),
  115. "已修改的横线": list(structure.get('modified_h_lines', set())),
  116. "已修改的竖线": list(structure.get('modified_v_lines', set()))
  117. })
  118. except Exception as e:
  119. st.error(f"❌ 加载配置失败: {e}")
  120. import traceback
  121. st.code(traceback.format_exc())
  122. st.stop()
  123. def create_file_uploader_section(work_mode: str):
  124. """
  125. 创建文件上传区域
  126. Args:
  127. work_mode: 工作模式("🆕 新建标注" 或 "📂 加载已有标注")
  128. """
  129. if work_mode == "🆕 新建标注":
  130. st.sidebar.subheader("上传文件")
  131. uploaded_json = st.sidebar.file_uploader(
  132. "上传OCR结果JSON",
  133. type=['json'],
  134. key="new_json"
  135. )
  136. uploaded_image = st.sidebar.file_uploader(
  137. "上传对应图片",
  138. type=['jpg', 'png'],
  139. key="new_image"
  140. )
  141. handle_json_upload(uploaded_json)
  142. handle_image_upload(uploaded_image)
  143. else: # 加载已有标注
  144. st.sidebar.subheader("加载已保存的标注")
  145. uploaded_config = st.sidebar.file_uploader(
  146. "上传配置文件 (*_structure.json)",
  147. type=['json'],
  148. key="load_config"
  149. )
  150. uploaded_image_for_config = st.sidebar.file_uploader(
  151. "上传对应图片(可选)",
  152. type=['jpg', 'png'],
  153. key="load_image"
  154. )
  155. handle_config_upload(uploaded_config)
  156. handle_image_upload(uploaded_image_for_config)
  157. # 提示信息
  158. if 'structure' in st.session_state and st.session_state.image is None:
  159. st.warning("⚠️ 已加载配置,但未加载对应图片。请上传图片以查看效果。")
  160. st.info(
  161. "💡 提示:配置文件已加载,您可以:\n"
  162. "1. 上传对应图片查看效果\n"
  163. "2. 直接编辑配置并保存"
  164. )