ui_components.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. """
  2. UI 组件
  3. """
  4. import streamlit as st
  5. import json
  6. from pathlib import Path
  7. from PIL import Image
  8. import tempfile
  9. try:
  10. from ..table_line_generator import TableLineGenerator
  11. except ImportError:
  12. from table_line_generator import TableLineGenerator
  13. from .config_loader import load_structure_from_config
  14. from .drawing import clear_table_image_cache
  15. def parse_ocr_data(ocr_data):
  16. """解析OCR数据,支持多种格式"""
  17. # 如果是字符串,尝试解析
  18. if isinstance(ocr_data, str):
  19. try:
  20. ocr_data = json.loads(ocr_data)
  21. except json.JSONDecodeError:
  22. st.error("❌ JSON 格式错误,无法解析")
  23. return []
  24. # 检查是否为 PPStructure V3 格式
  25. if isinstance(ocr_data, dict) and 'parsing_res_list' in ocr_data and 'overall_ocr_res' in ocr_data:
  26. st.info("🔍 检测到 PPStructure V3 格式")
  27. try:
  28. table_bbox, text_boxes = TableLineGenerator.parse_ppstructure_result(ocr_data)
  29. st.success(f"✅ 表格区域: {table_bbox}")
  30. st.success(f"✅ 表格内文本框: {len(text_boxes)} 个")
  31. return text_boxes
  32. except Exception as e:
  33. st.error(f"❌ 解析 PPStructure 结果失败: {e}")
  34. return []
  35. # 确保是列表
  36. if not isinstance(ocr_data, list):
  37. st.error(f"❌ OCR 数据应该是列表,实际类型: {type(ocr_data)}")
  38. return []
  39. if not ocr_data:
  40. st.warning("⚠️ OCR 数据为空")
  41. return []
  42. first_item = ocr_data[0]
  43. if not isinstance(first_item, dict):
  44. st.error(f"❌ OCR 数据项应该是字典,实际类型: {type(first_item)}")
  45. return []
  46. if 'bbox' not in first_item:
  47. st.error("❌ OCR 数据缺少 'bbox' 字段")
  48. st.info("💡 支持的格式示例:\n```json\n[\n {\n \"text\": \"文本\",\n \"bbox\": [x1, y1, x2, y2]\n }\n]\n```")
  49. return []
  50. return ocr_data
  51. def create_file_uploader_section(work_mode: str):
  52. """
  53. 创建文件上传区域
  54. Args:
  55. work_mode: 工作模式("🆕 新建标注" 或 "📂 加载已有标注")
  56. """
  57. if work_mode == "🆕 新建标注":
  58. st.sidebar.subheader("上传文件")
  59. uploaded_json = st.sidebar.file_uploader("上传OCR结果JSON", type=['json'], key="new_json")
  60. uploaded_image = st.sidebar.file_uploader("上传对应图片", type=['jpg', 'png'], key="new_image")
  61. # 处理 JSON 上传
  62. if uploaded_json is not None:
  63. if st.session_state.loaded_json_name != uploaded_json.name:
  64. try:
  65. raw_data = json.load(uploaded_json)
  66. with st.expander("🔍 原始数据结构"):
  67. if isinstance(raw_data, dict):
  68. st.json({k: f"<{type(v).__name__}>" if not isinstance(v, (str, int, float, bool, type(None))) else v
  69. for k, v in list(raw_data.items())[:5]})
  70. else:
  71. st.json(raw_data[:3] if len(raw_data) > 3 else raw_data)
  72. ocr_data = parse_ocr_data(raw_data)
  73. if not ocr_data:
  74. st.error("❌ 无法解析 OCR 数据,请检查 JSON 格式")
  75. st.stop()
  76. st.session_state.ocr_data = ocr_data
  77. st.session_state.loaded_json_name = uploaded_json.name
  78. st.session_state.loaded_config_name = None
  79. # 清除旧数据
  80. if 'structure' in st.session_state:
  81. del st.session_state.structure
  82. if 'generator' in st.session_state:
  83. del st.session_state.generator
  84. st.session_state.undo_stack = []
  85. st.session_state.redo_stack = []
  86. clear_table_image_cache()
  87. st.success(f"✅ 成功加载 {len(ocr_data)} 条 OCR 记录")
  88. except Exception as e:
  89. st.error(f"❌ 加载数据失败: {e}")
  90. st.stop()
  91. # 处理图片上传
  92. if uploaded_image is not None:
  93. if st.session_state.loaded_image_name != uploaded_image.name:
  94. try:
  95. image = Image.open(uploaded_image)
  96. st.session_state.image = image
  97. st.session_state.loaded_image_name = uploaded_image.name
  98. # 清除旧数据
  99. if 'structure' in st.session_state:
  100. del st.session_state.structure
  101. if 'generator' in st.session_state:
  102. del st.session_state.generator
  103. st.session_state.undo_stack = []
  104. st.session_state.redo_stack = []
  105. clear_table_image_cache()
  106. st.success(f"✅ 成功加载图片: {uploaded_image.name}")
  107. except Exception as e:
  108. st.error(f"❌ 加载图片失败: {e}")
  109. st.stop()
  110. else: # 加载已有标注
  111. st.sidebar.subheader("加载已保存的标注")
  112. uploaded_config = st.sidebar.file_uploader(
  113. "上传配置文件 (*_structure.json)",
  114. type=['json'],
  115. key="load_config"
  116. )
  117. uploaded_image_for_config = st.sidebar.file_uploader(
  118. "上传对应图片(可选)",
  119. type=['jpg', 'png'],
  120. key="load_image"
  121. )
  122. # 处理配置文件加载
  123. if uploaded_config is not None:
  124. if st.session_state.loaded_config_name != uploaded_config.name:
  125. try:
  126. # 创建临时文件
  127. with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as tmp:
  128. tmp.write(uploaded_config.getvalue().decode('utf-8'))
  129. tmp_path = tmp.name
  130. # 加载结构
  131. structure = load_structure_from_config(Path(tmp_path))
  132. # 清理临时文件
  133. Path(tmp_path).unlink()
  134. st.session_state.structure = structure
  135. st.session_state.loaded_config_name = uploaded_config.name
  136. # 清除历史记录和缓存
  137. st.session_state.undo_stack = []
  138. st.session_state.redo_stack = []
  139. clear_table_image_cache()
  140. st.success(f"✅ 成功加载配置: {uploaded_config.name}")
  141. st.info(
  142. f"📊 表格结构: {len(structure['rows'])}行 x {len(structure['columns'])}列\n\n"
  143. f"📏 横线数: {len(structure.get('horizontal_lines', []))}\n\n"
  144. f"📏 竖线数: {len(structure.get('vertical_lines', []))}"
  145. )
  146. # 显示配置文件详情
  147. with st.expander("📋 配置详情"):
  148. st.json({
  149. "行数": len(structure['rows']),
  150. "列数": len(structure['columns']),
  151. "横线数": len(structure.get('horizontal_lines', [])),
  152. "竖线数": len(structure.get('vertical_lines', [])),
  153. "行高": structure.get('row_height'),
  154. "列宽": structure.get('col_widths'),
  155. "已修改的横线": list(structure.get('modified_h_lines', set())),
  156. "已修改的竖线": list(structure.get('modified_v_lines', set()))
  157. })
  158. except Exception as e:
  159. st.error(f"❌ 加载配置失败: {e}")
  160. import traceback
  161. st.code(traceback.format_exc())
  162. st.stop()
  163. # 处理图片加载
  164. if uploaded_image_for_config is not None:
  165. if st.session_state.loaded_image_name != uploaded_image_for_config.name:
  166. try:
  167. image = Image.open(uploaded_image_for_config)
  168. st.session_state.image = image
  169. st.session_state.loaded_image_name = uploaded_image_for_config.name
  170. clear_table_image_cache()
  171. st.success(f"✅ 成功加载图片: {uploaded_image_for_config.name}")
  172. except Exception as e:
  173. st.error(f"❌ 加载图片失败: {e}")
  174. st.stop()
  175. # 提示信息
  176. if 'structure' in st.session_state and st.session_state.image is None:
  177. st.warning("⚠️ 已加载配置,但未加载对应图片。请上传图片以查看效果。")
  178. st.info("💡 提示:配置文件已加载,您可以:\n1. 上传对应图片查看效果\n2. 直接编辑配置并保存")
  179. def create_display_settings_section():
  180. """创建显示设置区域"""
  181. st.sidebar.divider()
  182. st.sidebar.subheader("🖼️ 显示设置")
  183. line_width = st.sidebar.slider("线条宽度", 1, 5, 2)
  184. display_mode = st.sidebar.radio("显示模式", ["对比显示", "仅显示划线图", "仅显示原图"], index=1)
  185. zoom_level = st.sidebar.slider("图片缩放", 0.25, 2.0, 1.0, 0.25)
  186. show_line_numbers = st.sidebar.checkbox("显示线条编号", value=True)
  187. return line_width, display_mode, zoom_level, show_line_numbers
  188. def create_undo_redo_section():
  189. """创建撤销/重做区域"""
  190. from .state_manager import undo_last_action, redo_last_action
  191. from .drawing import clear_table_image_cache
  192. st.sidebar.divider()
  193. st.sidebar.subheader("↩️ 撤销/重做")
  194. col1, col2 = st.sidebar.columns(2)
  195. with col1:
  196. if st.button("↩️ 撤销", disabled=len(st.session_state.undo_stack) == 0):
  197. if undo_last_action():
  198. clear_table_image_cache()
  199. st.success("✅ 已撤销")
  200. st.rerun()
  201. with col2:
  202. if st.button("↪️ 重做", disabled=len(st.session_state.redo_stack) == 0):
  203. if redo_last_action():
  204. clear_table_image_cache()
  205. st.success("✅ 已重做")
  206. st.rerun()
  207. st.sidebar.info(f"📚 历史记录: {len(st.session_state.undo_stack)} 条")
  208. def create_analysis_section(y_tolerance, x_tolerance, min_row_height):
  209. """
  210. 创建分析区域
  211. Args:
  212. y_tolerance: Y轴聚类容差
  213. x_tolerance: X轴聚类容差
  214. min_row_height: 最小行高
  215. """
  216. if st.button("🔍 分析表格结构"):
  217. with st.spinner("分析中..."):
  218. try:
  219. generator = st.session_state.generator
  220. structure = generator.analyze_table_structure(
  221. y_tolerance=y_tolerance,
  222. x_tolerance=x_tolerance,
  223. min_row_height=min_row_height
  224. )
  225. if not structure:
  226. st.warning("⚠️ 未检测到表格结构")
  227. st.stop()
  228. structure['modified_h_lines'] = set()
  229. structure['modified_v_lines'] = set()
  230. st.session_state.structure = structure
  231. st.session_state.undo_stack = []
  232. st.session_state.redo_stack = []
  233. clear_table_image_cache()
  234. st.success(
  235. f"✅ 检测到 {len(structure['rows'])} 行({len(structure['horizontal_lines'])} 条横线),"
  236. f"{len(structure['columns'])} 列({len(structure['vertical_lines'])} 条竖线)"
  237. )
  238. col1, col2, col3, col4 = st.columns(4)
  239. with col1:
  240. st.metric("行数", len(structure['rows']))
  241. with col2:
  242. st.metric("横线数", len(structure['horizontal_lines']))
  243. with col3:
  244. st.metric("列数", len(structure['columns']))
  245. with col4:
  246. st.metric("竖线数", len(structure['vertical_lines']))
  247. except Exception as e:
  248. st.error(f"❌ 分析失败: {e}")
  249. import traceback
  250. st.code(traceback.format_exc())
  251. st.stop()
  252. def create_save_section(work_mode, structure, image, line_width):
  253. """
  254. 创建保存区域
  255. Args:
  256. work_mode: 工作模式
  257. structure: 表格结构
  258. image: 图片
  259. line_width: 线条宽度
  260. """
  261. from .config_loader import save_structure_to_config
  262. from .drawing import draw_clean_table_lines
  263. import io
  264. st.divider()
  265. save_col1, save_col2, save_col3 = st.columns(3)
  266. with save_col1:
  267. save_structure = st.checkbox("保存表格结构配置", value=True)
  268. with save_col2:
  269. save_image = st.checkbox("保存表格线图片", value=True)
  270. with save_col3:
  271. line_color_option = st.selectbox(
  272. "保存时线条颜色",
  273. ["黑色", "蓝色", "红色"],
  274. index=0
  275. )
  276. if st.button("💾 保存", type="primary"):
  277. output_dir = Path("output/table_structures")
  278. output_dir.mkdir(parents=True, exist_ok=True)
  279. # 确定文件名
  280. if work_mode == "🆕 新建标注":
  281. if st.session_state.loaded_json_name:
  282. base_name = Path(st.session_state.loaded_json_name).stem
  283. else:
  284. base_name = "table_structure"
  285. else:
  286. if st.session_state.loaded_config_name:
  287. base_name = Path(st.session_state.loaded_config_name).stem
  288. if base_name.endswith('_structure'):
  289. base_name = base_name[:-10]
  290. elif st.session_state.loaded_image_name:
  291. base_name = Path(st.session_state.loaded_image_name).stem
  292. else:
  293. base_name = "table_structure"
  294. saved_files = []
  295. if save_structure:
  296. structure_path = output_dir / f"{base_name}_structure.json"
  297. save_structure_to_config(structure, structure_path)
  298. saved_files.append(("配置文件", structure_path))
  299. with open(structure_path, 'r') as f:
  300. st.download_button(
  301. "📥 下载配置文件",
  302. f.read(),
  303. file_name=f"{base_name}_structure.json",
  304. mime="application/json"
  305. )
  306. if save_image:
  307. if st.session_state.image is None:
  308. st.warning("⚠️ 无法保存图片:未加载图片文件")
  309. else:
  310. color_map = {
  311. "黑色": (0, 0, 0),
  312. "蓝色": (0, 0, 255),
  313. "红色": (255, 0, 0)
  314. }
  315. selected_color = color_map[line_color_option]
  316. clean_img = draw_clean_table_lines(
  317. st.session_state.image,
  318. structure,
  319. line_width=line_width,
  320. line_color=selected_color
  321. )
  322. output_image_path = output_dir / f"{base_name}_with_lines.png"
  323. clean_img.save(output_image_path)
  324. saved_files.append(("表格线图片", output_image_path))
  325. buf = io.BytesIO()
  326. clean_img.save(buf, format='PNG')
  327. buf.seek(0)
  328. st.download_button(
  329. "📥 下载表格线图片",
  330. buf,
  331. file_name=f"{base_name}_with_lines.png",
  332. mime="image/png"
  333. )
  334. if saved_files:
  335. st.success(f"✅ 已保存 {len(saved_files)} 个文件:")
  336. for file_type, file_path in saved_files:
  337. st.info(f" • {file_type}: {file_path}")