table_template_applier.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. """
  2. 表格模板应用器
  3. 将人工标注的表格结构应用到其他页面
  4. """
  5. import json
  6. from pathlib import Path
  7. from PIL import Image, ImageDraw
  8. from typing import Dict, List, Tuple, Union, Optional
  9. import numpy as np
  10. import argparse
  11. import sys
  12. # 使用相对导入
  13. from .editor.data_processor import get_structure_from_ocr
  14. from .table_line_generator import TableLineGenerator
  15. class TableTemplateApplier:
  16. """表格模板应用器(混合模式)"""
  17. def __init__(self, template_config_path: str):
  18. """初始化时只提取列信息和表头信息"""
  19. with open(template_config_path, 'r', encoding='utf-8') as f:
  20. self.template = json.load(f)
  21. # ✅ 只提取列宽(固定)
  22. self.col_widths = self.template['col_widths']
  23. # ✅ 计算列的相对位置
  24. self.col_offsets = [0]
  25. for width in self.col_widths:
  26. self.col_offsets.append(self.col_offsets[-1] + width)
  27. # ✅ 提取表头高度(通常固定)
  28. rows = self.template['rows']
  29. if rows:
  30. self.header_height = rows[0]['y_end'] - rows[0]['y_start']
  31. else:
  32. self.header_height = 40
  33. # ✅ 计算数据行高度(用于固定行高模式)
  34. if len(rows) > 1:
  35. data_row_heights = [row['y_end'] - row['y_start'] for row in rows[1:]]
  36. # 使用中位数作为典型行高
  37. self.row_height = int(np.median(data_row_heights)) if data_row_heights else 40
  38. # 兜底行高(同样使用中位数)
  39. self.fallback_row_height = self.row_height
  40. else:
  41. # 如果只有表头,使用默认值
  42. self.row_height = 40
  43. self.fallback_row_height = 40
  44. print(f"\n✅ 加载模板配置:")
  45. print(f" 列数: {len(self.col_widths)}")
  46. print(f" 列宽: {self.col_widths}")
  47. print(f" 表头高度: {self.header_height}px")
  48. print(f" 数据行高: {self.row_height}px (用于固定行高模式)")
  49. print(f" 兜底行高: {self.fallback_row_height}px (OCR失败时使用)")
  50. def detect_table_anchor(self, ocr_data: List[Dict]) -> Tuple[int, int]:
  51. """
  52. 检测表格的锚点位置(表头左上角)
  53. 策略:
  54. 1. 找到Y坐标最小的文本框(表头第一行)
  55. 2. 找到X坐标最小的文本框(第一列)
  56. Args:
  57. ocr_data: OCR识别结果
  58. Returns:
  59. (anchor_x, anchor_y): 表格左上角坐标
  60. """
  61. if not ocr_data:
  62. return (0, 0)
  63. # 找到最小的X和Y坐标
  64. min_x = min(item['bbox'][0] for item in ocr_data)
  65. min_y = min(item['bbox'][1] for item in ocr_data)
  66. return (min_x, min_y)
  67. def detect_table_rows(self, ocr_data: List[Dict], header_y: int) -> int:
  68. """
  69. 检测表格的行数(包括表头)
  70. 策略:
  71. 1. 找到Y坐标最大的文本框
  72. 2. 根据数据行高计算行数
  73. 3. 加上表头行
  74. Args:
  75. ocr_data: OCR识别结果
  76. header_y: 表头起始Y坐标
  77. Returns:
  78. 总行数(包括表头)
  79. """
  80. if not ocr_data:
  81. return 1 # 至少有表头
  82. max_y = max(item['bbox'][3] for item in ocr_data)
  83. # 🔧 计算数据区的高度(排除表头)
  84. data_start_y = header_y + self.header_height
  85. data_height = max_y - data_start_y
  86. # 计算数据行数
  87. num_data_rows = max(int(data_height / self.row_height), 0)
  88. # 总行数 = 1行表头 + n行数据
  89. total_rows = 1 + num_data_rows
  90. print(f"📊 行数计算:")
  91. print(f" 表头Y: {header_y}, 数据区起始Y: {data_start_y}")
  92. print(f" 最大Y: {max_y}, 数据区高度: {data_height}px")
  93. print(f" 数据行数: {num_data_rows}, 总行数: {total_rows}")
  94. return total_rows
  95. def apply_template_fixed(self,
  96. image: Image.Image,
  97. ocr_data: Union[List[Dict], Dict], # 🆕 支持 Dict
  98. anchor_x: int = None,
  99. anchor_y: int = None,
  100. num_rows: int = None,
  101. line_width: int = 2,
  102. line_color: Tuple[int, int, int] = (0, 0, 0)) -> Tuple[Image.Image, Dict]:
  103. """
  104. 将模板应用到图片
  105. Args:
  106. image: 目标图片
  107. ocr_data: OCR识别结果(用于自动检测锚点),可以是列表或完整字典
  108. anchor_x: 表格起始X坐标(None=自动检测)
  109. anchor_y: 表头起始Y坐标(None=自动检测)
  110. num_rows: 总行数(None=自动检测)
  111. line_width: 线条宽度
  112. line_color: 线条颜色
  113. Returns:
  114. 绘制了表格线的图片
  115. """
  116. # 🆕 1. 实例化生成器并进行倾斜校正
  117. ocr_data_dict = {'text_boxes': ocr_data}
  118. # 尝试从 ocr_data 列表中获取角度信息(如果它是从 ocr_data 字典中提取出来的 list)
  119. # 但通常 ocr_data 这里只是 text_boxes 列表。
  120. # 我们需要传递包含 image_rotation_angle 和 skew_angle 的字典。
  121. # 由于调用者可能会传入 list,我们需要检查是否有更多信息。
  122. # 这里假设调用者会在传入 list 前处理好,或者我们在这里无法获取。
  123. # 不过,如果是从 parse_ocr_data 获取的 ocr_data,它应该是个 dict。
  124. # apply_template_fixed 的签名是 ocr_data: List[Dict],这意味着它只接收 text_boxes。
  125. # 这可能是一个问题。我们需要修改调用处或者在这里处理。
  126. # 看看 apply_template_to_single_file 是怎么调用的。
  127. # apply_template_to_single_file:
  128. # text_boxes = ocr_data.get('text_boxes', [])
  129. # applier.apply_template_fixed(image, text_boxes, ...)
  130. # 这样我们就丢失了角度信息。
  131. # 我应该修改 apply_template_fixed 的签名,让它接收 Dict 类型的 ocr_data,或者单独传递角度。
  132. # 为了保持兼容性,我可以修改 apply_template_fixed 内部处理。
  133. # 但最好的方式是让它接收整个 ocr_data 字典,就像 apply_template_hybrid 一样。
  134. # 不过,为了最小化修改,我可以在 apply_template_to_single_file 里把角度传进来?
  135. # 不,那得改很多。
  136. # 让我们看看能不能在 apply_template_fixed 里重新构造 ocr_data_dict。
  137. # 如果传入的 ocr_data 是 list,那我们确实没法知道角度。
  138. # 除非我们改变 apply_template_to_single_file 的调用方式。
  139. # 让我们先修改 apply_template_to_single_file 的调用方式,传整个 ocr_data 进去。
  140. # 但是 apply_template_fixed 的签名明确写了 ocr_data: List[Dict]。
  141. # 既然我正在修改这个文件,我可以改变它的签名。
  142. # 或者,我可以像 apply_template_hybrid 一样,增加一个参数 ocr_data_full: Dict = None
  143. # 实际上,apply_template_hybrid 已经接收 ocr_data_dict: Dict。
  144. # apply_template_fixed 接收 List[Dict]。
  145. # 这是一个不一致的地方。
  146. # 我决定修改 apply_template_fixed 的参数,让它也能利用 TableLineGenerator 进行校正。
  147. # 但是 TableLineGenerator 需要完整的 ocr_data 字典才能读取角度。
  148. # 方案:修改 apply_template_fixed 接收 ocr_data_dict。
  149. # 为了兼容旧代码,如果传入的是 list,就包装一下。
  150. # 但是 Python 类型提示 List[Dict] 和 Dict 是不一样的。
  151. # 我可以把参数名改成 ocr_input,类型 Union[List[Dict], Dict]。
  152. # 或者,既然这是内部使用的工具,我直接修改签名,让它接收 Dict。
  153. # 检查一下是否有其他地方调用这个方法。
  154. # 只在 apply_template_to_single_file 调用了。
  155. # 所以我将修改 apply_template_fixed 接收 ocr_data_dict: Dict。
  156. generator = TableLineGenerator(image, {'text_boxes': ocr_data} if isinstance(ocr_data, list) else ocr_data)
  157. corrected_image, angle = generator.correct_skew()
  158. # 获取角度信息
  159. image_rotation_angle = generator.ocr_data.get('image_rotation_angle', 0.0)
  160. skew_angle = generator.ocr_data.get('skew_angle', 0.0)
  161. if abs(angle) > 0.1 or image_rotation_angle != 0:
  162. print(f"🔄 [TemplateApplier] 自动校正: 旋转={image_rotation_angle}°, 倾斜={skew_angle:.2f}°")
  163. # 更新 OCR 数据(generator 内部已经更新了)
  164. ocr_data = generator.ocr_data.get('text_boxes', [])
  165. # 使用校正后的图片
  166. img_with_lines = corrected_image.copy()
  167. else:
  168. img_with_lines = image.copy()
  169. # 如果是字典,提取 list
  170. if isinstance(ocr_data, dict):
  171. ocr_data = ocr_data.get('text_boxes', [])
  172. draw = ImageDraw.Draw(img_with_lines)
  173. # 🔍 自动检测锚点
  174. if anchor_x is None or anchor_y is None:
  175. detected_x, detected_y = self.detect_table_anchor(ocr_data)
  176. anchor_x = anchor_x or detected_x
  177. anchor_y = anchor_y or detected_y
  178. # 🔍 自动检测行数
  179. if num_rows is None:
  180. num_rows = self.detect_table_rows(ocr_data, anchor_y)
  181. print(f"\n📍 表格锚点: ({anchor_x}, {anchor_y})")
  182. print(f"📊 总行数: {num_rows} (1表头 + {num_rows-1}数据)")
  183. # 🎨 生成横线坐标
  184. horizontal_lines = []
  185. # 第1条线:表头顶部
  186. horizontal_lines.append(anchor_y)
  187. # 第2条线:表头底部/数据区顶部
  188. horizontal_lines.append(anchor_y + self.header_height)
  189. # 后续横线:数据行分隔线
  190. current_y = anchor_y + self.header_height
  191. for i in range(num_rows - 1): # 减1因为表头已经占了1行
  192. current_y += self.row_height
  193. horizontal_lines.append(current_y)
  194. # 🎨 生成竖线坐标
  195. vertical_lines = []
  196. for offset in self.col_offsets:
  197. x = anchor_x + offset
  198. vertical_lines.append(x)
  199. print(f"📏 横线坐标: {horizontal_lines[:3]}... (共{len(horizontal_lines)}条)")
  200. print(f"📏 竖线坐标: {vertical_lines[:3]}... (共{len(vertical_lines)}条)")
  201. # 🖊️ 绘制横线
  202. x_start = vertical_lines[0]
  203. x_end = vertical_lines[-1]
  204. for y in horizontal_lines:
  205. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  206. # 🖊️ 绘制竖线
  207. y_start = horizontal_lines[0]
  208. y_end = horizontal_lines[-1]
  209. for x in vertical_lines:
  210. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  211. print(f"✅ 表格绘制完成: {len(horizontal_lines)}行 × {len(vertical_lines)-1}列")
  212. # 🔑 生成结构信息
  213. structure = self._build_structure(
  214. horizontal_lines,
  215. vertical_lines,
  216. anchor_x,
  217. anchor_y,
  218. mode='fixed',
  219. image_rotation_angle=image_rotation_angle,
  220. skew_angle=skew_angle
  221. )
  222. return img_with_lines, structure
  223. def apply_template_hybrid(self,
  224. image: Image.Image,
  225. ocr_data_dict: Dict,
  226. use_ocr_rows: bool = True,
  227. anchor_x: int = None,
  228. anchor_y: int = None,
  229. y_tolerance: int = 5,
  230. line_width: int = 2,
  231. line_color: Tuple[int, int, int] = (0, 0, 0)) -> Tuple[Image.Image, Dict]:
  232. """
  233. 混合模式:使用模板的列 + OCR的行
  234. Args:
  235. image: 目标图片
  236. ocr_data: OCR识别结果(用于检测行)
  237. use_ocr_rows: 是否使用OCR检测的行(True=自适应行高)
  238. anchor_x: 表格起始X坐标(None=自动检测)
  239. anchor_y: 表头起始Y坐标(None=自动检测)
  240. y_tolerance: Y轴聚类容差(像素)
  241. line_width: 线条宽度
  242. line_color: 线条颜色
  243. Returns:
  244. 绘制了表格线的图片, 结构信息
  245. """
  246. # 🆕 1. 实例化生成器并进行倾斜校正
  247. generator = TableLineGenerator(image, ocr_data_dict)
  248. corrected_image, angle = generator.correct_skew()
  249. # 🆕 获取图片旋转角度
  250. image_rotation_angle = ocr_data_dict.get('image_rotation_angle', 0.0)
  251. skew_angle = ocr_data_dict.get('skew_angle', 0.0)
  252. if abs(angle) > 0.1 or image_rotation_angle != 0:
  253. print(f"🔄 [TemplateApplier] 自动校正: 旋转={image_rotation_angle}°, 倾斜={skew_angle:.2f}°")
  254. # 更新 OCR 数据
  255. ocr_data_dict = generator.ocr_data
  256. # 使用校正后的图片
  257. img_with_lines = corrected_image.copy()
  258. else:
  259. img_with_lines = image.copy()
  260. draw = ImageDraw.Draw(img_with_lines)
  261. ocr_data = ocr_data_dict.get('text_boxes', [])
  262. # 🔍 自动检测锚点
  263. if anchor_x is None or anchor_y is None:
  264. detected_x, detected_y = self.detect_table_anchor(ocr_data)
  265. anchor_x = anchor_x or detected_x
  266. anchor_y = anchor_y or detected_y
  267. print(f"\n📍 表格锚点: ({anchor_x}, {anchor_y})")
  268. # ✅ 竖线:使用模板的列宽(固定)
  269. vertical_lines = [anchor_x + offset for offset in self.col_offsets]
  270. print(f"📏 竖线坐标: {vertical_lines} (使用模板,共{len(vertical_lines)}条)")
  271. # ✅ 横线:根据模式选择
  272. if use_ocr_rows and ocr_data:
  273. horizontal_lines = self._detect_rows_from_ocr(
  274. ocr_data, anchor_y, y_tolerance
  275. )
  276. print(f"📏 横线坐标: 使用OCR检测 (共{len(horizontal_lines)}条,自适应行高)")
  277. else:
  278. num_rows = self.detect_table_rows(ocr_data, anchor_y) if ocr_data else 10
  279. horizontal_lines = self._generate_fixed_rows(anchor_y, num_rows)
  280. print(f"📏 横线坐标: 使用固定行高 (共{len(horizontal_lines)}条)")
  281. # 🖊️ 绘制横线
  282. x_start = vertical_lines[0]
  283. x_end = vertical_lines[-1]
  284. for y in horizontal_lines:
  285. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  286. # 🖊️ 绘制竖线
  287. y_start = horizontal_lines[0]
  288. y_end = horizontal_lines[-1]
  289. for x in vertical_lines:
  290. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  291. print(f"✅ 表格绘制完成: {len(horizontal_lines)}行 × {len(vertical_lines)-1}列")
  292. # 🔑 生成结构信息
  293. structure = self._build_structure(
  294. horizontal_lines,
  295. vertical_lines,
  296. anchor_x,
  297. anchor_y,
  298. mode='hybrid',
  299. image_rotation_angle=image_rotation_angle,
  300. skew_angle=skew_angle
  301. )
  302. return img_with_lines, structure
  303. def _detect_rows_from_ocr(self,
  304. ocr_data: List[Dict],
  305. anchor_y: int,
  306. y_tolerance: int = 5) -> List[int]:
  307. """
  308. 从OCR结果中检测行(自适应行高)
  309. 复用 get_structure_from_ocr 统一接口
  310. Args:
  311. ocr_data: OCR识别结果(MinerU 格式的 text_boxes)
  312. anchor_y: 表格起始Y坐标
  313. y_tolerance: Y轴聚类容差(未使用,保留参数兼容性)
  314. Returns:
  315. 横线 y 坐标列表
  316. """
  317. if not ocr_data:
  318. return [anchor_y, anchor_y + self.header_height]
  319. print(f"\n🔍 OCR行检测 (使用 MinerU 算法):")
  320. print(f" 有效文本框数: {len(ocr_data)}")
  321. # 🔑 验证是否为 MinerU 格式
  322. has_cell_index = any('row' in item and 'col' in item for item in ocr_data)
  323. if not has_cell_index:
  324. print(" ⚠️ 警告: OCR数据不包含 row/col 索引,可能不是 MinerU 格式")
  325. print(" ⚠️ 混合模式需要 MinerU 格式的 JSON 文件")
  326. return [anchor_y, anchor_y + self.header_height]
  327. # 🔑 重构原始数据格式(MinerU 需要完整的 table 结构)
  328. raw_data = {
  329. 'type': 'table',
  330. 'table_cells': ocr_data
  331. }
  332. try:
  333. # ✅ 使用统一接口解析和分析(无需 dummy_image)
  334. table_bbox, structure = get_structure_from_ocr(
  335. raw_data,
  336. json_format="mineru"
  337. )
  338. if not structure or 'horizontal_lines' not in structure:
  339. print(" ⚠️ MinerU 分析失败,使用兜底方案")
  340. return [anchor_y, anchor_y + self.header_height]
  341. # 🔑 获取横线坐标
  342. horizontal_lines = structure['horizontal_lines']
  343. # 🔑 调整第一条线到 anchor_y(表头顶部)
  344. if horizontal_lines:
  345. offset = anchor_y - horizontal_lines[0]
  346. horizontal_lines = [y + offset for y in horizontal_lines]
  347. print(f" 检测到行数: {len(horizontal_lines) - 1}")
  348. # 🔑 分析行高分布
  349. if len(horizontal_lines) > 1:
  350. row_heights = []
  351. for i in range(len(horizontal_lines) - 1):
  352. h = horizontal_lines[i+1] - horizontal_lines[i]
  353. row_heights.append(h)
  354. if len(row_heights) > 1:
  355. import numpy as np
  356. print(f" 行高分布: min={min(row_heights)}, "
  357. f"median={int(np.median(row_heights))}, "
  358. f"max={max(row_heights)}")
  359. return horizontal_lines
  360. except Exception as e:
  361. print(f" ⚠️ 解析失败: {e}")
  362. import traceback
  363. traceback.print_exc()
  364. return [anchor_y, anchor_y + self.header_height]
  365. def _generate_fixed_rows(self, anchor_y: int, num_rows: int) -> List[int]:
  366. """生成固定行高的横线(兜底方案)"""
  367. horizontal_lines = [anchor_y]
  368. # 表头
  369. horizontal_lines.append(anchor_y + self.header_height)
  370. # 数据行
  371. current_y = anchor_y + self.header_height
  372. for i in range(num_rows - 1):
  373. current_y += self.fallback_row_height
  374. horizontal_lines.append(current_y)
  375. return horizontal_lines
  376. def _build_structure(self,
  377. horizontal_lines: List[int],
  378. vertical_lines: List[int],
  379. anchor_x: int,
  380. anchor_y: int,
  381. mode: str = 'fixed',
  382. image_rotation_angle: float = 0.0,
  383. skew_angle: float = 0.0) -> Dict:
  384. """构建表格结构信息(统一)"""
  385. # 生成行区间
  386. rows = []
  387. for i in range(len(horizontal_lines) - 1):
  388. rows.append({
  389. 'y_start': horizontal_lines[i],
  390. 'y_end': horizontal_lines[i + 1],
  391. 'bboxes': []
  392. })
  393. # 生成列区间
  394. columns = []
  395. for i in range(len(vertical_lines) - 1):
  396. columns.append({
  397. 'x_start': vertical_lines[i],
  398. 'x_end': vertical_lines[i + 1]
  399. })
  400. # ✅ 根据模式设置正确的 mode 值
  401. if mode == 'hybrid':
  402. mode_value = 'hybrid'
  403. elif mode == 'fixed':
  404. mode_value = 'fixed'
  405. else:
  406. mode_value = mode # 保留原始值
  407. return {
  408. 'rows': rows,
  409. 'columns': columns,
  410. 'horizontal_lines': horizontal_lines,
  411. 'vertical_lines': vertical_lines,
  412. 'col_widths': self.col_widths,
  413. 'row_height': self.row_height if mode == 'fixed' else None,
  414. 'table_bbox': [
  415. vertical_lines[0],
  416. horizontal_lines[0],
  417. vertical_lines[-1],
  418. horizontal_lines[-1]
  419. ],
  420. 'mode': mode_value, # ✅ 确保有 mode 字段
  421. 'anchor': {'x': anchor_x, 'y': anchor_y},
  422. 'modified_h_lines': [], # ✅ 添加修改记录字段
  423. 'modified_v_lines': [], # ✅ 添加修改记录字段
  424. 'image_rotation_angle': image_rotation_angle,
  425. 'skew_angle': skew_angle,
  426. 'is_skew_corrected': abs(skew_angle) > 0.1 or image_rotation_angle != 0
  427. }
  428. def apply_template_to_single_file(
  429. applier: TableTemplateApplier,
  430. image_file: Path,
  431. json_file: Path,
  432. output_dir: Path,
  433. structure_suffix: str = "_structure.json",
  434. use_hybrid_mode: bool = True,
  435. line_width: int = 2,
  436. line_color: Tuple[int, int, int] = (0, 0, 0)
  437. ) -> bool:
  438. """
  439. 应用模板到单个文件
  440. Args:
  441. applier: 模板应用器实例
  442. image_file: 图片文件路径
  443. json_file: OCR JSON文件路径
  444. output_dir: 输出目录
  445. use_hybrid_mode: 是否使用混合模式(需要 MinerU 格式)
  446. line_width: 线条宽度
  447. line_color: 线条颜色
  448. Returns:
  449. 是否成功
  450. """
  451. print(f"📄 处理: {image_file.name}")
  452. try:
  453. # 加载OCR数据
  454. with open(json_file, 'r', encoding='utf-8') as f:
  455. raw_data = json.load(f)
  456. # 🔑 自动检测 OCR 格式
  457. ocr_format = None
  458. if 'parsing_res_list' in raw_data and 'overall_ocr_res' in raw_data:
  459. # PPStructure 格式
  460. ocr_format = 'ppstructure'
  461. elif isinstance(raw_data, (list, dict)):
  462. # 尝试提取 MinerU 格式
  463. table_data = None
  464. if isinstance(raw_data, list):
  465. for item in raw_data:
  466. if isinstance(item, dict) and item.get('type') == 'table':
  467. table_data = item
  468. break
  469. elif isinstance(raw_data, dict) and raw_data.get('type') == 'table':
  470. table_data = raw_data
  471. if table_data and 'table_cells' in table_data:
  472. ocr_format = 'mineru'
  473. else:
  474. raise ValueError("未识别的 OCR 格式")
  475. else:
  476. raise ValueError("未识别的 OCR 格式(仅支持 PPStructure 或 MinerU)")
  477. table_bbox, ocr_data = TableLineGenerator.parse_ocr_data(
  478. raw_data,
  479. json_format=ocr_format
  480. )
  481. text_boxes = ocr_data.get('text_boxes', [])
  482. print(f" ✅ 加载OCR数据: {len(text_boxes)} 个文本框")
  483. print(f" 📋 OCR格式: {ocr_format}")
  484. # 加载图片
  485. image = Image.open(image_file)
  486. print(f" ✅ 加载图片: {image.size}")
  487. # 🔑 验证混合模式的格式要求
  488. if use_hybrid_mode and ocr_format != 'mineru':
  489. print(f" ⚠️ 警告: 混合模式需要 MinerU 格式,当前格式为 {ocr_format}")
  490. print(f" ℹ️ 自动切换到完全模板模式")
  491. use_hybrid_mode = False
  492. # 🆕 根据模式选择处理方式
  493. if use_hybrid_mode:
  494. print(f" 🔧 使用混合模式 (模板列 + MinerU 行)")
  495. img_with_lines, structure = applier.apply_template_hybrid(
  496. image,
  497. ocr_data,
  498. use_ocr_rows=True,
  499. line_width=line_width,
  500. line_color=line_color
  501. )
  502. else:
  503. print(f" 🔧 使用完全模板模式 (固定行高)")
  504. img_with_lines, structure = applier.apply_template_fixed(
  505. image,
  506. text_boxes,
  507. line_width=line_width,
  508. line_color=line_color
  509. )
  510. # 保存图片
  511. output_file = output_dir / f"{image_file.stem}.png"
  512. img_with_lines.save(output_file)
  513. # 保存结构配置
  514. structure_file = output_dir / f"{image_file.stem}{structure_suffix}"
  515. with open(structure_file, 'w', encoding='utf-8') as f:
  516. json.dump(structure, f, indent=2, ensure_ascii=False)
  517. print(f" ✅ 保存图片: {output_file.name}")
  518. print(f" ✅ 保存配置: {structure_file.name}")
  519. print(f" 📊 表格: {len(structure['rows'])}行 x {len(structure['columns'])}列")
  520. return True
  521. except Exception as e:
  522. print(f" ❌ 处理失败: {e}")
  523. import traceback
  524. traceback.print_exc()
  525. return False
  526. def apply_template_batch(
  527. template_config_path: str,
  528. image_dir: str,
  529. json_dir: str,
  530. output_dir: str,
  531. structure_suffix: str = "_structure.json",
  532. use_hybrid_mode: bool = False,
  533. line_width: int = 2,
  534. line_color: Tuple[int, int, int] = (0, 0, 0)
  535. ):
  536. """
  537. 批量应用模板到所有图片
  538. Args:
  539. template_config_path: 模板配置路径
  540. image_dir: 图片目录
  541. json_dir: OCR JSON目录
  542. output_dir: 输出目录
  543. line_width: 线条宽度
  544. line_color: 线条颜色
  545. """
  546. applier = TableTemplateApplier(template_config_path)
  547. image_path = Path(image_dir)
  548. json_path = Path(json_dir)
  549. output_path = Path(output_dir)
  550. output_path.mkdir(parents=True, exist_ok=True)
  551. # 查找所有图片
  552. image_files = list(image_path.glob("*.jpg")) + list(image_path.glob("*.png"))
  553. image_files.sort()
  554. print(f"\n🔍 找到 {len(image_files)} 个图片文件")
  555. print(f"📂 图片目录: {image_dir}")
  556. print(f"📂 JSON目录: {json_dir}")
  557. print(f"📂 输出目录: {output_dir}\n")
  558. results = []
  559. success_count = 0
  560. failed_count = 0
  561. for idx, image_file in enumerate(image_files, 1):
  562. print(f"\n{'='*60}")
  563. print(f"[{idx}/{len(image_files)}] 处理: {image_file.name}")
  564. print(f"{'='*60}")
  565. # 查找对应的JSON文件
  566. json_file = json_path / f"{image_file.stem}.json"
  567. if not json_file.exists():
  568. print(f"⚠️ 找不到OCR结果: {json_file.name}")
  569. results.append({
  570. 'source': str(image_file),
  571. 'status': 'skipped',
  572. 'reason': 'no_json'
  573. })
  574. failed_count += 1
  575. continue
  576. if apply_template_to_single_file(
  577. applier, image_file, json_file, output_path, structure_suffix, use_hybrid_mode,
  578. line_width, line_color
  579. ):
  580. results.append({
  581. 'source': str(image_file),
  582. 'json': str(json_file),
  583. 'status': 'success'
  584. })
  585. success_count += 1
  586. else:
  587. results.append({
  588. 'source': str(image_file),
  589. 'json': str(json_file),
  590. 'status': 'error'
  591. })
  592. failed_count += 1
  593. print()
  594. # 保存批处理结果
  595. result_file = output_path / "batch_results.json"
  596. with open(result_file, 'w', encoding='utf-8') as f:
  597. json.dump(results, f, indent=2, ensure_ascii=False)
  598. # 统计
  599. skipped_count = sum(1 for r in results if r['status'] == 'skipped')
  600. print(f"\n{'='*60}")
  601. print(f"🎉 批处理完成!")
  602. print(f"{'='*60}")
  603. print(f"✅ 成功: {success_count}")
  604. print(f"❌ 失败: {failed_count}")
  605. print(f"⚠️ 跳过: {skipped_count}")
  606. print(f"📊 总计: {len(results)}")
  607. print(f"📄 结果保存: {result_file}")
  608. def main():
  609. """主函数"""
  610. parser = argparse.ArgumentParser(
  611. description='应用表格模板到其他页面(支持混合模式)',
  612. formatter_class=argparse.RawDescriptionHelpFormatter,
  613. epilog="""
  614. 示例用法:
  615. 1. 混合模式(推荐,自适应行高):
  616. python table_template_applier.py \\
  617. --template template.json \\
  618. --image-dir /path/to/images \\
  619. --json-dir /path/to/jsons \\
  620. --output-dir /path/to/output \\
  621. --structure-suffix _structure.json \\
  622. --hybrid
  623. 2. 完全模板模式(固定行高):
  624. python table_template_applier.py \\
  625. --template template.json \\
  626. --image-file page.png \\
  627. --json-file page.json \\
  628. --output-dir /path/to/output \\
  629. --structure-suffix _structure.json \\
  630. 模式说明:
  631. - 混合模式(--hybrid): 列宽使用模板,行高根据OCR自适应
  632. - 完全模板模式: 列宽和行高都使用模板(适合固定格式表格)
  633. """
  634. )
  635. # 模板参数
  636. parser.add_argument(
  637. '-t', '--template',
  638. type=str,
  639. required=True,
  640. help='模板配置文件路径(人工标注的第一页结构)'
  641. )
  642. # 文件参数组
  643. file_group = parser.add_argument_group('文件参数(单文件模式)')
  644. file_group.add_argument(
  645. '--image-file',
  646. type=str,
  647. help='图片文件路径'
  648. )
  649. file_group.add_argument(
  650. '--json-file',
  651. type=str,
  652. help='OCR JSON文件路径'
  653. )
  654. # 目录参数组
  655. dir_group = parser.add_argument_group('目录参数(批量模式)')
  656. dir_group.add_argument(
  657. '--image-dir',
  658. type=str,
  659. help='图片目录'
  660. )
  661. dir_group.add_argument(
  662. '--json-dir',
  663. type=str,
  664. help='OCR JSON目录'
  665. )
  666. # 输出参数组
  667. output_group = parser.add_argument_group('输出参数')
  668. output_group.add_argument(
  669. '-o', '--output-dir',
  670. type=str,
  671. required=True,
  672. help='输出目录(必需)'
  673. )
  674. output_group.add_argument(
  675. '--structure-suffix',
  676. type=str,
  677. default='_structure.json',
  678. help='输出结构配置文件后缀(默认: _structure.json)'
  679. )
  680. # 绘图参数组
  681. draw_group = parser.add_argument_group('绘图参数')
  682. draw_group.add_argument(
  683. '-w', '--width',
  684. type=int,
  685. default=2,
  686. help='线条宽度(默认: 2)'
  687. )
  688. draw_group.add_argument(
  689. '-c', '--color',
  690. default='black',
  691. choices=['black', 'blue', 'red'],
  692. help='线条颜色(默认: black)'
  693. )
  694. # 🆕 新增模式参数
  695. mode_group = parser.add_argument_group('模式参数')
  696. mode_group.add_argument(
  697. '--hybrid',
  698. action='store_true',
  699. help='使用混合模式(模板列 + OCR行,自适应行高,推荐)'
  700. )
  701. args = parser.parse_args()
  702. # 颜色映射
  703. color_map = {
  704. 'black': (0, 0, 0),
  705. 'blue': (0, 0, 255),
  706. 'red': (255, 0, 0)
  707. }
  708. line_color = color_map[args.color]
  709. # 验证模板文件
  710. template_path = Path(args.template)
  711. if not template_path.exists():
  712. print(f"❌ 错误: 模板文件不存在: {template_path}")
  713. return
  714. output_path = Path(args.output_dir)
  715. output_path.mkdir(parents=True, exist_ok=True)
  716. # 判断模式
  717. if args.image_file and args.json_file:
  718. # 单文件模式
  719. image_file = Path(args.image_file)
  720. json_file = Path(args.json_file)
  721. if not image_file.exists():
  722. print(f"❌ 错误: 图片文件不存在: {image_file}")
  723. return
  724. if not json_file.exists():
  725. print(f"❌ 错误: JSON文件不存在: {json_file}")
  726. return
  727. print("\n🔧 单文件处理模式")
  728. print(f"📄 模板: {template_path.name}")
  729. print(f"📄 图片: {image_file.name}")
  730. print(f"📄 JSON: {json_file.name}")
  731. print(f"📂 输出: {output_path}\n")
  732. applier = TableTemplateApplier(str(template_path))
  733. success = apply_template_to_single_file(
  734. applier, image_file, json_file, output_path,
  735. use_hybrid_mode=args.hybrid, # 🆕 传递混合模式参数
  736. line_width=args.width,
  737. line_color=line_color
  738. )
  739. if success:
  740. print("\n✅ 处理完成!")
  741. else:
  742. print("\n❌ 处理失败!")
  743. elif args.image_dir and args.json_dir:
  744. # 批量模式
  745. image_dir = Path(args.image_dir)
  746. json_dir = Path(args.json_dir)
  747. if not image_dir.exists():
  748. print(f"❌ 错误: 图片目录不存在: {image_dir}")
  749. return
  750. if not json_dir.exists():
  751. print(f"❌ 错误: JSON目录不存在: {json_dir}")
  752. return
  753. print("\n🔧 批量处理模式")
  754. print(f"📄 模板: {template_path.name}")
  755. apply_template_batch(
  756. str(template_path),
  757. str(image_dir),
  758. str(json_dir),
  759. str(output_path),
  760. structure_suffix=args.structure_suffix,
  761. use_hybrid_mode=args.hybrid, # 🆕 传递混合模式参数
  762. line_width=args.width,
  763. line_color=line_color,
  764. )
  765. else:
  766. parser.print_help()
  767. print("\n❌ 错误: 请指定单文件模式或批量模式的参数")
  768. print("\n提示:")
  769. print(" 单文件模式: --image-file + --json-file")
  770. print(" 批量模式: --image-dir + --json-dir")
  771. if __name__ == "__main__":
  772. print("🚀 启动表格模板批量应用程序...")
  773. import sys
  774. if len(sys.argv) == 1:
  775. # 如果没有命令行参数,使用默认配置运行
  776. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  777. # 默认配置
  778. default_config = {
  779. "template": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行.wiredtable/康强_北京农村商业银行_page_001_structure.json",
  780. "image-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行/康强_北京农村商业银行_page_002.png",
  781. "json-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行_page_002.json",
  782. "output-dir": "output/batch_results",
  783. "width": "2",
  784. "color": "black"
  785. }
  786. # default_config = {
  787. # "template": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.wiredtable/B用户_扫描流水_page_001_structure.json",
  788. # "image-file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/mineru_vllm_results/B用户_扫描流水/B用户_扫描流水_page_002.png",
  789. # "json-file": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/mineru_vllm_results_cell_bbox/B用户_扫描流水_page_002.json",
  790. # "output-dir": "output/batch_results",
  791. # "width": "2",
  792. # "color": "black"
  793. # }
  794. print("⚙️ 默认参数:")
  795. for key, value in default_config.items():
  796. print(f" --{key}: {value}")
  797. # 构造参数
  798. sys.argv = [sys.argv[0]]
  799. for key, value in default_config.items():
  800. sys.argv.extend([f"--{key}", str(value)])
  801. sys.argv.append("--hybrid") # 使用混合模式
  802. sys.exit(main())