table_template_applier.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. """
  2. 表格模板应用器
  3. 将人工标注的表格结构应用到其他页面
  4. """
  5. import json
  6. from pathlib import Path
  7. from PIL import Image, ImageDraw
  8. from typing import Dict, List, Tuple
  9. import numpy as np
  10. import argparse
  11. import sys
  12. # 添加父目录到路径
  13. sys.path.insert(0, str(Path(__file__).parent))
  14. try:
  15. from editor.data_processor import get_structure_from_ocr
  16. from table_line_generator import TableLineGenerator
  17. except ImportError:
  18. from .editor.data_processor import get_structure_from_ocr
  19. from .table_line_generator import TableLineGenerator
  20. class TableTemplateApplier:
  21. """表格模板应用器(混合模式)"""
  22. def __init__(self, template_config_path: str):
  23. """初始化时只提取列信息和表头信息"""
  24. with open(template_config_path, 'r', encoding='utf-8') as f:
  25. self.template = json.load(f)
  26. # ✅ 只提取列宽(固定)
  27. self.col_widths = self.template['col_widths']
  28. # ✅ 计算列的相对位置
  29. self.col_offsets = [0]
  30. for width in self.col_widths:
  31. self.col_offsets.append(self.col_offsets[-1] + width)
  32. # ✅ 提取表头高度(通常固定)
  33. rows = self.template['rows']
  34. if rows:
  35. self.header_height = rows[0]['y_end'] - rows[0]['y_start']
  36. else:
  37. self.header_height = 40
  38. # ✅ 计算数据行高度(用于固定行高模式)
  39. if len(rows) > 1:
  40. data_row_heights = [row['y_end'] - row['y_start'] for row in rows[1:]]
  41. # 使用中位数作为典型行高
  42. self.row_height = int(np.median(data_row_heights)) if data_row_heights else 40
  43. # 兜底行高(同样使用中位数)
  44. self.fallback_row_height = self.row_height
  45. else:
  46. # 如果只有表头,使用默认值
  47. self.row_height = 40
  48. self.fallback_row_height = 40
  49. print(f"\n✅ 加载模板配置:")
  50. print(f" 列数: {len(self.col_widths)}")
  51. print(f" 列宽: {self.col_widths}")
  52. print(f" 表头高度: {self.header_height}px")
  53. print(f" 数据行高: {self.row_height}px (用于固定行高模式)")
  54. print(f" 兜底行高: {self.fallback_row_height}px (OCR失败时使用)")
  55. def detect_table_anchor(self, ocr_data: List[Dict]) -> Tuple[int, int]:
  56. """
  57. 检测表格的锚点位置(表头左上角)
  58. 策略:
  59. 1. 找到Y坐标最小的文本框(表头第一行)
  60. 2. 找到X坐标最小的文本框(第一列)
  61. Args:
  62. ocr_data: OCR识别结果
  63. Returns:
  64. (anchor_x, anchor_y): 表格左上角坐标
  65. """
  66. if not ocr_data:
  67. return (0, 0)
  68. # 找到最小的X和Y坐标
  69. min_x = min(item['bbox'][0] for item in ocr_data)
  70. min_y = min(item['bbox'][1] for item in ocr_data)
  71. return (min_x, min_y)
  72. def detect_table_rows(self, ocr_data: List[Dict], header_y: int) -> int:
  73. """
  74. 检测表格的行数(包括表头)
  75. 策略:
  76. 1. 找到Y坐标最大的文本框
  77. 2. 根据数据行高计算行数
  78. 3. 加上表头行
  79. Args:
  80. ocr_data: OCR识别结果
  81. header_y: 表头起始Y坐标
  82. Returns:
  83. 总行数(包括表头)
  84. """
  85. if not ocr_data:
  86. return 1 # 至少有表头
  87. max_y = max(item['bbox'][3] for item in ocr_data)
  88. # 🔧 计算数据区的高度(排除表头)
  89. data_start_y = header_y + self.header_height
  90. data_height = max_y - data_start_y
  91. # 计算数据行数
  92. num_data_rows = max(int(data_height / self.row_height), 0)
  93. # 总行数 = 1行表头 + n行数据
  94. total_rows = 1 + num_data_rows
  95. print(f"📊 行数计算:")
  96. print(f" 表头Y: {header_y}, 数据区起始Y: {data_start_y}")
  97. print(f" 最大Y: {max_y}, 数据区高度: {data_height}px")
  98. print(f" 数据行数: {num_data_rows}, 总行数: {total_rows}")
  99. return total_rows
  100. def apply_template_fixed(self,
  101. image: Image.Image,
  102. ocr_data: List[Dict],
  103. anchor_x: int = None,
  104. anchor_y: int = None,
  105. num_rows: int = None,
  106. line_width: int = 2,
  107. line_color: Tuple[int, int, int] = (0, 0, 0)) -> Tuple[Image.Image, Dict]:
  108. """
  109. 将模板应用到图片
  110. Args:
  111. image: 目标图片
  112. ocr_data: OCR识别结果(用于自动检测锚点)
  113. anchor_x: 表格起始X坐标(None=自动检测)
  114. anchor_y: 表头起始Y坐标(None=自动检测)
  115. num_rows: 总行数(None=自动检测)
  116. line_width: 线条宽度
  117. line_color: 线条颜色
  118. Returns:
  119. 绘制了表格线的图片
  120. """
  121. img_with_lines = image.copy()
  122. draw = ImageDraw.Draw(img_with_lines)
  123. # 🔍 自动检测锚点
  124. if anchor_x is None or anchor_y is None:
  125. detected_x, detected_y = self.detect_table_anchor(ocr_data)
  126. anchor_x = anchor_x or detected_x
  127. anchor_y = anchor_y or detected_y
  128. # 🔍 自动检测行数
  129. if num_rows is None:
  130. num_rows = self.detect_table_rows(ocr_data, anchor_y)
  131. print(f"\n📍 表格锚点: ({anchor_x}, {anchor_y})")
  132. print(f"📊 总行数: {num_rows} (1表头 + {num_rows-1}数据)")
  133. # 🎨 生成横线坐标
  134. horizontal_lines = []
  135. # 第1条线:表头顶部
  136. horizontal_lines.append(anchor_y)
  137. # 第2条线:表头底部/数据区顶部
  138. horizontal_lines.append(anchor_y + self.header_height)
  139. # 后续横线:数据行分隔线
  140. current_y = anchor_y + self.header_height
  141. for i in range(num_rows - 1): # 减1因为表头已经占了1行
  142. current_y += self.row_height
  143. horizontal_lines.append(current_y)
  144. # 🎨 生成竖线坐标
  145. vertical_lines = []
  146. for offset in self.col_offsets:
  147. x = anchor_x + offset
  148. vertical_lines.append(x)
  149. print(f"📏 横线坐标: {horizontal_lines[:3]}... (共{len(horizontal_lines)}条)")
  150. print(f"📏 竖线坐标: {vertical_lines[:3]}... (共{len(vertical_lines)}条)")
  151. # 🖊️ 绘制横线
  152. x_start = vertical_lines[0]
  153. x_end = vertical_lines[-1]
  154. for y in horizontal_lines:
  155. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  156. # 🖊️ 绘制竖线
  157. y_start = horizontal_lines[0]
  158. y_end = horizontal_lines[-1]
  159. for x in vertical_lines:
  160. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  161. print(f"✅ 表格绘制完成: {len(horizontal_lines)}行 × {len(vertical_lines)-1}列")
  162. # 🔑 生成结构信息
  163. structure = self._build_structure(
  164. horizontal_lines,
  165. vertical_lines,
  166. anchor_x,
  167. anchor_y,
  168. mode='fixed'
  169. )
  170. return img_with_lines, structure
  171. def apply_template_hybrid(self,
  172. image: Image.Image,
  173. ocr_data_dict: Dict,
  174. use_ocr_rows: bool = True,
  175. anchor_x: int = None,
  176. anchor_y: int = None,
  177. y_tolerance: int = 5,
  178. line_width: int = 2,
  179. line_color: Tuple[int, int, int] = (0, 0, 0)) -> Tuple[Image.Image, Dict]:
  180. """
  181. 混合模式:使用模板的列 + OCR的行
  182. Args:
  183. image: 目标图片
  184. ocr_data: OCR识别结果(用于检测行)
  185. use_ocr_rows: 是否使用OCR检测的行(True=自适应行高)
  186. anchor_x: 表格起始X坐标(None=自动检测)
  187. anchor_y: 表头起始Y坐标(None=自动检测)
  188. y_tolerance: Y轴聚类容差(像素)
  189. line_width: 线条宽度
  190. line_color: 线条颜色
  191. Returns:
  192. 绘制了表格线的图片, 结构信息
  193. """
  194. img_with_lines = image.copy()
  195. draw = ImageDraw.Draw(img_with_lines)
  196. ocr_data = ocr_data_dict.get('text_boxes', [])
  197. # 🔍 自动检测锚点
  198. if anchor_x is None or anchor_y is None:
  199. detected_x, detected_y = self.detect_table_anchor(ocr_data)
  200. anchor_x = anchor_x or detected_x
  201. anchor_y = anchor_y or detected_y
  202. print(f"\n📍 表格锚点: ({anchor_x}, {anchor_y})")
  203. # ✅ 竖线:使用模板的列宽(固定)
  204. vertical_lines = [anchor_x + offset for offset in self.col_offsets]
  205. print(f"📏 竖线坐标: {vertical_lines} (使用模板,共{len(vertical_lines)}条)")
  206. # ✅ 横线:根据模式选择
  207. if use_ocr_rows and ocr_data:
  208. horizontal_lines = self._detect_rows_from_ocr(
  209. ocr_data, anchor_y, y_tolerance
  210. )
  211. print(f"📏 横线坐标: 使用OCR检测 (共{len(horizontal_lines)}条,自适应行高)")
  212. else:
  213. num_rows = self.detect_table_rows(ocr_data, anchor_y) if ocr_data else 10
  214. horizontal_lines = self._generate_fixed_rows(anchor_y, num_rows)
  215. print(f"📏 横线坐标: 使用固定行高 (共{len(horizontal_lines)}条)")
  216. # 🖊️ 绘制横线
  217. x_start = vertical_lines[0]
  218. x_end = vertical_lines[-1]
  219. for y in horizontal_lines:
  220. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  221. # 🖊️ 绘制竖线
  222. y_start = horizontal_lines[0]
  223. y_end = horizontal_lines[-1]
  224. for x in vertical_lines:
  225. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  226. print(f"✅ 表格绘制完成: {len(horizontal_lines)}行 × {len(vertical_lines)-1}列")
  227. # 🔑 生成结构信息
  228. structure = self._build_structure(
  229. horizontal_lines,
  230. vertical_lines,
  231. anchor_x,
  232. anchor_y,
  233. mode='hybrid'
  234. )
  235. return img_with_lines, structure
  236. def _detect_rows_from_ocr(self,
  237. ocr_data: List[Dict],
  238. anchor_y: int,
  239. y_tolerance: int = 5) -> List[int]:
  240. """
  241. 从OCR结果中检测行(自适应行高)
  242. 复用 get_structure_from_ocr 统一接口
  243. Args:
  244. ocr_data: OCR识别结果(MinerU 格式的 text_boxes)
  245. anchor_y: 表格起始Y坐标
  246. y_tolerance: Y轴聚类容差(未使用,保留参数兼容性)
  247. Returns:
  248. 横线 y 坐标列表
  249. """
  250. if not ocr_data:
  251. return [anchor_y, anchor_y + self.header_height]
  252. print(f"\n🔍 OCR行检测 (使用 MinerU 算法):")
  253. print(f" 有效文本框数: {len(ocr_data)}")
  254. # 🔑 验证是否为 MinerU 格式
  255. has_cell_index = any('row' in item and 'col' in item for item in ocr_data)
  256. if not has_cell_index:
  257. print(" ⚠️ 警告: OCR数据不包含 row/col 索引,可能不是 MinerU 格式")
  258. print(" ⚠️ 混合模式需要 MinerU 格式的 JSON 文件")
  259. return [anchor_y, anchor_y + self.header_height]
  260. # 🔑 重构原始数据格式(MinerU 需要完整的 table 结构)
  261. raw_data = {
  262. 'type': 'table',
  263. 'table_cells': ocr_data
  264. }
  265. try:
  266. # ✅ 使用统一接口解析和分析(无需 dummy_image)
  267. table_bbox, structure = get_structure_from_ocr(
  268. raw_data,
  269. tool="mineru"
  270. )
  271. if not structure or 'horizontal_lines' not in structure:
  272. print(" ⚠️ MinerU 分析失败,使用兜底方案")
  273. return [anchor_y, anchor_y + self.header_height]
  274. # 🔑 获取横线坐标
  275. horizontal_lines = structure['horizontal_lines']
  276. # 🔑 调整第一条线到 anchor_y(表头顶部)
  277. if horizontal_lines:
  278. offset = anchor_y - horizontal_lines[0]
  279. horizontal_lines = [y + offset for y in horizontal_lines]
  280. print(f" 检测到行数: {len(horizontal_lines) - 1}")
  281. # 🔑 分析行高分布
  282. if len(horizontal_lines) > 1:
  283. row_heights = []
  284. for i in range(len(horizontal_lines) - 1):
  285. h = horizontal_lines[i+1] - horizontal_lines[i]
  286. row_heights.append(h)
  287. if len(row_heights) > 1:
  288. import numpy as np
  289. print(f" 行高分布: min={min(row_heights)}, "
  290. f"median={int(np.median(row_heights))}, "
  291. f"max={max(row_heights)}")
  292. return horizontal_lines
  293. except Exception as e:
  294. print(f" ⚠️ 解析失败: {e}")
  295. import traceback
  296. traceback.print_exc()
  297. return [anchor_y, anchor_y + self.header_height]
  298. def _generate_fixed_rows(self, anchor_y: int, num_rows: int) -> List[int]:
  299. """生成固定行高的横线(兜底方案)"""
  300. horizontal_lines = [anchor_y]
  301. # 表头
  302. horizontal_lines.append(anchor_y + self.header_height)
  303. # 数据行
  304. current_y = anchor_y + self.header_height
  305. for i in range(num_rows - 1):
  306. current_y += self.fallback_row_height
  307. horizontal_lines.append(current_y)
  308. return horizontal_lines
  309. def _build_structure(self,
  310. horizontal_lines: List[int],
  311. vertical_lines: List[int],
  312. anchor_x: int,
  313. anchor_y: int,
  314. mode: str = 'fixed') -> Dict:
  315. """构建表格结构信息(统一)"""
  316. # 生成行区间
  317. rows = []
  318. for i in range(len(horizontal_lines) - 1):
  319. rows.append({
  320. 'y_start': horizontal_lines[i],
  321. 'y_end': horizontal_lines[i + 1],
  322. 'bboxes': []
  323. })
  324. # 生成列区间
  325. columns = []
  326. for i in range(len(vertical_lines) - 1):
  327. columns.append({
  328. 'x_start': vertical_lines[i],
  329. 'x_end': vertical_lines[i + 1]
  330. })
  331. # ✅ 根据模式设置正确的 mode 值
  332. if mode == 'hybrid':
  333. mode_value = 'hybrid'
  334. elif mode == 'fixed':
  335. mode_value = 'fixed'
  336. else:
  337. mode_value = mode # 保留原始值
  338. return {
  339. 'rows': rows,
  340. 'columns': columns,
  341. 'horizontal_lines': horizontal_lines,
  342. 'vertical_lines': vertical_lines,
  343. 'col_widths': self.col_widths,
  344. 'row_height': self.row_height if mode == 'fixed' else None,
  345. 'table_bbox': [
  346. vertical_lines[0],
  347. horizontal_lines[0],
  348. vertical_lines[-1],
  349. horizontal_lines[-1]
  350. ],
  351. 'mode': mode_value, # ✅ 确保有 mode 字段
  352. 'anchor': {'x': anchor_x, 'y': anchor_y},
  353. 'modified_h_lines': [], # ✅ 添加修改记录字段
  354. 'modified_v_lines': [] # ✅ 添加修改记录字段
  355. }
  356. def apply_template_to_single_file(
  357. applier: TableTemplateApplier,
  358. image_file: Path,
  359. json_file: Path,
  360. output_dir: Path,
  361. structure_suffix: str = "_structure.json",
  362. use_hybrid_mode: bool = True,
  363. line_width: int = 2,
  364. line_color: Tuple[int, int, int] = (0, 0, 0)
  365. ) -> bool:
  366. """
  367. 应用模板到单个文件
  368. Args:
  369. applier: 模板应用器实例
  370. image_file: 图片文件路径
  371. json_file: OCR JSON文件路径
  372. output_dir: 输出目录
  373. use_hybrid_mode: 是否使用混合模式(需要 MinerU 格式)
  374. line_width: 线条宽度
  375. line_color: 线条颜色
  376. Returns:
  377. 是否成功
  378. """
  379. print(f"📄 处理: {image_file.name}")
  380. try:
  381. # 加载OCR数据
  382. with open(json_file, 'r', encoding='utf-8') as f:
  383. raw_data = json.load(f)
  384. # 🔑 自动检测 OCR 格式
  385. ocr_format = None
  386. if 'parsing_res_list' in raw_data and 'overall_ocr_res' in raw_data:
  387. # PPStructure 格式
  388. ocr_format = 'ppstructure'
  389. elif isinstance(raw_data, (list, dict)):
  390. # 尝试提取 MinerU 格式
  391. table_data = None
  392. if isinstance(raw_data, list):
  393. for item in raw_data:
  394. if isinstance(item, dict) and item.get('type') == 'table':
  395. table_data = item
  396. break
  397. elif isinstance(raw_data, dict) and raw_data.get('type') == 'table':
  398. table_data = raw_data
  399. if table_data and 'table_cells' in table_data:
  400. ocr_format = 'mineru'
  401. else:
  402. raise ValueError("未识别的 OCR 格式")
  403. else:
  404. raise ValueError("未识别的 OCR 格式(仅支持 PPStructure 或 MinerU)")
  405. table_bbox, ocr_data = TableLineGenerator.parse_ocr_data(
  406. raw_data,
  407. tool=ocr_format
  408. )
  409. text_boxes = ocr_data.get('text_boxes', [])
  410. print(f" ✅ 加载OCR数据: {len(text_boxes)} 个文本框")
  411. print(f" 📋 OCR格式: {ocr_format}")
  412. # 加载图片
  413. image = Image.open(image_file)
  414. print(f" ✅ 加载图片: {image.size}")
  415. # 🔑 验证混合模式的格式要求
  416. if use_hybrid_mode and ocr_format != 'mineru':
  417. print(f" ⚠️ 警告: 混合模式需要 MinerU 格式,当前格式为 {ocr_format}")
  418. print(f" ℹ️ 自动切换到完全模板模式")
  419. use_hybrid_mode = False
  420. # 🆕 根据模式选择处理方式
  421. if use_hybrid_mode:
  422. print(f" 🔧 使用混合模式 (模板列 + MinerU 行)")
  423. img_with_lines, structure = applier.apply_template_hybrid(
  424. image,
  425. ocr_data,
  426. use_ocr_rows=True,
  427. line_width=line_width,
  428. line_color=line_color
  429. )
  430. else:
  431. print(f" 🔧 使用完全模板模式 (固定行高)")
  432. img_with_lines, structure = applier.apply_template_fixed(
  433. image,
  434. text_boxes,
  435. line_width=line_width,
  436. line_color=line_color
  437. )
  438. # 保存图片
  439. output_file = output_dir / f"{image_file.stem}.png"
  440. img_with_lines.save(output_file)
  441. # 保存结构配置
  442. structure_file = output_dir / f"{image_file.stem}{structure_suffix}"
  443. with open(structure_file, 'w', encoding='utf-8') as f:
  444. json.dump(structure, f, indent=2, ensure_ascii=False)
  445. print(f" ✅ 保存图片: {output_file.name}")
  446. print(f" ✅ 保存配置: {structure_file.name}")
  447. print(f" 📊 表格: {len(structure['rows'])}行 x {len(structure['columns'])}列")
  448. return True
  449. except Exception as e:
  450. print(f" ❌ 处理失败: {e}")
  451. import traceback
  452. traceback.print_exc()
  453. return False
  454. def apply_template_batch(
  455. template_config_path: str,
  456. image_dir: str,
  457. json_dir: str,
  458. output_dir: str,
  459. structure_suffix: str = "_structure.json",
  460. use_hybrid_mode: bool = False,
  461. line_width: int = 2,
  462. line_color: Tuple[int, int, int] = (0, 0, 0)
  463. ):
  464. """
  465. 批量应用模板到所有图片
  466. Args:
  467. template_config_path: 模板配置路径
  468. image_dir: 图片目录
  469. json_dir: OCR JSON目录
  470. output_dir: 输出目录
  471. line_width: 线条宽度
  472. line_color: 线条颜色
  473. """
  474. applier = TableTemplateApplier(template_config_path)
  475. image_path = Path(image_dir)
  476. json_path = Path(json_dir)
  477. output_path = Path(output_dir)
  478. output_path.mkdir(parents=True, exist_ok=True)
  479. # 查找所有图片
  480. image_files = list(image_path.glob("*.jpg")) + list(image_path.glob("*.png"))
  481. image_files.sort()
  482. print(f"\n🔍 找到 {len(image_files)} 个图片文件")
  483. print(f"📂 图片目录: {image_dir}")
  484. print(f"📂 JSON目录: {json_dir}")
  485. print(f"📂 输出目录: {output_dir}\n")
  486. results = []
  487. success_count = 0
  488. failed_count = 0
  489. for idx, image_file in enumerate(image_files, 1):
  490. print(f"\n{'='*60}")
  491. print(f"[{idx}/{len(image_files)}] 处理: {image_file.name}")
  492. print(f"{'='*60}")
  493. # 查找对应的JSON文件
  494. json_file = json_path / f"{image_file.stem}.json"
  495. if not json_file.exists():
  496. print(f"⚠️ 找不到OCR结果: {json_file.name}")
  497. results.append({
  498. 'source': str(image_file),
  499. 'status': 'skipped',
  500. 'reason': 'no_json'
  501. })
  502. failed_count += 1
  503. continue
  504. if apply_template_to_single_file(
  505. applier, image_file, json_file, output_path, structure_suffix, use_hybrid_mode,
  506. line_width, line_color
  507. ):
  508. results.append({
  509. 'source': str(image_file),
  510. 'json': str(json_file),
  511. 'status': 'success'
  512. })
  513. success_count += 1
  514. else:
  515. results.append({
  516. 'source': str(image_file),
  517. 'json': str(json_file),
  518. 'status': 'error'
  519. })
  520. failed_count += 1
  521. print()
  522. # 保存批处理结果
  523. result_file = output_path / "batch_results.json"
  524. with open(result_file, 'w', encoding='utf-8') as f:
  525. json.dump(results, f, indent=2, ensure_ascii=False)
  526. # 统计
  527. skipped_count = sum(1 for r in results if r['status'] == 'skipped')
  528. print(f"\n{'='*60}")
  529. print(f"🎉 批处理完成!")
  530. print(f"{'='*60}")
  531. print(f"✅ 成功: {success_count}")
  532. print(f"❌ 失败: {failed_count}")
  533. print(f"⚠️ 跳过: {skipped_count}")
  534. print(f"📊 总计: {len(results)}")
  535. print(f"📄 结果保存: {result_file}")
  536. def main():
  537. """主函数"""
  538. parser = argparse.ArgumentParser(
  539. description='应用表格模板到其他页面(支持混合模式)',
  540. formatter_class=argparse.RawDescriptionHelpFormatter,
  541. epilog="""
  542. 示例用法:
  543. 1. 混合模式(推荐,自适应行高):
  544. python table_template_applier.py \\
  545. --template template.json \\
  546. --image-dir /path/to/images \\
  547. --json-dir /path/to/jsons \\
  548. --output-dir /path/to/output \\
  549. --structure-suffix _structure.json \\
  550. --hybrid
  551. 2. 完全模板模式(固定行高):
  552. python table_template_applier.py \\
  553. --template template.json \\
  554. --image-file page.png \\
  555. --json-file page.json \\
  556. --output-dir /path/to/output \\
  557. --structure-suffix _structure.json \\
  558. 模式说明:
  559. - 混合模式(--hybrid): 列宽使用模板,行高根据OCR自适应
  560. - 完全模板模式: 列宽和行高都使用模板(适合固定格式表格)
  561. """
  562. )
  563. # 模板参数
  564. parser.add_argument(
  565. '-t', '--template',
  566. type=str,
  567. required=True,
  568. help='模板配置文件路径(人工标注的第一页结构)'
  569. )
  570. # 文件参数组
  571. file_group = parser.add_argument_group('文件参数(单文件模式)')
  572. file_group.add_argument(
  573. '--image-file',
  574. type=str,
  575. help='图片文件路径'
  576. )
  577. file_group.add_argument(
  578. '--json-file',
  579. type=str,
  580. help='OCR JSON文件路径'
  581. )
  582. # 目录参数组
  583. dir_group = parser.add_argument_group('目录参数(批量模式)')
  584. dir_group.add_argument(
  585. '--image-dir',
  586. type=str,
  587. help='图片目录'
  588. )
  589. dir_group.add_argument(
  590. '--json-dir',
  591. type=str,
  592. help='OCR JSON目录'
  593. )
  594. # 输出参数组
  595. output_group = parser.add_argument_group('输出参数')
  596. output_group.add_argument(
  597. '-o', '--output-dir',
  598. type=str,
  599. required=True,
  600. help='输出目录(必需)'
  601. )
  602. output_group.add_argument(
  603. '--structure-suffix',
  604. type=str,
  605. default='_structure.json',
  606. help='输出结构配置文件后缀(默认: _structure.json)'
  607. )
  608. # 绘图参数组
  609. draw_group = parser.add_argument_group('绘图参数')
  610. draw_group.add_argument(
  611. '-w', '--width',
  612. type=int,
  613. default=2,
  614. help='线条宽度(默认: 2)'
  615. )
  616. draw_group.add_argument(
  617. '-c', '--color',
  618. default='black',
  619. choices=['black', 'blue', 'red'],
  620. help='线条颜色(默认: black)'
  621. )
  622. # 🆕 新增模式参数
  623. mode_group = parser.add_argument_group('模式参数')
  624. mode_group.add_argument(
  625. '--hybrid',
  626. action='store_true',
  627. help='使用混合模式(模板列 + OCR行,自适应行高,推荐)'
  628. )
  629. args = parser.parse_args()
  630. # 颜色映射
  631. color_map = {
  632. 'black': (0, 0, 0),
  633. 'blue': (0, 0, 255),
  634. 'red': (255, 0, 0)
  635. }
  636. line_color = color_map[args.color]
  637. # 验证模板文件
  638. template_path = Path(args.template)
  639. if not template_path.exists():
  640. print(f"❌ 错误: 模板文件不存在: {template_path}")
  641. return
  642. output_path = Path(args.output_dir)
  643. output_path.mkdir(parents=True, exist_ok=True)
  644. # 判断模式
  645. if args.image_file and args.json_file:
  646. # 单文件模式
  647. image_file = Path(args.image_file)
  648. json_file = Path(args.json_file)
  649. if not image_file.exists():
  650. print(f"❌ 错误: 图片文件不存在: {image_file}")
  651. return
  652. if not json_file.exists():
  653. print(f"❌ 错误: JSON文件不存在: {json_file}")
  654. return
  655. print("\n🔧 单文件处理模式")
  656. print(f"📄 模板: {template_path.name}")
  657. print(f"📄 图片: {image_file.name}")
  658. print(f"📄 JSON: {json_file.name}")
  659. print(f"📂 输出: {output_path}\n")
  660. applier = TableTemplateApplier(str(template_path))
  661. success = apply_template_to_single_file(
  662. applier, image_file, json_file, output_path,
  663. use_hybrid_mode=args.hybrid, # 🆕 传递混合模式参数
  664. line_width=args.width,
  665. line_color=line_color
  666. )
  667. if success:
  668. print("\n✅ 处理完成!")
  669. else:
  670. print("\n❌ 处理失败!")
  671. elif args.image_dir and args.json_dir:
  672. # 批量模式
  673. image_dir = Path(args.image_dir)
  674. json_dir = Path(args.json_dir)
  675. if not image_dir.exists():
  676. print(f"❌ 错误: 图片目录不存在: {image_dir}")
  677. return
  678. if not json_dir.exists():
  679. print(f"❌ 错误: JSON目录不存在: {json_dir}")
  680. return
  681. print("\n🔧 批量处理模式")
  682. print(f"📄 模板: {template_path.name}")
  683. apply_template_batch(
  684. str(template_path),
  685. str(image_dir),
  686. str(json_dir),
  687. str(output_path),
  688. structure_suffix=args.structure_suffix,
  689. use_hybrid_mode=args.hybrid, # 🆕 传递混合模式参数
  690. line_width=args.width,
  691. line_color=line_color,
  692. )
  693. else:
  694. parser.print_help()
  695. print("\n❌ 错误: 请指定单文件模式或批量模式的参数")
  696. print("\n提示:")
  697. print(" 单文件模式: --image-file + --json-file")
  698. print(" 批量模式: --image-dir + --json-dir")
  699. if __name__ == "__main__":
  700. print("🚀 启动表格模板批量应用程序...")
  701. import sys
  702. if len(sys.argv) == 1:
  703. # 如果没有命令行参数,使用默认配置运行
  704. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  705. # 默认配置
  706. default_config = {
  707. "template": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行.wiredtable/康强_北京农村商业银行_page_001_structure.json",
  708. "image-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行/康强_北京农村商业银行_page_002.png",
  709. "json-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行_page_002.json",
  710. "output-dir": "output/batch_results",
  711. "width": "2",
  712. "color": "black"
  713. }
  714. # default_config = {
  715. # "template": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.wiredtable/B用户_扫描流水_page_001_structure.json",
  716. # "image-file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/mineru_vllm_results/B用户_扫描流水/B用户_扫描流水_page_002.png",
  717. # "json-file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/mineru_vllm_results_cell_bbox/B用户_扫描流水_page_002.json",
  718. # "output-dir": "output/batch_results",
  719. # "width": "2",
  720. # "color": "black"
  721. # }
  722. print("⚙️ 默认参数:")
  723. for key, value in default_config.items():
  724. print(f" --{key}: {value}")
  725. # 构造参数
  726. sys.argv = [sys.argv[0]]
  727. for key, value in default_config.items():
  728. sys.argv.extend([f"--{key}", str(value)])
  729. sys.argv.append("--hybrid") # 使用混合模式
  730. sys.exit(main())