ui_components_v1.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. """
  2. UI 组件
  3. """
  4. import streamlit as st
  5. import json
  6. from pathlib import Path
  7. from PIL import Image
  8. import tempfile
  9. from typing import Dict, List
  10. try:
  11. from ..table_line_generator import TableLineGenerator
  12. except ImportError:
  13. from table_line_generator import TableLineGenerator
  14. from .config_loader import load_structure_from_config, build_data_source_catalog
  15. from .drawing import clear_table_image_cache
  16. def create_file_uploader_section(work_mode: str):
  17. """
  18. 创建文件上传区域
  19. Args:
  20. work_mode: 工作模式("🆕 新建标注" 或 "📂 加载已有标注")
  21. """
  22. if work_mode == "🆕 新建标注":
  23. st.sidebar.subheader("上传文件")
  24. uploaded_json = st.sidebar.file_uploader("上传OCR结果JSON", type=['json'], key="new_json")
  25. uploaded_image = st.sidebar.file_uploader("上传对应图片", type=['jpg', 'png'], key="new_image")
  26. # 处理 JSON 上传
  27. if uploaded_json is not None:
  28. if st.session_state.loaded_json_name != uploaded_json.name:
  29. try:
  30. raw_data = json.load(uploaded_json)
  31. with st.expander("🔍 原始数据结构"):
  32. if isinstance(raw_data, dict):
  33. st.json({k: f"<{type(v).__name__}>" if not isinstance(v, (str, int, float, bool, type(None))) else v
  34. for k, v in list(raw_data.items())[:5]})
  35. else:
  36. st.json(raw_data[:3] if len(raw_data) > 3 else raw_data)
  37. ocr_data = parse_ocr_data(raw_data)
  38. if not ocr_data:
  39. st.error("❌ 无法解析 OCR 数据,请检查 JSON 格式")
  40. st.stop()
  41. st.session_state.ocr_data = ocr_data
  42. st.session_state.loaded_json_name = uploaded_json.name
  43. st.session_state.loaded_config_name = None
  44. # 清除旧数据
  45. if 'structure' in st.session_state:
  46. del st.session_state.structure
  47. if 'generator' in st.session_state:
  48. del st.session_state.generator
  49. st.session_state.undo_stack = []
  50. st.session_state.redo_stack = []
  51. clear_table_image_cache()
  52. st.success(f"✅ 成功加载 {len(ocr_data)} 条 OCR 记录")
  53. except Exception as e:
  54. st.error(f"❌ 加载数据失败: {e}")
  55. st.stop()
  56. # 处理图片上传
  57. if uploaded_image is not None:
  58. if st.session_state.loaded_image_name != uploaded_image.name:
  59. try:
  60. image = Image.open(uploaded_image)
  61. st.session_state.image = image
  62. st.session_state.loaded_image_name = uploaded_image.name
  63. # 清除旧数据
  64. if 'structure' in st.session_state:
  65. del st.session_state.structure
  66. if 'generator' in st.session_state:
  67. del st.session_state.generator
  68. st.session_state.undo_stack = []
  69. st.session_state.redo_stack = []
  70. clear_table_image_cache()
  71. st.success(f"✅ 成功加载图片: {uploaded_image.name}")
  72. except Exception as e:
  73. st.error(f"❌ 加载图片失败: {e}")
  74. st.stop()
  75. else: # 加载已有标注
  76. st.sidebar.subheader("加载已保存的标注")
  77. uploaded_config = st.sidebar.file_uploader(
  78. "上传配置文件 (*_structure.json)",
  79. type=['json'],
  80. key="load_config"
  81. )
  82. uploaded_image_for_config = st.sidebar.file_uploader(
  83. "上传对应图片(可选)",
  84. type=['jpg', 'png'],
  85. key="load_image"
  86. )
  87. # 处理配置文件加载
  88. if uploaded_config is not None:
  89. if st.session_state.loaded_config_name != uploaded_config.name:
  90. try:
  91. # 创建临时文件
  92. with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as tmp:
  93. tmp.write(uploaded_config.getvalue().decode('utf-8'))
  94. tmp_path = tmp.name
  95. # 加载结构
  96. structure = load_structure_from_config(Path(tmp_path))
  97. # 清理临时文件
  98. Path(tmp_path).unlink()
  99. st.session_state.structure = structure
  100. st.session_state.loaded_config_name = uploaded_config.name
  101. # 清除历史记录和缓存
  102. st.session_state.undo_stack = []
  103. st.session_state.redo_stack = []
  104. clear_table_image_cache()
  105. st.success(f"✅ 成功加载配置: {uploaded_config.name}")
  106. st.info(
  107. f"📊 表格结构: {len(structure['rows'])}行 x {len(structure['columns'])}列\n\n"
  108. f"📏 横线数: {len(structure.get('horizontal_lines', []))}\n\n"
  109. f"📏 竖线数: {len(structure.get('vertical_lines', []))}"
  110. )
  111. # 显示配置文件详情
  112. with st.expander("📋 配置详情"):
  113. st.json({
  114. "行数": len(structure['rows']),
  115. "列数": len(structure['columns']),
  116. "横线数": len(structure.get('horizontal_lines', [])),
  117. "竖线数": len(structure.get('vertical_lines', [])),
  118. "行高": structure.get('row_height'),
  119. "列宽": structure.get('col_widths'),
  120. "已修改的横线": list(structure.get('modified_h_lines', set())),
  121. "已修改的竖线": list(structure.get('modified_v_lines', set()))
  122. })
  123. except Exception as e:
  124. st.error(f"❌ 加载配置失败: {e}")
  125. import traceback
  126. st.code(traceback.format_exc())
  127. st.stop()
  128. # 处理图片加载
  129. if uploaded_image_for_config is not None:
  130. if st.session_state.loaded_image_name != uploaded_image_for_config.name:
  131. try:
  132. image = Image.open(uploaded_image_for_config)
  133. st.session_state.image = image
  134. st.session_state.loaded_image_name = uploaded_image_for_config.name
  135. clear_table_image_cache()
  136. st.success(f"✅ 成功加载图片: {uploaded_image_for_config.name}")
  137. except Exception as e:
  138. st.error(f"❌ 加载图片失败: {e}")
  139. st.stop()
  140. # 提示信息
  141. if 'structure' in st.session_state and st.session_state.image is None:
  142. st.warning("⚠️ 已加载配置,但未加载对应图片。请上传图片以查看效果。")
  143. st.info("💡 提示:配置文件已加载,您可以:\n1. 上传对应图片查看效果\n2. 直接编辑配置并保存")
  144. def create_display_settings_section(display_config: Dict):
  145. """显示设置(由配置驱动)"""
  146. st.sidebar.divider()
  147. st.sidebar.subheader("🖼️ 显示设置")
  148. line_width = st.sidebar.slider(
  149. "线条宽度",
  150. int(display_config.get("line_width_min", 1)),
  151. int(display_config.get("line_width_max", 5)),
  152. int(display_config.get("default_line_width", 2)),
  153. )
  154. display_mode = st.sidebar.radio(
  155. "显示模式",
  156. ["对比显示", "仅显示划线图", "仅显示原图"],
  157. index=1,
  158. )
  159. zoom_level = st.sidebar.slider(
  160. "图片缩放",
  161. float(display_config.get("zoom_min", 0.25)),
  162. float(display_config.get("zoom_max", 2.0)),
  163. float(display_config.get("default_zoom", 1.0)),
  164. float(display_config.get("zoom_step", 0.25)),
  165. )
  166. show_line_numbers = st.sidebar.checkbox(
  167. "显示线条编号",
  168. value=bool(display_config.get("show_line_numbers", True)),
  169. )
  170. return line_width, display_mode, zoom_level, show_line_numbers
  171. def create_undo_redo_section():
  172. """创建撤销/重做区域"""
  173. from .state_manager import undo_last_action, redo_last_action
  174. from .drawing import clear_table_image_cache
  175. st.sidebar.divider()
  176. st.sidebar.subheader("↩️ 撤销/重做")
  177. col1, col2 = st.sidebar.columns(2)
  178. with col1:
  179. if st.button("↩️ 撤销", disabled=len(st.session_state.undo_stack) == 0):
  180. if undo_last_action():
  181. clear_table_image_cache()
  182. st.success("✅ 已撤销")
  183. st.rerun()
  184. with col2:
  185. if st.button("↪️ 重做", disabled=len(st.session_state.redo_stack) == 0):
  186. if redo_last_action():
  187. clear_table_image_cache()
  188. st.success("✅ 已重做")
  189. st.rerun()
  190. st.sidebar.info(f"📚 历史记录: {len(st.session_state.undo_stack)} 条")
  191. def create_analysis_section(y_tolerance, x_tolerance, min_row_height):
  192. """
  193. 创建分析区域
  194. Args:
  195. y_tolerance: Y轴聚类容差
  196. x_tolerance: X轴聚类容差
  197. min_row_height: 最小行高
  198. """
  199. if st.button("🔍 分析表格结构"):
  200. with st.spinner("分析中..."):
  201. try:
  202. generator = st.session_state.generator
  203. structure = generator.analyze_table_structure(
  204. y_tolerance=y_tolerance,
  205. x_tolerance=x_tolerance,
  206. min_row_height=min_row_height
  207. )
  208. if not structure:
  209. st.warning("⚠️ 未检测到表格结构")
  210. st.stop()
  211. structure['modified_h_lines'] = set()
  212. structure['modified_v_lines'] = set()
  213. st.session_state.structure = structure
  214. st.session_state.undo_stack = []
  215. st.session_state.redo_stack = []
  216. clear_table_image_cache()
  217. st.success(
  218. f"✅ 检测到 {len(structure['rows'])} 行({len(structure['horizontal_lines'])} 条横线),"
  219. f"{len(structure['columns'])} 列({len(structure['vertical_lines'])} 条竖线)"
  220. )
  221. col1, col2, col3, col4 = st.columns(4)
  222. with col1:
  223. st.metric("行数", len(structure['rows']))
  224. with col2:
  225. st.metric("横线数", len(structure['horizontal_lines']))
  226. with col3:
  227. st.metric("列数", len(structure['columns']))
  228. with col4:
  229. st.metric("竖线数", len(structure['vertical_lines']))
  230. except Exception as e:
  231. st.error(f"❌ 分析失败: {e}")
  232. import traceback
  233. st.code(traceback.format_exc())
  234. st.stop()
  235. def create_save_section(work_mode, structure, image, line_width, output_config: Dict):
  236. """
  237. 保存设置(目录/命名来自配置)
  238. """
  239. from .config_loader import save_structure_to_config
  240. from .drawing import draw_clean_table_lines
  241. import io
  242. st.divider()
  243. defaults = output_config.get("defaults", {})
  244. line_colors = output_config.get("line_colors") or [
  245. {"name": "黑色", "rgb": [0, 0, 0]},
  246. {"name": "蓝色", "rgb": [0, 0, 255]},
  247. {"name": "红色", "rgb": [255, 0, 0]},
  248. ]
  249. save_col1, save_col2, save_col3 = st.columns(3)
  250. with save_col1:
  251. save_structure = st.checkbox(
  252. "保存表格结构配置",
  253. value=bool(defaults.get("save_structure", True)),
  254. )
  255. with save_col2:
  256. save_image = st.checkbox(
  257. "保存表格线图片",
  258. value=bool(defaults.get("save_image", True)),
  259. )
  260. color_names = [c["name"] for c in line_colors]
  261. default_color = defaults.get("line_color", color_names[0])
  262. default_index = color_names.index(default_color) if default_color in color_names else 0
  263. with save_col3:
  264. line_color_option = st.selectbox(
  265. "保存时线条颜色",
  266. color_names,
  267. label_visibility="collapsed",
  268. index=default_index,
  269. )
  270. if st.button("💾 保存", type="primary"):
  271. output_dir = Path(output_config.get("directory", "output/table_structures"))
  272. output_dir.mkdir(parents=True, exist_ok=True)
  273. structure_suffix = output_config.get("structure_suffix", "_structure.json")
  274. image_suffix = output_config.get("image_suffix", "_with_lines.png")
  275. # 确定文件名
  276. if work_mode == "🆕 新建标注":
  277. if st.session_state.loaded_json_name:
  278. base_name = Path(st.session_state.loaded_json_name).stem
  279. else:
  280. base_name = "table_structure"
  281. else:
  282. if st.session_state.loaded_config_name:
  283. base_name = Path(st.session_state.loaded_config_name).stem
  284. if base_name.endswith('_structure'):
  285. base_name = base_name[:-10]
  286. elif st.session_state.loaded_image_name:
  287. base_name = Path(st.session_state.loaded_image_name).stem
  288. else:
  289. base_name = "table_structure"
  290. saved_files = []
  291. if save_structure:
  292. structure_filename = f"{base_name}{structure_suffix}"
  293. structure_path = output_dir / structure_filename
  294. save_structure_to_config(structure, structure_path)
  295. saved_files.append(("配置文件", structure_path))
  296. with open(structure_path, 'r') as f:
  297. st.download_button(
  298. "📥 下载配置文件",
  299. f.read(),
  300. file_name=f"{base_name}_structure.json",
  301. mime="application/json"
  302. )
  303. if save_image:
  304. if st.session_state.image is None:
  305. st.warning("⚠️ 无法保存图片:未加载图片文件")
  306. else:
  307. selected_color_rgb = next(
  308. (tuple(c["rgb"]) for c in line_colors if c["name"] == line_color_option),
  309. (0, 0, 0),
  310. )
  311. clean_img = draw_clean_table_lines(
  312. st.session_state.image,
  313. structure,
  314. line_width=line_width,
  315. line_color=selected_color_rgb,
  316. )
  317. image_filename = f"{base_name}{image_suffix}"
  318. output_image_path = output_dir / image_filename
  319. clean_img.save(output_image_path)
  320. saved_files.append(("表格线图片", output_image_path))
  321. buf = io.BytesIO()
  322. clean_img.save(buf, format='PNG')
  323. buf.seek(0)
  324. st.download_button(
  325. "📥 下载表格线图片",
  326. buf,
  327. file_name=f"{base_name}_with_lines.png",
  328. mime="image/png"
  329. )
  330. if saved_files:
  331. st.success(f"✅ 已保存 {len(saved_files)} 个文件:")
  332. for file_type, file_path in saved_files:
  333. st.info(f" • {file_type}: {file_path}")
  334. def setup_new_annotation_mode(ocr_data, image, config: Dict):
  335. """
  336. 设置新建标注模式的通用逻辑
  337. Args:
  338. ocr_data: OCR 数据
  339. image: 图片对象
  340. config: 显示配置
  341. Returns:
  342. tuple: (y_tolerance, x_tolerance, min_row_height, line_width, display_mode, zoom_level, show_line_numbers)
  343. """
  344. # 参数调整
  345. st.sidebar.header("🔧 参数调整")
  346. y_tolerance = st.sidebar.slider("Y轴聚类容差(像素)", 1, 20, 5, key="new_y_tol")
  347. x_tolerance = st.sidebar.slider("X轴聚类容差(像素)", 5, 50, 10, key="new_x_tol")
  348. min_row_height = st.sidebar.slider("最小行高(像素)", 10, 100, 20, key="new_min_h")
  349. # 显示设置
  350. line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(config)
  351. create_undo_redo_section()
  352. # 初始化生成器
  353. if 'generator' not in st.session_state or st.session_state.generator is None:
  354. try:
  355. generator = TableLineGenerator(image, ocr_data)
  356. st.session_state.generator = generator
  357. except Exception as e:
  358. st.error(f"❌ 初始化生成器失败: {e}")
  359. st.stop()
  360. # 分析按钮
  361. create_analysis_section(y_tolerance, x_tolerance, min_row_height)
  362. return y_tolerance, x_tolerance, min_row_height, line_width, display_mode, zoom_level, show_line_numbers
  363. def setup_edit_annotation_mode(structure, image, config: Dict):
  364. """
  365. 设置编辑标注模式的通用逻辑
  366. Args:
  367. structure: 表格结构
  368. image: 图片对象(可为 None)
  369. config: 显示配置
  370. Returns:
  371. tuple: (image, line_width, display_mode, zoom_level, show_line_numbers)
  372. """
  373. # 如果没有图片,创建虚拟画布
  374. if image is None:
  375. if 'table_bbox' in structure:
  376. bbox = structure['table_bbox']
  377. dummy_width = bbox[2] + 100
  378. dummy_height = bbox[3] + 100
  379. else:
  380. dummy_width = 2000
  381. dummy_height = 2000
  382. image = Image.new('RGB', (dummy_width, dummy_height), color='white')
  383. st.info(f"💡 使用虚拟画布 ({dummy_width}x{dummy_height})")
  384. # 显示设置
  385. line_width, display_mode, zoom_level, show_line_numbers = create_display_settings_section(config)
  386. create_undo_redo_section()
  387. return image, line_width, display_mode, zoom_level, show_line_numbers
  388. def render_table_structure_view(structure, image, line_width, display_mode, zoom_level, show_line_numbers,
  389. viewport_width, viewport_height):
  390. """
  391. 渲染表格结构视图(统一三种模式的显示逻辑)
  392. Args:
  393. structure: 表格结构
  394. image: 图片对象
  395. line_width: 线条宽度
  396. display_mode: 显示模式
  397. zoom_level: 缩放级别
  398. show_line_numbers: 是否显示线条编号
  399. viewport_width: 视口宽度
  400. viewport_height: 视口高度
  401. """
  402. # 绘制表格线
  403. img_with_lines = get_cached_table_lines_image(
  404. image, structure, line_width=line_width, show_numbers=show_line_numbers
  405. )
  406. # 根据显示模式显示图片
  407. if display_mode == "对比显示":
  408. col1, col2 = st.columns(2)
  409. with col1:
  410. show_image_with_scroll(image, "原图", viewport_width, viewport_height, zoom_level)
  411. with col2:
  412. show_image_with_scroll(img_with_lines, "表格线", viewport_width, viewport_height, zoom_level)
  413. elif display_mode == "仅显示划线图":
  414. show_image_with_scroll(
  415. img_with_lines,
  416. f"表格线图 (缩放: {zoom_level:.0%})",
  417. viewport_width,
  418. viewport_height,
  419. zoom_level
  420. )
  421. else:
  422. show_image_with_scroll(
  423. image,
  424. f"原图 (缩放: {zoom_level:.0%})",
  425. viewport_width,
  426. viewport_height,
  427. zoom_level
  428. )
  429. # 手动调整区域
  430. create_adjustment_section(structure)
  431. # 显示详细信息
  432. with st.expander("📊 表格结构详情"):
  433. st.json({
  434. "行数": len(structure['rows']),
  435. "列数": len(structure['columns']),
  436. "横线数": len(structure.get('horizontal_lines', [])),
  437. "竖线数": len(structure.get('vertical_lines', [])),
  438. "横线坐标": structure.get('horizontal_lines', []),
  439. "竖线坐标": structure.get('vertical_lines', []),
  440. "标准行高": structure.get('row_height'),
  441. "列宽度": structure.get('col_widths'),
  442. "修改的横线": list(structure.get('modified_h_lines', set())),
  443. "修改的竖线": list(structure.get('modified_v_lines', set()))
  444. })
  445. def create_directory_selector(data_sources: List[Dict], global_output_config: Dict):
  446. """目录模式选择器(优化:避免重复加载)"""
  447. st.sidebar.subheader("目录模式")
  448. source_names = [src["name"] for src in data_sources]
  449. selected_name = st.sidebar.selectbox("选择数据源", source_names, key="dir_mode_source")
  450. source_cfg = next(src for src in data_sources if src["name"] == selected_name)
  451. output_cfg = source_cfg.get("output", global_output_config)
  452. output_dir = Path(output_cfg.get("directory", "output/table_structures"))
  453. structure_suffix = output_cfg.get("structure_suffix", "_structure.json")
  454. catalog_key = f"catalog::{selected_name}"
  455. if catalog_key not in st.session_state:
  456. st.session_state[catalog_key] = build_data_source_catalog(source_cfg)
  457. catalog = st.session_state[catalog_key]
  458. if not catalog:
  459. st.sidebar.warning("目录中没有 JSON 文件")
  460. return
  461. if 'dir_selected_index' not in st.session_state:
  462. st.session_state.dir_selected_index = 0
  463. selected = st.sidebar.selectbox(
  464. "选择文件",
  465. range(len(catalog)),
  466. format_func=lambda i: catalog[i]["display"],
  467. index=st.session_state.dir_selected_index,
  468. key="dir_select_box"
  469. )
  470. page_input = st.sidebar.number_input(
  471. "页码跳转",
  472. min_value=1,
  473. max_value=len(catalog),
  474. value=catalog[selected]["index"],
  475. step=1,
  476. key="dir_page_input"
  477. )
  478. # 🔑 关键优化:只在切换文件时才重新加载
  479. current_entry_key = f"{selected_name}::{catalog[selected]['json']}"
  480. if 'last_loaded_entry' not in st.session_state or st.session_state.last_loaded_entry != current_entry_key:
  481. # 文件切换,重新加载
  482. entry = catalog[selected]
  483. base_name = entry["json"].stem
  484. structure_file = output_dir / f"{base_name}{structure_suffix}"
  485. has_structure = structure_file.exists()
  486. # 📂 加载 JSON
  487. with open(entry["json"], "r", encoding="utf-8") as fp:
  488. raw = json.load(fp)
  489. st.session_state.ocr_data = parse_ocr_data(raw)
  490. st.session_state.loaded_json_name = entry["json"].name
  491. # 🖼️ 加载图片
  492. if entry["image"] and entry["image"].exists():
  493. st.session_state.image = Image.open(entry["image"])
  494. st.session_state.loaded_image_name = entry["image"].name
  495. else:
  496. st.session_state.image = None
  497. # 🎯 自动模式判断
  498. if has_structure:
  499. st.session_state.dir_auto_mode = "edit"
  500. st.session_state.loaded_config_name = structure_file.name
  501. try:
  502. structure = load_structure_from_config(structure_file)
  503. st.session_state.structure = structure
  504. st.session_state.undo_stack = []
  505. st.session_state.redo_stack = []
  506. clear_table_image_cache()
  507. st.sidebar.success(f"✅ 编辑模式")
  508. except Exception as e:
  509. st.error(f"❌ 加载标注失败: {e}")
  510. st.session_state.dir_auto_mode = "new"
  511. else:
  512. st.session_state.dir_auto_mode = "new"
  513. if 'structure' in st.session_state:
  514. del st.session_state.structure
  515. if 'generator' in st.session_state:
  516. del st.session_state.generator
  517. st.sidebar.info(f"🆕 新建模式")
  518. # 标记已加载
  519. st.session_state.last_loaded_entry = current_entry_key
  520. st.info(f"📂 已加载: {entry['json'].name}")
  521. # 页码跳转处理
  522. if page_input != catalog[selected]["index"]:
  523. target = next((i for i, item in enumerate(catalog) if item["index"] == page_input), None)
  524. if target is not None:
  525. st.session_state.dir_selected_index = target
  526. st.rerun()
  527. return st.session_state.get('dir_auto_mode', 'new')