table_template_applier.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. """
  2. 表格模板应用器
  3. 将人工标注的表格结构应用到其他页面
  4. """
  5. import json
  6. from pathlib import Path
  7. from PIL import Image, ImageDraw
  8. from typing import Dict, List, Tuple
  9. import numpy as np
  10. import argparse
  11. try:
  12. from table_line_generator import TableLineGenerator
  13. except ImportError:
  14. from .table_line_generator import TableLineGenerator
  15. class TableTemplateApplier:
  16. """表格模板应用器"""
  17. def __init__(self, template_config_path: str):
  18. """
  19. 初始化模板应用器
  20. Args:
  21. template_config_path: 模板配置文件路径(人工标注的结果)
  22. """
  23. with open(template_config_path, 'r', encoding='utf-8') as f:
  24. self.template = json.load(f)
  25. # 🎯 从标注结果提取固定参数
  26. self.col_widths = self.template['col_widths']
  27. # 🔧 计算数据行的标准行高(排除表头)
  28. rows = self.template['rows']
  29. if len(rows) > 1:
  30. # 计算每行的实际高度
  31. row_heights = [row['y_end'] - row['y_start'] for row in rows]
  32. # 🎯 假设第一行是表头,从第二行开始计算
  33. data_row_heights = row_heights[1:] if len(row_heights) > 1 else row_heights
  34. # 使用中位数作为标准行高(更稳健)
  35. self.row_height = int(np.median(data_row_heights))
  36. self.header_height = row_heights[0] if row_heights else self.row_height
  37. print(f"📏 表头高度: {self.header_height}px")
  38. print(f"📏 数据行高度: {self.row_height}px")
  39. print(f" (从 {len(data_row_heights)} 行数据中计算,中位数)")
  40. else:
  41. # 兜底方案
  42. self.row_height = self.template.get('row_height', 60)
  43. self.header_height = self.row_height
  44. # 🎯 计算列的相对位置(从第一列开始的偏移量)
  45. self.col_offsets = [0]
  46. for width in self.col_widths:
  47. self.col_offsets.append(self.col_offsets[-1] + width)
  48. # 🎯 提取表头的Y坐标(作为参考)
  49. self.template_header_y = rows[0]['y_start'] if rows else 0
  50. print(f"\n✅ 加载模板配置:")
  51. print(f" 表头高度: {self.header_height}px")
  52. print(f" 数据行高度: {self.row_height}px")
  53. print(f" 列数: {len(self.col_widths)}")
  54. print(f" 列宽: {self.col_widths}")
  55. def detect_table_anchor(self, ocr_data: List[Dict]) -> Tuple[int, int]:
  56. """
  57. 检测表格的锚点位置(表头左上角)
  58. 策略:
  59. 1. 找到Y坐标最小的文本框(表头第一行)
  60. 2. 找到X坐标最小的文本框(第一列)
  61. Args:
  62. ocr_data: OCR识别结果
  63. Returns:
  64. (anchor_x, anchor_y): 表格左上角坐标
  65. """
  66. if not ocr_data:
  67. return (0, 0)
  68. # 找到最小的X和Y坐标
  69. min_x = min(item['bbox'][0] for item in ocr_data)
  70. min_y = min(item['bbox'][1] for item in ocr_data)
  71. return (min_x, min_y)
  72. def detect_table_rows(self, ocr_data: List[Dict], header_y: int) -> int:
  73. """
  74. 检测表格的行数(包括表头)
  75. 策略:
  76. 1. 找到Y坐标最大的文本框
  77. 2. 根据数据行高计算行数
  78. 3. 加上表头行
  79. Args:
  80. ocr_data: OCR识别结果
  81. header_y: 表头起始Y坐标
  82. Returns:
  83. 总行数(包括表头)
  84. """
  85. if not ocr_data:
  86. return 1 # 至少有表头
  87. max_y = max(item['bbox'][3] for item in ocr_data)
  88. # 🔧 计算数据区的高度(排除表头)
  89. data_start_y = header_y + self.header_height
  90. data_height = max_y - data_start_y
  91. # 计算数据行数
  92. num_data_rows = max(int(data_height / self.row_height), 0)
  93. # 总行数 = 1行表头 + n行数据
  94. total_rows = 1 + num_data_rows
  95. print(f"📊 行数计算:")
  96. print(f" 表头Y: {header_y}, 数据区起始Y: {data_start_y}")
  97. print(f" 最大Y: {max_y}, 数据区高度: {data_height}px")
  98. print(f" 数据行数: {num_data_rows}, 总行数: {total_rows}")
  99. return total_rows
  100. def apply_to_image(self,
  101. image: Image.Image,
  102. ocr_data: List[Dict],
  103. anchor_x: int = None,
  104. anchor_y: int = None,
  105. num_rows: int = None,
  106. line_width: int = 2,
  107. line_color: Tuple[int, int, int] = (0, 0, 0)) -> Image.Image:
  108. """
  109. 将模板应用到图片
  110. Args:
  111. image: 目标图片
  112. ocr_data: OCR识别结果(用于自动检测锚点)
  113. anchor_x: 表格起始X坐标(None=自动检测)
  114. anchor_y: 表头起始Y坐标(None=自动检测)
  115. num_rows: 总行数(None=自动检测)
  116. line_width: 线条宽度
  117. line_color: 线条颜色
  118. Returns:
  119. 绘制了表格线的图片
  120. """
  121. img_with_lines = image.copy()
  122. draw = ImageDraw.Draw(img_with_lines)
  123. # 🔍 自动检测锚点
  124. if anchor_x is None or anchor_y is None:
  125. detected_x, detected_y = self.detect_table_anchor(ocr_data)
  126. anchor_x = anchor_x or detected_x
  127. anchor_y = anchor_y or detected_y
  128. # 🔍 自动检测行数
  129. if num_rows is None:
  130. num_rows = self.detect_table_rows(ocr_data, anchor_y)
  131. print(f"\n📍 表格锚点: ({anchor_x}, {anchor_y})")
  132. print(f"📊 总行数: {num_rows} (1表头 + {num_rows-1}数据)")
  133. # 🎨 生成横线坐标
  134. horizontal_lines = []
  135. # 第1条线:表头顶部
  136. horizontal_lines.append(anchor_y)
  137. # 第2条线:表头底部/数据区顶部
  138. horizontal_lines.append(anchor_y + self.header_height)
  139. # 后续横线:数据行分隔线
  140. current_y = anchor_y + self.header_height
  141. for i in range(num_rows - 1): # 减1因为表头已经占了1行
  142. current_y += self.row_height
  143. horizontal_lines.append(current_y)
  144. # 🎨 生成竖线坐标
  145. vertical_lines = []
  146. for offset in self.col_offsets:
  147. x = anchor_x + offset
  148. vertical_lines.append(x)
  149. print(f"📏 横线坐标: {horizontal_lines[:3]}... (共{len(horizontal_lines)}条)")
  150. print(f"📏 竖线坐标: {vertical_lines[:3]}... (共{len(vertical_lines)}条)")
  151. # 🖊️ 绘制横线
  152. x_start = vertical_lines[0]
  153. x_end = vertical_lines[-1]
  154. for y in horizontal_lines:
  155. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  156. # 🖊️ 绘制竖线
  157. y_start = horizontal_lines[0]
  158. y_end = horizontal_lines[-1]
  159. for x in vertical_lines:
  160. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  161. return img_with_lines
  162. def generate_structure_for_image(self,
  163. ocr_data: List[Dict],
  164. anchor_x: int = None,
  165. anchor_y: int = None,
  166. num_rows: int = None) -> Dict:
  167. """
  168. 为新图片生成表格结构配置
  169. Args:
  170. ocr_data: OCR识别结果
  171. anchor_x: 表格起始X坐标(None=自动检测)
  172. anchor_y: 表头起始Y坐标(None=自动检测)
  173. num_rows: 总行数(None=自动检测)
  174. Returns:
  175. 表格结构配置
  176. """
  177. # 🔍 自动检测锚点
  178. if anchor_x is None or anchor_y is None:
  179. detected_x, detected_y = self.detect_table_anchor(ocr_data)
  180. anchor_x = anchor_x or detected_x
  181. anchor_y = anchor_y or detected_y
  182. # 🔍 自动检测行数
  183. if num_rows is None:
  184. num_rows = self.detect_table_rows(ocr_data, anchor_y)
  185. # 🎨 生成横线坐标
  186. horizontal_lines = []
  187. horizontal_lines.append(anchor_y)
  188. horizontal_lines.append(anchor_y + self.header_height)
  189. current_y = anchor_y + self.header_height
  190. for i in range(num_rows - 1):
  191. current_y += self.row_height
  192. horizontal_lines.append(current_y)
  193. # 🎨 生成竖线坐标
  194. vertical_lines = []
  195. for offset in self.col_offsets:
  196. x = anchor_x + offset
  197. vertical_lines.append(x)
  198. # 🎨 生成行区间
  199. rows = []
  200. for i in range(num_rows):
  201. rows.append({
  202. 'y_start': horizontal_lines[i],
  203. 'y_end': horizontal_lines[i + 1],
  204. 'bboxes': []
  205. })
  206. # 🎨 生成列区间
  207. columns = []
  208. for i in range(len(vertical_lines) - 1):
  209. columns.append({
  210. 'x_start': vertical_lines[i],
  211. 'x_end': vertical_lines[i + 1]
  212. })
  213. return {
  214. 'rows': rows,
  215. 'columns': columns,
  216. 'horizontal_lines': horizontal_lines,
  217. 'vertical_lines': vertical_lines,
  218. 'header_height': self.header_height,
  219. 'row_height': self.row_height,
  220. 'col_widths': self.col_widths,
  221. 'table_bbox': [
  222. vertical_lines[0],
  223. horizontal_lines[0],
  224. vertical_lines[-1],
  225. horizontal_lines[-1]
  226. ],
  227. 'anchor': {'x': anchor_x, 'y': anchor_y},
  228. 'num_rows': num_rows
  229. }
  230. def apply_template_to_single_file(
  231. applier: TableTemplateApplier,
  232. image_file: Path,
  233. json_file: Path,
  234. output_dir: Path,
  235. line_width: int = 2,
  236. line_color: Tuple[int, int, int] = (0, 0, 0)
  237. ) -> bool:
  238. """
  239. 应用模板到单个文件
  240. Args:
  241. applier: 模板应用器实例
  242. image_file: 图片文件路径
  243. json_file: OCR JSON文件路径
  244. output_dir: 输出目录
  245. line_width: 线条宽度
  246. line_color: 线条颜色
  247. Returns:
  248. 是否成功
  249. """
  250. print(f"📄 处理: {image_file.name}")
  251. try:
  252. # 加载OCR数据
  253. with open(json_file, 'r', encoding='utf-8') as f:
  254. raw_data = json.load(f)
  255. # 🔧 解析OCR数据(支持PPStructure格式)
  256. if 'parsing_res_list' in raw_data and 'overall_ocr_res' in raw_data:
  257. table_bbox, ocr_data = TableLineGenerator.parse_ppstructure_result(raw_data)
  258. else:
  259. raise ValueError("不是PPStructure格式的OCR结果")
  260. print(f" ✅ 加载OCR数据: {len(ocr_data)} 个文本框")
  261. # 加载图片
  262. image = Image.open(image_file)
  263. print(f" ✅ 加载图片: {image.size}")
  264. # 🎯 应用模板
  265. img_with_lines = applier.apply_to_image(
  266. image,
  267. ocr_data,
  268. line_width=line_width,
  269. line_color=line_color
  270. )
  271. # 保存图片
  272. output_file = output_dir / f"{image_file.stem}_with_lines.png"
  273. img_with_lines.save(output_file)
  274. # 🆕 生成并保存结构配置
  275. structure = applier.generate_structure_for_image(ocr_data)
  276. structure_file = output_dir / f"{image_file.stem}_structure.json"
  277. with open(structure_file, 'w', encoding='utf-8') as f:
  278. json.dump(structure, f, indent=2, ensure_ascii=False)
  279. print(f" ✅ 保存图片: {output_file.name}")
  280. print(f" ✅ 保存配置: {structure_file.name}")
  281. print(f" 📊 表格: {structure['num_rows']}行 x {len(structure['columns'])}列")
  282. return True
  283. except Exception as e:
  284. print(f" ❌ 处理失败: {e}")
  285. import traceback
  286. traceback.print_exc()
  287. return False
  288. def apply_template_batch(
  289. template_config_path: str,
  290. image_dir: str,
  291. json_dir: str,
  292. output_dir: str,
  293. line_width: int = 2,
  294. line_color: Tuple[int, int, int] = (0, 0, 0)
  295. ):
  296. """
  297. 批量应用模板到所有图片
  298. Args:
  299. template_config_path: 模板配置路径
  300. image_dir: 图片目录
  301. json_dir: OCR JSON目录
  302. output_dir: 输出目录
  303. line_width: 线条宽度
  304. line_color: 线条颜色
  305. """
  306. applier = TableTemplateApplier(template_config_path)
  307. image_path = Path(image_dir)
  308. json_path = Path(json_dir)
  309. output_path = Path(output_dir)
  310. output_path.mkdir(parents=True, exist_ok=True)
  311. # 查找所有图片
  312. image_files = list(image_path.glob("*.jpg")) + list(image_path.glob("*.png"))
  313. image_files.sort()
  314. print(f"\n🔍 找到 {len(image_files)} 个图片文件")
  315. print(f"📂 图片目录: {image_dir}")
  316. print(f"📂 JSON目录: {json_dir}")
  317. print(f"📂 输出目录: {output_dir}\n")
  318. results = []
  319. success_count = 0
  320. failed_count = 0
  321. for idx, image_file in enumerate(image_files, 1):
  322. print(f"\n{'='*60}")
  323. print(f"[{idx}/{len(image_files)}] 处理: {image_file.name}")
  324. print(f"{'='*60}")
  325. # 查找对应的JSON文件
  326. json_file = json_path / f"{image_file.stem}.json"
  327. if not json_file.exists():
  328. print(f"⚠️ 找不到OCR结果: {json_file.name}")
  329. results.append({
  330. 'source': str(image_file),
  331. 'status': 'skipped',
  332. 'reason': 'no_json'
  333. })
  334. failed_count += 1
  335. continue
  336. if apply_template_to_single_file(
  337. applier, image_file, json_file, output_path,
  338. line_width, line_color
  339. ):
  340. results.append({
  341. 'source': str(image_file),
  342. 'json': str(json_file),
  343. 'status': 'success'
  344. })
  345. success_count += 1
  346. else:
  347. results.append({
  348. 'source': str(image_file),
  349. 'json': str(json_file),
  350. 'status': 'error'
  351. })
  352. failed_count += 1
  353. print()
  354. # 保存批处理结果
  355. result_file = output_path / "batch_results.json"
  356. with open(result_file, 'w', encoding='utf-8') as f:
  357. json.dump(results, f, indent=2, ensure_ascii=False)
  358. # 统计
  359. skipped_count = sum(1 for r in results if r['status'] == 'skipped')
  360. print(f"\n{'='*60}")
  361. print(f"🎉 批处理完成!")
  362. print(f"{'='*60}")
  363. print(f"✅ 成功: {success_count}")
  364. print(f"❌ 失败: {failed_count}")
  365. print(f"⚠️ 跳过: {skipped_count}")
  366. print(f"📊 总计: {len(results)}")
  367. print(f"📄 结果保存: {result_file}")
  368. def main():
  369. """主函数"""
  370. parser = argparse.ArgumentParser(
  371. description='应用表格模板到其他页面',
  372. formatter_class=argparse.RawDescriptionHelpFormatter,
  373. epilog="""
  374. 示例用法:
  375. 1. 批量处理整个目录:
  376. python table_template_applier.py \\
  377. --template output/康强_北京农村商业银行_page_001_structure.json \\
  378. --image-dir /path/to/images \\
  379. --json-dir /path/to/jsons \\
  380. --output-dir /path/to/output
  381. 2. 处理单个文件:
  382. python table_template_applier.py \\
  383. --template output/康强_北京农村商业银行_page_001_structure.json \\
  384. --image-file /path/to/page_002.png \\
  385. --json-file /path/to/page_002.json \\
  386. --output-dir /path/to/output
  387. 输出内容:
  388. - {name}_with_lines.png: 带表格线的图片
  389. - {name}_structure.json: 表格结构配置
  390. - batch_results.json: 批处理统计结果
  391. """
  392. )
  393. # 模板参数
  394. parser.add_argument(
  395. '-t', '--template',
  396. type=str,
  397. required=True,
  398. help='模板配置文件路径(人工标注的第一页结构)'
  399. )
  400. # 文件参数组
  401. file_group = parser.add_argument_group('文件参数(单文件模式)')
  402. file_group.add_argument(
  403. '--image-file',
  404. type=str,
  405. help='图片文件路径'
  406. )
  407. file_group.add_argument(
  408. '--json-file',
  409. type=str,
  410. help='OCR JSON文件路径'
  411. )
  412. # 目录参数组
  413. dir_group = parser.add_argument_group('目录参数(批量模式)')
  414. dir_group.add_argument(
  415. '--image-dir',
  416. type=str,
  417. help='图片目录'
  418. )
  419. dir_group.add_argument(
  420. '--json-dir',
  421. type=str,
  422. help='OCR JSON目录'
  423. )
  424. # 输出参数组
  425. output_group = parser.add_argument_group('输出参数')
  426. output_group.add_argument(
  427. '-o', '--output-dir',
  428. type=str,
  429. required=True,
  430. help='输出目录(必需)'
  431. )
  432. # 绘图参数组
  433. draw_group = parser.add_argument_group('绘图参数')
  434. draw_group.add_argument(
  435. '-w', '--width',
  436. type=int,
  437. default=2,
  438. help='线条宽度(默认: 2)'
  439. )
  440. draw_group.add_argument(
  441. '-c', '--color',
  442. default='black',
  443. choices=['black', 'blue', 'red'],
  444. help='线条颜色(默认: black)'
  445. )
  446. args = parser.parse_args()
  447. # 颜色映射
  448. color_map = {
  449. 'black': (0, 0, 0),
  450. 'blue': (0, 0, 255),
  451. 'red': (255, 0, 0)
  452. }
  453. line_color = color_map[args.color]
  454. # 验证模板文件
  455. template_path = Path(args.template)
  456. if not template_path.exists():
  457. print(f"❌ 错误: 模板文件不存在: {template_path}")
  458. return
  459. output_path = Path(args.output_dir)
  460. output_path.mkdir(parents=True, exist_ok=True)
  461. # 判断模式
  462. if args.image_file and args.json_file:
  463. # 单文件模式
  464. image_file = Path(args.image_file)
  465. json_file = Path(args.json_file)
  466. if not image_file.exists():
  467. print(f"❌ 错误: 图片文件不存在: {image_file}")
  468. return
  469. if not json_file.exists():
  470. print(f"❌ 错误: JSON文件不存在: {json_file}")
  471. return
  472. print("\n🔧 单文件处理模式")
  473. print(f"📄 模板: {template_path.name}")
  474. print(f"📄 图片: {image_file.name}")
  475. print(f"📄 JSON: {json_file.name}")
  476. print(f"📂 输出: {output_path}\n")
  477. applier = TableTemplateApplier(str(template_path))
  478. success = apply_template_to_single_file(
  479. applier, image_file, json_file, output_path,
  480. args.width, line_color
  481. )
  482. if success:
  483. print("\n✅ 处理完成!")
  484. else:
  485. print("\n❌ 处理失败!")
  486. elif args.image_dir and args.json_dir:
  487. # 批量模式
  488. image_dir = Path(args.image_dir)
  489. json_dir = Path(args.json_dir)
  490. if not image_dir.exists():
  491. print(f"❌ 错误: 图片目录不存在: {image_dir}")
  492. return
  493. if not json_dir.exists():
  494. print(f"❌ 错误: JSON目录不存在: {json_dir}")
  495. return
  496. print("\n🔧 批量处理模式")
  497. print(f"📄 模板: {template_path.name}")
  498. apply_template_batch(
  499. str(template_path),
  500. str(image_dir),
  501. str(json_dir),
  502. str(output_path),
  503. args.width,
  504. line_color
  505. )
  506. else:
  507. parser.print_help()
  508. print("\n❌ 错误: 请指定单文件模式或批量模式的参数")
  509. print("\n提示:")
  510. print(" 单文件模式: --image-file + --json-file")
  511. print(" 批量模式: --image-dir + --json-dir")
  512. if __name__ == "__main__":
  513. print("🚀 启动表格模板批量应用程序...")
  514. import sys
  515. if len(sys.argv) == 1:
  516. # 如果没有命令行参数,使用默认配置运行
  517. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  518. # 默认配置
  519. default_config = {
  520. "template": "output/table_structures/康强_北京农村商业银行_page_001_structure.json",
  521. "image-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行/康强_北京农村商业银行_page_002.png",
  522. "json-file": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行/ppstructurev3_client_results/康强_北京农村商业银行_page_002.json",
  523. "output-dir": "output/batch_results",
  524. "width": "2",
  525. "color": "black"
  526. }
  527. print("⚙️ 默认参数:")
  528. for key, value in default_config.items():
  529. print(f" --{key}: {value}")
  530. # 构造参数
  531. sys.argv = [sys.argv[0]]
  532. for key, value in default_config.items():
  533. sys.argv.extend([f"--{key}", str(value)])
  534. sys.exit(main())