""" 批量处理器 首页标注 -> 自动传播到所有页 适用于几十页甚至上百页的交易流水快速处理 """ import json import cv2 from pathlib import Path from typing import List, Dict, Tuple, Optional, Union from dataclasses import dataclass, asdict from concurrent.futures import ThreadPoolExecutor, as_completed from .smart_generator import SmartTableLineGenerator, TableStructure from .column_detector import ColumnRegion from .adaptive_row_splitter import AdaptiveRowSplitter, RowRegion @dataclass class TableTemplate: """ 表格模板 (从首页学习) 固定部分:列边界、表格水平范围 可变部分:行分割(每页自适应计算) """ columns: List[Tuple[int, int]] # 列边界 [(x_left, x_right), ...] table_x_range: Tuple[int, int] # 表格水平范围 (x_left, x_right) header_y_bottom: int # 表头底部 Y 坐标 (首页) page_header_height: int # 页眉高度 (非首页) page_footer_height: int # 页脚高度 row_splitter_params: Dict # 行分割参数 def to_dict(self) -> Dict: return asdict(self) @classmethod def from_dict(cls, data: Dict) -> 'TableTemplate': return cls(**data) def save(self, path: Union[str, Path]): with open(path, 'w', encoding='utf-8') as f: json.dump(self.to_dict(), f, ensure_ascii=False, indent=2) @classmethod def load(cls, path: Union[str, Path]) -> 'TableTemplate': with open(path, 'r', encoding='utf-8') as f: data = json.load(f) return cls.from_dict(data) @dataclass class PageResult: """单页处理结果""" page_index: int structure: TableStructure table_data: List[List[str]] success: bool error_message: Optional[str] = None class BatchTableProcessor: """ 批量表格处理器 工作流程: 1. 从首页学习模板(列边界、表格区域) 2. 将模板应用到所有页(列边界复用,行分割自适应) 3. 并行处理提升效率 """ def __init__(self, generator: Optional[SmartTableLineGenerator] = None, max_workers: int = 4): self.generator = generator or SmartTableLineGenerator() self.max_workers = max_workers def learn_template(self, first_page_ocr: List[Dict], page_size: Tuple[int, int], table_region: Optional[Tuple[int, int, int, int]] = None, header_row_count: int = 1, page_header_height: int = 80, page_footer_height: int = 50) -> TableTemplate: """ 从首页学习模板 Args: first_page_ocr: 首页 OCR 结果 page_size: 页面尺寸 table_region: 表格区域 (可选) header_row_count: 表头行数 page_header_height: 非首页的页眉高度 page_footer_height: 页脚高度 Returns: 表格模板 """ # 1. 生成首页结构 structure, _ = self.generator.generate( first_page_ocr, page_size, table_region ) # 2. 提取列边界 columns = [(c.x_left, c.x_right) for c in structure.columns] # 3. 确定表格水平范围 if columns: table_x_range = (columns[0][0], columns[-1][1]) else: table_x_range = (0, page_size[0]) # 4. 确定表头底部(跳过表头行) if structure.rows and header_row_count > 0: if header_row_count < len(structure.rows): header_y_bottom = structure.rows[header_row_count - 1].y_bottom else: header_y_bottom = structure.rows[0].y_bottom else: header_y_bottom = structure.table_region[1] + 50 # 5. 创建模板 template = TableTemplate( columns=columns, table_x_range=table_x_range, header_y_bottom=header_y_bottom, page_header_height=page_header_height, page_footer_height=page_footer_height, row_splitter_params={ 'min_gap_height': 6, 'density_threshold': 0.05, 'min_row_height': 15 } ) print(f"📐 模板学习完成: {len(columns)} 列, 表头底部 Y={header_y_bottom}") return template def apply_template(self, ocr_boxes: List[Dict], page_size: Tuple[int, int], template: TableTemplate, page_index: int) -> PageResult: """ 将模板应用到某一页 Args: ocr_boxes: 该页 OCR 结果 page_size: 页面尺寸 template: 表格模板 page_index: 页码 (0-indexed) Returns: PageResult """ try: width, height = page_size # 1. 确定表格区域 if page_index == 0: # 首页:从表头底部开始 table_top = template.header_y_bottom else: # 非首页:跳过页眉 table_top = self._detect_content_top( ocr_boxes, template, template.page_header_height ) table_bottom = height - template.page_footer_height table_region = ( template.table_x_range[0], table_top, template.table_x_range[1], table_bottom ) # 2. 复用列边界 columns = [ ColumnRegion(x_left=c[0], x_right=c[1], column_index=i) for i, c in enumerate(template.columns) ] # 3. 自适应行分割 row_splitter = AdaptiveRowSplitter(**template.row_splitter_params) rows, _ = row_splitter.split_rows(ocr_boxes, table_region) # 4. 构建结构 structure = TableStructure( table_region=table_region, columns=columns, rows=rows, page_size=page_size ) # 5. 构建表格数据 table_data = self.generator.build_table_data(ocr_boxes, structure) return PageResult( page_index=page_index, structure=structure, table_data=table_data, success=True ) except Exception as e: return PageResult( page_index=page_index, structure=None, table_data=[], success=False, error_message=str(e) ) def process_document(self, pages_ocr: List[List[Dict]], page_sizes: List[Tuple[int, int]], template: Optional[TableTemplate] = None, first_page_table_region: Optional[Tuple[int, int, int, int]] = None, parallel: bool = True, progress_callback: Optional[callable] = None) -> List[PageResult]: """ 处理整个文档 Args: pages_ocr: 每页的 OCR 结果 page_sizes: 每页的尺寸 template: 模板 (如果为 None,从首页自动学习) first_page_table_region: 首页表格区域 (用于模板学习) parallel: 是否并行处理 progress_callback: 进度回调函数 (page_index, total_pages) Returns: 每页的处理结果 """ total_pages = len(pages_ocr) print(f"📚 开始处理文档: {total_pages} 页") # 1. 学习模板(如果未提供) if template is None: template = self.learn_template( pages_ocr[0], page_sizes[0], first_page_table_region ) # 2. 处理所有页 results = [None] * total_pages if parallel and total_pages > 1: # 并行处理 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = { executor.submit( self.apply_template, pages_ocr[i], page_sizes[i], template, i ): i for i in range(total_pages) } for future in as_completed(futures): page_idx = futures[future] try: result = future.result() results[page_idx] = result if progress_callback: progress_callback(page_idx, total_pages) status = "✅" if result.success else "❌" rows = len(result.structure.rows) if result.structure else 0 print(f" {status} 页 {page_idx + 1}: {rows} 行") except Exception as e: results[page_idx] = PageResult( page_index=page_idx, structure=None, table_data=[], success=False, error_message=str(e) ) else: # 串行处理 for i in range(total_pages): result = self.apply_template(pages_ocr[i], page_sizes[i], template, i) results[i] = result if progress_callback: progress_callback(i, total_pages) status = "✅" if result.success else "❌" rows = len(result.structure.rows) if result.structure else 0 print(f" {status} 页 {i + 1}: {rows} 行") # 统计 success_count = sum(1 for r in results if r.success) print(f"📊 处理完成: {success_count}/{total_pages} 页成功") return results def _detect_content_top(self, ocr_boxes: List[Dict], template: TableTemplate, default_top: int) -> int: """ 检测内容区域顶部 对于非首页,找到第一个在表格水平范围内的 box """ x_left, x_right = template.table_x_range for box in sorted(ocr_boxes, key=lambda b: b['bbox'][1]): x_center = (box['bbox'][0] + box['bbox'][2]) / 2 if x_left <= x_center <= x_right: # 返回该 box 的顶部 - margin return max(0, box['bbox'][1] - 5) return default_top def export_results(self, results: List[PageResult], output_dir: Union[str, Path], prefix: str = "page") -> List[Path]: """ 导出处理结果 Returns: 导出的文件路径列表 """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) exported_files = [] for result in results: if not result.success: continue filename = f"{prefix}_{result.page_index + 1:03d}_structure.json" filepath = output_dir / filename data = { 'page_index': result.page_index, 'structure': result.structure.to_dict(), 'table_data': result.table_data } with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) exported_files.append(filepath) print(f"📁 导出 {len(exported_files)} 个文件到 {output_dir}") return exported_files def draw_all_pages(self, image_paths: List[Union[str, Path]], results: List[PageResult], output_dir: Union[str, Path], line_color: Tuple[int, int, int] = (0, 0, 255), line_thickness: int = 1) -> List[Path]: """ 在所有页面上绘制表格线 Returns: 绘制后图片的路径列表 """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) output_paths = [] for image_path, result in zip(image_paths, results): if not result.success: continue image = cv2.imread(str(image_path)) if image is None: continue # 绘制表格线 drawn = self.generator.draw_table_lines( image, result.structure, line_color, line_thickness ) # 保存 output_filename = Path(image_path).stem + "_lined.png" output_path = output_dir / output_filename cv2.imwrite(str(output_path), drawn) output_paths.append(output_path) print(f"🖼️ 绘制 {len(output_paths)} 张图片到 {output_dir}") return output_paths