| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- """
- 批量处理器
- 首页标注 -> 自动传播到所有页
- 适用于几十页甚至上百页的交易流水快速处理
- """
- import json
- import cv2
- from pathlib import Path
- from typing import List, Dict, Tuple, Optional, Union
- from dataclasses import dataclass, asdict
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from .smart_generator import SmartTableLineGenerator, TableStructure
- from .column_detector import ColumnRegion
- from .adaptive_row_splitter import AdaptiveRowSplitter, RowRegion
- @dataclass
- class TableTemplate:
- """
- 表格模板 (从首页学习)
-
- 固定部分:列边界、表格水平范围
- 可变部分:行分割(每页自适应计算)
- """
- columns: List[Tuple[int, int]] # 列边界 [(x_left, x_right), ...]
- table_x_range: Tuple[int, int] # 表格水平范围 (x_left, x_right)
- header_y_bottom: int # 表头底部 Y 坐标 (首页)
- page_header_height: int # 页眉高度 (非首页)
- page_footer_height: int # 页脚高度
- row_splitter_params: Dict # 行分割参数
-
- def to_dict(self) -> Dict:
- return asdict(self)
-
- @classmethod
- def from_dict(cls, data: Dict) -> 'TableTemplate':
- return cls(**data)
-
- def save(self, path: Union[str, Path]):
- with open(path, 'w', encoding='utf-8') as f:
- json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
-
- @classmethod
- def load(cls, path: Union[str, Path]) -> 'TableTemplate':
- with open(path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- return cls.from_dict(data)
- @dataclass
- class PageResult:
- """单页处理结果"""
- page_index: int
- structure: TableStructure
- table_data: List[List[str]]
- success: bool
- error_message: Optional[str] = None
- class BatchTableProcessor:
- """
- 批量表格处理器
-
- 工作流程:
- 1. 从首页学习模板(列边界、表格区域)
- 2. 将模板应用到所有页(列边界复用,行分割自适应)
- 3. 并行处理提升效率
- """
-
- def __init__(self,
- generator: Optional[SmartTableLineGenerator] = None,
- max_workers: int = 4):
- self.generator = generator or SmartTableLineGenerator()
- self.max_workers = max_workers
-
- def learn_template(self,
- first_page_ocr: List[Dict],
- page_size: Tuple[int, int],
- table_region: Optional[Tuple[int, int, int, int]] = None,
- header_row_count: int = 1,
- page_header_height: int = 80,
- page_footer_height: int = 50) -> TableTemplate:
- """
- 从首页学习模板
-
- Args:
- first_page_ocr: 首页 OCR 结果
- page_size: 页面尺寸
- table_region: 表格区域 (可选)
- header_row_count: 表头行数
- page_header_height: 非首页的页眉高度
- page_footer_height: 页脚高度
-
- Returns:
- 表格模板
- """
- # 1. 生成首页结构
- structure, _ = self.generator.generate(
- first_page_ocr, page_size, table_region
- )
-
- # 2. 提取列边界
- columns = [(c.x_left, c.x_right) for c in structure.columns]
-
- # 3. 确定表格水平范围
- if columns:
- table_x_range = (columns[0][0], columns[-1][1])
- else:
- table_x_range = (0, page_size[0])
-
- # 4. 确定表头底部(跳过表头行)
- if structure.rows and header_row_count > 0:
- if header_row_count < len(structure.rows):
- header_y_bottom = structure.rows[header_row_count - 1].y_bottom
- else:
- header_y_bottom = structure.rows[0].y_bottom
- else:
- header_y_bottom = structure.table_region[1] + 50
-
- # 5. 创建模板
- template = TableTemplate(
- columns=columns,
- table_x_range=table_x_range,
- header_y_bottom=header_y_bottom,
- page_header_height=page_header_height,
- page_footer_height=page_footer_height,
- row_splitter_params={
- 'min_gap_height': 6,
- 'density_threshold': 0.05,
- 'min_row_height': 15
- }
- )
-
- print(f"📐 模板学习完成: {len(columns)} 列, 表头底部 Y={header_y_bottom}")
-
- return template
-
- def apply_template(self,
- ocr_boxes: List[Dict],
- page_size: Tuple[int, int],
- template: TableTemplate,
- page_index: int) -> PageResult:
- """
- 将模板应用到某一页
-
- Args:
- ocr_boxes: 该页 OCR 结果
- page_size: 页面尺寸
- template: 表格模板
- page_index: 页码 (0-indexed)
-
- Returns:
- PageResult
- """
- try:
- width, height = page_size
-
- # 1. 确定表格区域
- if page_index == 0:
- # 首页:从表头底部开始
- table_top = template.header_y_bottom
- else:
- # 非首页:跳过页眉
- table_top = self._detect_content_top(
- ocr_boxes, template, template.page_header_height
- )
-
- table_bottom = height - template.page_footer_height
-
- table_region = (
- template.table_x_range[0],
- table_top,
- template.table_x_range[1],
- table_bottom
- )
-
- # 2. 复用列边界
- columns = [
- ColumnRegion(x_left=c[0], x_right=c[1], column_index=i)
- for i, c in enumerate(template.columns)
- ]
-
- # 3. 自适应行分割
- row_splitter = AdaptiveRowSplitter(**template.row_splitter_params)
- rows, _ = row_splitter.split_rows(ocr_boxes, table_region)
-
- # 4. 构建结构
- structure = TableStructure(
- table_region=table_region,
- columns=columns,
- rows=rows,
- page_size=page_size
- )
-
- # 5. 构建表格数据
- table_data = self.generator.build_table_data(ocr_boxes, structure)
-
- return PageResult(
- page_index=page_index,
- structure=structure,
- table_data=table_data,
- success=True
- )
-
- except Exception as e:
- return PageResult(
- page_index=page_index,
- structure=None,
- table_data=[],
- success=False,
- error_message=str(e)
- )
-
- def process_document(self,
- pages_ocr: List[List[Dict]],
- page_sizes: List[Tuple[int, int]],
- template: Optional[TableTemplate] = None,
- first_page_table_region: Optional[Tuple[int, int, int, int]] = None,
- parallel: bool = True,
- progress_callback: Optional[callable] = None) -> List[PageResult]:
- """
- 处理整个文档
-
- Args:
- pages_ocr: 每页的 OCR 结果
- page_sizes: 每页的尺寸
- template: 模板 (如果为 None,从首页自动学习)
- first_page_table_region: 首页表格区域 (用于模板学习)
- parallel: 是否并行处理
- progress_callback: 进度回调函数 (page_index, total_pages)
-
- Returns:
- 每页的处理结果
- """
- total_pages = len(pages_ocr)
- print(f"📚 开始处理文档: {total_pages} 页")
-
- # 1. 学习模板(如果未提供)
- if template is None:
- template = self.learn_template(
- pages_ocr[0], page_sizes[0], first_page_table_region
- )
-
- # 2. 处理所有页
- results = [None] * total_pages
-
- if parallel and total_pages > 1:
- # 并行处理
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
- futures = {
- executor.submit(
- self.apply_template,
- pages_ocr[i], page_sizes[i], template, i
- ): i
- for i in range(total_pages)
- }
-
- for future in as_completed(futures):
- page_idx = futures[future]
- try:
- result = future.result()
- results[page_idx] = result
-
- if progress_callback:
- progress_callback(page_idx, total_pages)
-
- status = "✅" if result.success else "❌"
- rows = len(result.structure.rows) if result.structure else 0
- print(f" {status} 页 {page_idx + 1}: {rows} 行")
-
- except Exception as e:
- results[page_idx] = PageResult(
- page_index=page_idx,
- structure=None,
- table_data=[],
- success=False,
- error_message=str(e)
- )
- else:
- # 串行处理
- for i in range(total_pages):
- result = self.apply_template(pages_ocr[i], page_sizes[i], template, i)
- results[i] = result
-
- if progress_callback:
- progress_callback(i, total_pages)
-
- status = "✅" if result.success else "❌"
- rows = len(result.structure.rows) if result.structure else 0
- print(f" {status} 页 {i + 1}: {rows} 行")
-
- # 统计
- success_count = sum(1 for r in results if r.success)
- print(f"📊 处理完成: {success_count}/{total_pages} 页成功")
-
- return results
-
- def _detect_content_top(self,
- ocr_boxes: List[Dict],
- template: TableTemplate,
- default_top: int) -> int:
- """
- 检测内容区域顶部
-
- 对于非首页,找到第一个在表格水平范围内的 box
- """
- x_left, x_right = template.table_x_range
-
- for box in sorted(ocr_boxes, key=lambda b: b['bbox'][1]):
- x_center = (box['bbox'][0] + box['bbox'][2]) / 2
- if x_left <= x_center <= x_right:
- # 返回该 box 的顶部 - margin
- return max(0, box['bbox'][1] - 5)
-
- return default_top
-
- def export_results(self,
- results: List[PageResult],
- output_dir: Union[str, Path],
- prefix: str = "page") -> List[Path]:
- """
- 导出处理结果
-
- Returns:
- 导出的文件路径列表
- """
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- exported_files = []
-
- for result in results:
- if not result.success:
- continue
-
- filename = f"{prefix}_{result.page_index + 1:03d}_structure.json"
- filepath = output_dir / filename
-
- data = {
- 'page_index': result.page_index,
- 'structure': result.structure.to_dict(),
- 'table_data': result.table_data
- }
-
- with open(filepath, 'w', encoding='utf-8') as f:
- json.dump(data, f, ensure_ascii=False, indent=2)
-
- exported_files.append(filepath)
-
- print(f"📁 导出 {len(exported_files)} 个文件到 {output_dir}")
-
- return exported_files
-
- def draw_all_pages(self,
- image_paths: List[Union[str, Path]],
- results: List[PageResult],
- output_dir: Union[str, Path],
- line_color: Tuple[int, int, int] = (0, 0, 255),
- line_thickness: int = 1) -> List[Path]:
- """
- 在所有页面上绘制表格线
-
- Returns:
- 绘制后图片的路径列表
- """
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- output_paths = []
-
- for image_path, result in zip(image_paths, results):
- if not result.success:
- continue
-
- image = cv2.imread(str(image_path))
- if image is None:
- continue
-
- # 绘制表格线
- drawn = self.generator.draw_table_lines(
- image, result.structure, line_color, line_thickness
- )
-
- # 保存
- output_filename = Path(image_path).stem + "_lined.png"
- output_path = output_dir / output_filename
- cv2.imwrite(str(output_path), drawn)
-
- output_paths.append(output_path)
-
- print(f"🖼️ 绘制 {len(output_paths)} 张图片到 {output_dir}")
-
- return output_paths
|