batch_template_controls.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """
  2. 批量模板应用控件
  3. """
  4. import streamlit as st
  5. import json
  6. from pathlib import Path
  7. from PIL import Image
  8. from typing import Dict, List
  9. import sys
  10. # 添加父目录到路径
  11. sys.path.insert(0, str(Path(__file__).parent.parent))
  12. from table_template_applier import TableTemplateApplier
  13. from table_line_generator import TableLineGenerator
  14. def create_batch_template_section(current_line_width: int, current_line_color: str):
  15. """
  16. 创建批量应用模板的控制区域
  17. Args:
  18. current_line_width: 当前页使用的线条宽度
  19. current_line_color: 当前页使用的线条颜色名称
  20. 要求:
  21. - 当前在目录模式
  22. - 已有标注(edit 模式)
  23. - 有可用的目录清单
  24. """
  25. # 检查前置条件
  26. if 'structure' not in st.session_state or not st.session_state.structure:
  27. return
  28. if 'current_catalog' not in st.session_state:
  29. return
  30. if 'current_output_config' not in st.session_state:
  31. return
  32. # 🔑 检查当前页是否有保存的结构文件
  33. if 'loaded_config_name' not in st.session_state or not st.session_state.loaded_config_name:
  34. st.warning("⚠️ 当前页未保存结构文件,请先保存后再批量应用")
  35. return
  36. st.divider()
  37. st.subheader("🔄 批量应用模板")
  38. catalog = st.session_state.current_catalog
  39. current_index = st.session_state.get('current_catalog_index', 0)
  40. current_entry = catalog[current_index]
  41. # 统计信息
  42. total_files = len(catalog)
  43. current_page = current_entry["index"]
  44. # 找出哪些页面还没有标注
  45. output_config = st.session_state.current_output_config
  46. output_dir = Path(output_config.get("directory", "output/table_structures"))
  47. structure_suffix = output_config.get("structure_suffix", "_structure.json")
  48. # 🔑 获取当前页的结构文件路径
  49. current_base_name = current_entry["json"].stem
  50. current_structure_file = output_dir / f"{current_base_name}{structure_suffix}"
  51. if not current_structure_file.exists():
  52. st.error("❌ 未找到当前页的结构文件,请先保存")
  53. st.info(f"期望文件: {current_structure_file}")
  54. return
  55. unlabeled_pages = []
  56. for entry in catalog:
  57. if entry["index"] == current_page:
  58. continue # 跳过当前页
  59. structure_file = output_dir / f"{entry['json'].stem}{structure_suffix}"
  60. if not structure_file.exists():
  61. unlabeled_pages.append(entry)
  62. st.info(
  63. f"📊 当前页: {current_page}/{total_files}\n\n"
  64. f"📄 模板文件: {current_structure_file.name}\n\n"
  65. f"✅ 已标注: {total_files - len(unlabeled_pages)} 页\n\n"
  66. f"⏳ 待处理: {len(unlabeled_pages)} 页"
  67. )
  68. if len(unlabeled_pages) == 0:
  69. st.success("🎉 所有页面都已标注!")
  70. return
  71. # 🔑 使用当前页的设置
  72. st.info(
  73. f"🎨 将使用当前页设置:\n\n"
  74. f"• 线条宽度: {current_line_width}px\n\n"
  75. f"• 线条颜色: {current_line_color}"
  76. )
  77. # 获取颜色配置
  78. line_colors = output_config.get("line_colors") or [
  79. {"name": "黑色", "rgb": [0, 0, 0]},
  80. {"name": "蓝色", "rgb": [0, 0, 255]},
  81. {"name": "红色", "rgb": [255, 0, 0]},
  82. ]
  83. # 🔑 从颜色名称映射到 RGB
  84. color_map = {c["name"]: tuple(c["rgb"]) for c in line_colors}
  85. line_color = color_map.get(current_line_color, (0, 0, 0))
  86. # 应用按钮
  87. if st.button("🚀 批量应用到所有未标注页面", type="primary"):
  88. _apply_template_batch(
  89. current_structure_file, # 🔑 直接使用保存的结构文件
  90. current_entry,
  91. unlabeled_pages,
  92. output_dir,
  93. structure_suffix,
  94. current_line_width,
  95. line_color
  96. )
  97. def _apply_template_batch(
  98. template_file: Path, # 🔑 改为直接传入模板文件路径
  99. template_entry: Dict,
  100. target_entries: List[Dict],
  101. output_dir: Path,
  102. structure_suffix: str,
  103. line_width: int,
  104. line_color: tuple
  105. ):
  106. """
  107. 执行批量应用模板
  108. Args:
  109. template_file: 模板结构文件路径
  110. template_entry: 模板页面条目
  111. target_entries: 目标页面列表
  112. output_dir: 输出目录
  113. structure_suffix: 结构文件后缀
  114. line_width: 线条宽度
  115. line_color: 线条颜色 (r, g, b)
  116. """
  117. try:
  118. # 🔑 直接使用保存的结构文件创建模板应用器
  119. applier = TableTemplateApplier(str(template_file))
  120. st.info(f"📋 使用模板: {template_file.name}")
  121. # 进度条
  122. progress_bar = st.progress(0)
  123. status_text = st.empty()
  124. success_count = 0
  125. failed_count = 0
  126. results = []
  127. for idx, entry in enumerate(target_entries):
  128. # 更新进度
  129. progress = (idx + 1) / len(target_entries)
  130. progress_bar.progress(progress)
  131. status_text.text(f"处理中: {entry['display']} ({idx + 1}/{len(target_entries)})")
  132. try:
  133. # 加载 OCR 数据
  134. with open(entry["json"], "r", encoding="utf-8") as fp:
  135. raw = json.load(fp)
  136. # 解析 OCR 数据
  137. if 'parsing_res_list' in raw and 'overall_ocr_res' in raw:
  138. table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw)
  139. else:
  140. raise ValueError("不支持的 OCR 格式")
  141. # 加载图片
  142. if entry["image"] and entry["image"].exists():
  143. image = Image.open(entry["image"])
  144. else:
  145. st.warning(f"⚠️ 跳过 {entry['display']}: 未找到图片")
  146. failed_count += 1
  147. results.append({
  148. 'page': entry['index'],
  149. 'status': 'skipped',
  150. 'reason': 'no_image'
  151. })
  152. continue
  153. # 应用模板生成图片
  154. img_with_lines = applier.apply_to_image(
  155. image,
  156. ocr_data,
  157. line_width=line_width,
  158. line_color=line_color
  159. )
  160. # 生成结构配置
  161. structure = applier.generate_structure_for_image(ocr_data)
  162. # 保存图片
  163. base_name = entry["json"].stem
  164. image_suffix = st.session_state.current_output_config.get("image_suffix", ".png")
  165. output_image_path = output_dir / f"{base_name}{image_suffix}"
  166. img_with_lines.save(output_image_path)
  167. # 🔑 保存结构(确保 set 转为 list)
  168. structure_path = output_dir / f"{base_name}{structure_suffix}"
  169. with open(structure_path, 'w', encoding='utf-8') as f:
  170. json.dump(structure, f, indent=2, ensure_ascii=False)
  171. success_count += 1
  172. results.append({
  173. 'page': entry['index'],
  174. 'status': 'success',
  175. 'image': str(output_image_path),
  176. 'structure': str(structure_path)
  177. })
  178. except Exception as e:
  179. failed_count += 1
  180. results.append({
  181. 'page': entry['index'],
  182. 'status': 'error',
  183. 'error': str(e)
  184. })
  185. st.error(f"❌ 处理失败 {entry['display']}: {e}")
  186. # 完成
  187. progress_bar.progress(1.0)
  188. status_text.empty()
  189. # 保存批处理结果
  190. batch_result_path = output_dir / "batch_results.json"
  191. with open(batch_result_path, 'w', encoding='utf-8') as f:
  192. json.dump({
  193. 'template': template_entry['display'],
  194. 'template_file': str(template_file),
  195. 'total': len(target_entries),
  196. 'success': success_count,
  197. 'failed': failed_count,
  198. 'line_width': line_width,
  199. 'line_color': line_color,
  200. 'results': results
  201. }, f, indent=2, ensure_ascii=False)
  202. # 显示结果
  203. if success_count > 0:
  204. st.success(
  205. f"✅ 批量应用完成!\n\n"
  206. f"成功: {success_count} 页\n\n"
  207. f"失败: {failed_count} 页"
  208. )
  209. # 🔑 提供下载批处理结果
  210. with open(batch_result_path, 'r', encoding='utf-8') as f:
  211. st.download_button(
  212. "📥 下载批处理报告",
  213. f.read(),
  214. file_name="batch_results.json",
  215. mime="application/json"
  216. )
  217. else:
  218. st.error("❌ 批量应用失败,没有成功处理任何页面")
  219. # 显示详细结果
  220. with st.expander("📋 详细结果"):
  221. for result in results:
  222. if result['status'] == 'success':
  223. st.success(f"✅ 第 {result['page']} 页")
  224. elif result['status'] == 'error':
  225. st.error(f"❌ 第 {result['page']} 页: {result.get('error', '未知错误')}")
  226. else:
  227. st.warning(f"⚠️ 第 {result['page']} 页: {result.get('reason', '跳过')}")
  228. except Exception as e:
  229. st.error(f"❌ 批量应用过程中发生错误: {e}")
  230. import traceback
  231. with st.expander("🔍 详细错误信息"):
  232. st.code(traceback.format_exc())