batch_processor.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. """
  2. 批量处理器
  3. 首页标注 -> 自动传播到所有页
  4. 适用于几十页甚至上百页的交易流水快速处理
  5. """
  6. import json
  7. import cv2
  8. from pathlib import Path
  9. from typing import List, Dict, Tuple, Optional, Union
  10. from dataclasses import dataclass, asdict
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. from .smart_generator import SmartTableLineGenerator, TableStructure
  13. from .column_detector import ColumnRegion
  14. from .adaptive_row_splitter import AdaptiveRowSplitter, RowRegion
  15. @dataclass
  16. class TableTemplate:
  17. """
  18. 表格模板 (从首页学习)
  19. 固定部分:列边界、表格水平范围
  20. 可变部分:行分割(每页自适应计算)
  21. """
  22. columns: List[Tuple[int, int]] # 列边界 [(x_left, x_right), ...]
  23. table_x_range: Tuple[int, int] # 表格水平范围 (x_left, x_right)
  24. header_y_bottom: int # 表头底部 Y 坐标 (首页)
  25. page_header_height: int # 页眉高度 (非首页)
  26. page_footer_height: int # 页脚高度
  27. row_splitter_params: Dict # 行分割参数
  28. def to_dict(self) -> Dict:
  29. return asdict(self)
  30. @classmethod
  31. def from_dict(cls, data: Dict) -> 'TableTemplate':
  32. return cls(**data)
  33. def save(self, path: Union[str, Path]):
  34. with open(path, 'w', encoding='utf-8') as f:
  35. json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
  36. @classmethod
  37. def load(cls, path: Union[str, Path]) -> 'TableTemplate':
  38. with open(path, 'r', encoding='utf-8') as f:
  39. data = json.load(f)
  40. return cls.from_dict(data)
  41. @dataclass
  42. class PageResult:
  43. """单页处理结果"""
  44. page_index: int
  45. structure: TableStructure
  46. table_data: List[List[str]]
  47. success: bool
  48. error_message: Optional[str] = None
  49. class BatchTableProcessor:
  50. """
  51. 批量表格处理器
  52. 工作流程:
  53. 1. 从首页学习模板(列边界、表格区域)
  54. 2. 将模板应用到所有页(列边界复用,行分割自适应)
  55. 3. 并行处理提升效率
  56. """
  57. def __init__(self,
  58. generator: Optional[SmartTableLineGenerator] = None,
  59. max_workers: int = 4):
  60. self.generator = generator or SmartTableLineGenerator()
  61. self.max_workers = max_workers
  62. def learn_template(self,
  63. first_page_ocr: List[Dict],
  64. page_size: Tuple[int, int],
  65. table_region: Optional[Tuple[int, int, int, int]] = None,
  66. header_row_count: int = 1,
  67. page_header_height: int = 80,
  68. page_footer_height: int = 50) -> TableTemplate:
  69. """
  70. 从首页学习模板
  71. Args:
  72. first_page_ocr: 首页 OCR 结果
  73. page_size: 页面尺寸
  74. table_region: 表格区域 (可选)
  75. header_row_count: 表头行数
  76. page_header_height: 非首页的页眉高度
  77. page_footer_height: 页脚高度
  78. Returns:
  79. 表格模板
  80. """
  81. # 1. 生成首页结构
  82. structure, _ = self.generator.generate(
  83. first_page_ocr, page_size, table_region
  84. )
  85. # 2. 提取列边界
  86. columns = [(c.x_left, c.x_right) for c in structure.columns]
  87. # 3. 确定表格水平范围
  88. if columns:
  89. table_x_range = (columns[0][0], columns[-1][1])
  90. else:
  91. table_x_range = (0, page_size[0])
  92. # 4. 确定表头底部(跳过表头行)
  93. if structure.rows and header_row_count > 0:
  94. if header_row_count < len(structure.rows):
  95. header_y_bottom = structure.rows[header_row_count - 1].y_bottom
  96. else:
  97. header_y_bottom = structure.rows[0].y_bottom
  98. else:
  99. header_y_bottom = structure.table_region[1] + 50
  100. # 5. 创建模板
  101. template = TableTemplate(
  102. columns=columns,
  103. table_x_range=table_x_range,
  104. header_y_bottom=header_y_bottom,
  105. page_header_height=page_header_height,
  106. page_footer_height=page_footer_height,
  107. row_splitter_params={
  108. 'min_gap_height': 6,
  109. 'density_threshold': 0.05,
  110. 'min_row_height': 15
  111. }
  112. )
  113. print(f"📐 模板学习完成: {len(columns)} 列, 表头底部 Y={header_y_bottom}")
  114. return template
  115. def apply_template(self,
  116. ocr_boxes: List[Dict],
  117. page_size: Tuple[int, int],
  118. template: TableTemplate,
  119. page_index: int) -> PageResult:
  120. """
  121. 将模板应用到某一页
  122. Args:
  123. ocr_boxes: 该页 OCR 结果
  124. page_size: 页面尺寸
  125. template: 表格模板
  126. page_index: 页码 (0-indexed)
  127. Returns:
  128. PageResult
  129. """
  130. try:
  131. width, height = page_size
  132. # 1. 确定表格区域
  133. if page_index == 0:
  134. # 首页:从表头底部开始
  135. table_top = template.header_y_bottom
  136. else:
  137. # 非首页:跳过页眉
  138. table_top = self._detect_content_top(
  139. ocr_boxes, template, template.page_header_height
  140. )
  141. table_bottom = height - template.page_footer_height
  142. table_region = (
  143. template.table_x_range[0],
  144. table_top,
  145. template.table_x_range[1],
  146. table_bottom
  147. )
  148. # 2. 复用列边界
  149. columns = [
  150. ColumnRegion(x_left=c[0], x_right=c[1], column_index=i)
  151. for i, c in enumerate(template.columns)
  152. ]
  153. # 3. 自适应行分割
  154. row_splitter = AdaptiveRowSplitter(**template.row_splitter_params)
  155. rows, _ = row_splitter.split_rows(ocr_boxes, table_region)
  156. # 4. 构建结构
  157. structure = TableStructure(
  158. table_region=table_region,
  159. columns=columns,
  160. rows=rows,
  161. page_size=page_size
  162. )
  163. # 5. 构建表格数据
  164. table_data = self.generator.build_table_data(ocr_boxes, structure)
  165. return PageResult(
  166. page_index=page_index,
  167. structure=structure,
  168. table_data=table_data,
  169. success=True
  170. )
  171. except Exception as e:
  172. return PageResult(
  173. page_index=page_index,
  174. structure=None,
  175. table_data=[],
  176. success=False,
  177. error_message=str(e)
  178. )
  179. def process_document(self,
  180. pages_ocr: List[List[Dict]],
  181. page_sizes: List[Tuple[int, int]],
  182. template: Optional[TableTemplate] = None,
  183. first_page_table_region: Optional[Tuple[int, int, int, int]] = None,
  184. parallel: bool = True,
  185. progress_callback: Optional[callable] = None) -> List[PageResult]:
  186. """
  187. 处理整个文档
  188. Args:
  189. pages_ocr: 每页的 OCR 结果
  190. page_sizes: 每页的尺寸
  191. template: 模板 (如果为 None,从首页自动学习)
  192. first_page_table_region: 首页表格区域 (用于模板学习)
  193. parallel: 是否并行处理
  194. progress_callback: 进度回调函数 (page_index, total_pages)
  195. Returns:
  196. 每页的处理结果
  197. """
  198. total_pages = len(pages_ocr)
  199. print(f"📚 开始处理文档: {total_pages} 页")
  200. # 1. 学习模板(如果未提供)
  201. if template is None:
  202. template = self.learn_template(
  203. pages_ocr[0], page_sizes[0], first_page_table_region
  204. )
  205. # 2. 处理所有页
  206. results = [None] * total_pages
  207. if parallel and total_pages > 1:
  208. # 并行处理
  209. with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
  210. futures = {
  211. executor.submit(
  212. self.apply_template,
  213. pages_ocr[i], page_sizes[i], template, i
  214. ): i
  215. for i in range(total_pages)
  216. }
  217. for future in as_completed(futures):
  218. page_idx = futures[future]
  219. try:
  220. result = future.result()
  221. results[page_idx] = result
  222. if progress_callback:
  223. progress_callback(page_idx, total_pages)
  224. status = "✅" if result.success else "❌"
  225. rows = len(result.structure.rows) if result.structure else 0
  226. print(f" {status} 页 {page_idx + 1}: {rows} 行")
  227. except Exception as e:
  228. results[page_idx] = PageResult(
  229. page_index=page_idx,
  230. structure=None,
  231. table_data=[],
  232. success=False,
  233. error_message=str(e)
  234. )
  235. else:
  236. # 串行处理
  237. for i in range(total_pages):
  238. result = self.apply_template(pages_ocr[i], page_sizes[i], template, i)
  239. results[i] = result
  240. if progress_callback:
  241. progress_callback(i, total_pages)
  242. status = "✅" if result.success else "❌"
  243. rows = len(result.structure.rows) if result.structure else 0
  244. print(f" {status} 页 {i + 1}: {rows} 行")
  245. # 统计
  246. success_count = sum(1 for r in results if r.success)
  247. print(f"📊 处理完成: {success_count}/{total_pages} 页成功")
  248. return results
  249. def _detect_content_top(self,
  250. ocr_boxes: List[Dict],
  251. template: TableTemplate,
  252. default_top: int) -> int:
  253. """
  254. 检测内容区域顶部
  255. 对于非首页,找到第一个在表格水平范围内的 box
  256. """
  257. x_left, x_right = template.table_x_range
  258. for box in sorted(ocr_boxes, key=lambda b: b['bbox'][1]):
  259. x_center = (box['bbox'][0] + box['bbox'][2]) / 2
  260. if x_left <= x_center <= x_right:
  261. # 返回该 box 的顶部 - margin
  262. return max(0, box['bbox'][1] - 5)
  263. return default_top
  264. def export_results(self,
  265. results: List[PageResult],
  266. output_dir: Union[str, Path],
  267. prefix: str = "page") -> List[Path]:
  268. """
  269. 导出处理结果
  270. Returns:
  271. 导出的文件路径列表
  272. """
  273. output_dir = Path(output_dir)
  274. output_dir.mkdir(parents=True, exist_ok=True)
  275. exported_files = []
  276. for result in results:
  277. if not result.success:
  278. continue
  279. filename = f"{prefix}_{result.page_index + 1:03d}_structure.json"
  280. filepath = output_dir / filename
  281. data = {
  282. 'page_index': result.page_index,
  283. 'structure': result.structure.to_dict(),
  284. 'table_data': result.table_data
  285. }
  286. with open(filepath, 'w', encoding='utf-8') as f:
  287. json.dump(data, f, ensure_ascii=False, indent=2)
  288. exported_files.append(filepath)
  289. print(f"📁 导出 {len(exported_files)} 个文件到 {output_dir}")
  290. return exported_files
  291. def draw_all_pages(self,
  292. image_paths: List[Union[str, Path]],
  293. results: List[PageResult],
  294. output_dir: Union[str, Path],
  295. line_color: Tuple[int, int, int] = (0, 0, 255),
  296. line_thickness: int = 1) -> List[Path]:
  297. """
  298. 在所有页面上绘制表格线
  299. Returns:
  300. 绘制后图片的路径列表
  301. """
  302. output_dir = Path(output_dir)
  303. output_dir.mkdir(parents=True, exist_ok=True)
  304. output_paths = []
  305. for image_path, result in zip(image_paths, results):
  306. if not result.success:
  307. continue
  308. image = cv2.imread(str(image_path))
  309. if image is None:
  310. continue
  311. # 绘制表格线
  312. drawn = self.generator.draw_table_lines(
  313. image, result.structure, line_color, line_thickness
  314. )
  315. # 保存
  316. output_filename = Path(image_path).stem + "_lined.png"
  317. output_path = output_dir / output_filename
  318. cv2.imwrite(str(output_path), drawn)
  319. output_paths.append(output_path)
  320. print(f"🖼️ 绘制 {len(output_paths)} 张图片到 {output_dir}")
  321. return output_paths