""" 批量处理服务 基于首页模板应用到多个文件,适用于多页银行流水等场景 """ import json import sys from pathlib import Path from typing import List, Dict, Optional, Tuple, Any, Callable from concurrent.futures import ThreadPoolExecutor, as_completed from loguru import logger from PIL import Image import io # 添加 ocr_platform 根目录到 Python 路径(用于导入 core 和 ocr_utils) _file_path = Path(__file__).resolve() ocr_platform_root = _file_path.parents[2] # batch_service.py -> services -> backend -> table_line_generator -> ocr_platform if str(ocr_platform_root) not in sys.path: sys.path.insert(0, str(ocr_platform_root)) from table_line_generator.core.table_analyzer import TableAnalyzer from table_line_generator.core.ocr_parser import OcrParser from table_line_generator.core.drawing_service import DrawingService from table_line_generator.backend.services.template_service import TemplateService class BatchProcessor: """ 批量表格处理器 工作流程: 1. 从首页学习模板结构(主要是竖线/列边界) 2. 将模板应用到所有页(竖线复用,横线自适应) 3. 支持并行处理提升效率 """ def __init__(self, max_workers: int = 4): """ 初始化批量处理器 Args: max_workers: 并行处理的最大线程数 """ self.max_workers = max_workers def learn_template_from_structure( self, first_page_structure: Dict ) -> Dict: """ 从首页结构中学习模板 Args: first_page_structure: 首页的表格结构 Returns: 模板字典,包含可复用的部分(主要是列边界) """ template = { 'vertical_lines': first_page_structure.get('vertical_lines', []).copy(), 'table_bbox': first_page_structure.get('table_bbox'), 'col_widths': first_page_structure.get('col_widths'), 'total_cols': first_page_structure.get('total_cols'), 'mode': first_page_structure.get('mode', 'hybrid') } logger.info(f"从首页学习模板: {len(template['vertical_lines'])} 条竖线, {template['total_cols']} 列") return template def apply_template_to_structure( self, template: Dict, target_ocr_data: Dict, adjust_rows: bool = True, y_tolerance: int = 5, min_row_height: int = 20 ) -> Dict: """ 将模板应用到目标页面的 OCR 数据 Args: template: 从首页学习的模板 target_ocr_data: 目标页面的 OCR 数据 adjust_rows: 是否自适应调整行分割 y_tolerance: Y轴聚类容差 min_row_height: 最小行高 Returns: 应用模板后的结构 """ # 创建分析器 analyzer = TableAnalyzer(None, target_ocr_data) if adjust_rows: # 重新分析行结构(自适应) analyzed = analyzer.analyze( y_tolerance=y_tolerance, min_row_height=min_row_height, method=template.get('mode', 'auto') ) # 复用模板的列信息 new_structure = { 'horizontal_lines': analyzed['horizontal_lines'], 'vertical_lines': template['vertical_lines'].copy(), 'table_bbox': template.get('table_bbox') or analyzed.get('table_bbox'), 'row_height': analyzed.get('row_height'), 'col_widths': template.get('col_widths'), 'total_rows': analyzed.get('total_rows'), 'total_cols': template.get('total_cols'), 'mode': template.get('mode'), 'modified_h_lines': [], 'modified_v_lines': [], 'image_rotation_angle': target_ocr_data.get('image_rotation_angle', 0.0), 'skew_angle': target_ocr_data.get('skew_angle', 0.0), 'is_skew_corrected': target_ocr_data.get('is_skew_corrected', False) } else: # 完全复用模板(包括横线) new_structure = template.copy() new_structure['image_rotation_angle'] = target_ocr_data.get('image_rotation_angle', 0.0) new_structure['skew_angle'] = target_ocr_data.get('skew_angle', 0.0) new_structure['is_skew_corrected'] = target_ocr_data.get('is_skew_corrected', False) return new_structure def process_batch_from_data_source( self, template_name: str, file_pairs: List[Dict], output_dir: str, parallel: bool = True, adjust_rows: bool = True, structure_suffix: str = "_structure.json", image_suffix: str = "_with_lines.png", progress_callback: Optional[Callable[[int, int], None]] = None ) -> Dict: """ 批量处理数据源中的文件 Args: template_name: 模板名称(从 TemplateService 加载) file_pairs: 文件对列表 [{'json_path': ..., 'image_path': ...}, ...] output_dir: 输出目录 parallel: 是否并行处理 adjust_rows: 是否自适应调整行分割 structure_suffix: 结构文件后缀 image_suffix: 输出图片后缀 progress_callback: 进度回调 callback(index, total) Returns: 处理结果摘要 """ total = len(file_pairs) results = [] # 加载模板 template_service = TemplateService() logger.info(f"开始批量处理: {total} 个文件, 使用模板: {template_name}, 并行={parallel}") if parallel and total > 1: # 并行处理 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = { executor.submit( self._process_single_file, pair, template_name, output_dir, adjust_rows, structure_suffix ): idx for idx, pair in enumerate(file_pairs) } for future in as_completed(futures): idx = futures[future] try: result = future.result() results.append(result) if progress_callback: progress_callback(idx, total) status = "✅" if result['success'] else "❌" logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}") except Exception as e: logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}") results.append({ 'success': False, 'error': str(e), 'index': idx }) else: # 串行处理 for idx, pair in enumerate(file_pairs): try: result = self._process_single_file( pair, template_name, output_dir, adjust_rows, structure_suffix ) results.append(result) if progress_callback: progress_callback(idx, total) status = "✅" if result['success'] else "❌" logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}") except Exception as e: logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}") results.append({ 'success': False, 'error': str(e), 'index': idx }) # 统计结果 success_count = sum(1 for r in results if r.get('success')) failed_count = total - success_count summary = { 'total': total, 'success': success_count, 'failed': failed_count, 'results': results } logger.info(f"📊 批量处理完成: 成功 {success_count}/{total}, 失败 {failed_count}") return summary def _process_single_file( self, file_pair: Dict, template_name: str, output_dir: str, adjust_rows: bool, structure_suffix: str = "_structure.json" ) -> Dict: """ 处理单个文件 Args: file_pair: 文件对 {'json_path': ..., 'image_path': ...} template_name: 模板名称(从 TemplateService 加载) output_dir: 输出目录 adjust_rows: 是否调整行 structure_suffix: 结构文件后缀 Returns: 处理结果 """ json_path = Path(file_pair['json_path']) image_path = Path(file_pair['image_path']) try: # 1. 读取 OCR 数据 with open(json_path, 'r', encoding='utf-8') as f: ocr_result = json.load(f) # 2. 解析 OCR 数据(获取目标页面的 table_bbox 和 ocr_data) target_table_bbox, ocr_data = OcrParser.parse(ocr_result) target_image_size = ocr_data.get('image_size', {'width': 1, 'height': 1}) # 3. 使用 TemplateService.preview_apply() 应用模板到目标页面 # 这会自动处理坐标映射,适配不同尺寸的图片 template_service = TemplateService() applied_template = template_service.preview_apply( template_name=template_name, target_image_size=target_image_size, target_table_bbox=target_table_bbox, mode='relative' # 使用相对坐标映射,适应不同尺寸图片 ) # 4. 构建最终结构(合并应用的模板和目标页面的元数据) if adjust_rows: # 如果启用自适应行,则自动分析目标页面的行结构 analyzer = TableAnalyzer(None, ocr_data) analyzed = analyzer.analyze() # 关键:使用模板的竖线(已通过 preview_apply 映射),结合目标的横线 new_structure = { 'horizontal_lines': analyzed['horizontal_lines'], # 自适应调整 'vertical_lines': applied_template['vertical_lines'], # 来自模板,已映射 'table_bbox': applied_template['table_bbox'], # 目标页面的 bbox 'row_height': analyzed.get('row_height'), 'col_widths': applied_template.get('col_widths'), 'total_rows': analyzed.get('total_rows'), 'total_cols': applied_template.get('total_cols'), 'mode': applied_template.get('mode', 'hybrid'), 'modified_h_lines': [], 'modified_v_lines': [], # 各页使用自己的旋转角度 'image_rotation_angle': ocr_data.get('image_rotation_angle', 0.0), 'skew_angle': ocr_data.get('skew_angle', 0.0), 'is_skew_corrected': ocr_data.get('is_skew_corrected', False) } else: # 完全复用应用的模板(包括竖线和横线) new_structure = applied_template.copy() new_structure['image_rotation_angle'] = ocr_data.get('image_rotation_angle', 0.0) new_structure['skew_angle'] = ocr_data.get('skew_angle', 0.0) new_structure['is_skew_corrected'] = ocr_data.get('is_skew_corrected', False) # 5. 保存结构文件 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) structure_filename = f"{image_path.stem}{structure_suffix}" structure_path = output_path / structure_filename # 准备保存的数据(移除不需要的字段) structure_to_save = new_structure.copy() for key in ['rows', 'columns']: structure_to_save.pop(key, None) with open(structure_path, 'w', encoding='utf-8') as f: json.dump(structure_to_save, f, ensure_ascii=False, indent=2) return { 'success': True, 'json_path': str(json_path), 'image_path': str(image_path), 'structure_path': str(structure_path), 'filename': image_path.name, 'rows': new_structure.get('total_rows', 0), 'cols': new_structure.get('total_cols', 0) } except Exception as e: logger.exception(f"处理文件失败: {json_path}") return { 'success': False, 'json_path': str(json_path), 'image_path': str(image_path), 'error': str(e), 'filename': image_path.name if image_path else 'unknown' } def draw_batch_images( self, results: List[Dict], line_width: int = 2, line_color: Tuple[int, int, int] = (0, 0, 0) ) -> List[Dict]: """ 批量绘制表格线到图片上 Args: results: process_batch_from_data_source 的返回结果中的 results 列表 line_width: 线条宽度 line_color: 线条颜色 RGB Returns: 绘制结果列表 """ draw_results = [] for result in results: if not result.get('success'): continue try: image_path = Path(result['image_path']) structure_path = Path(result['structure_path']) # 读取图片 image = Image.open(image_path) if image.mode != 'RGB': image = image.convert('RGB') # 读取结构 with open(structure_path, 'r', encoding='utf-8') as f: structure = json.load(f) # 绘制线条 image_with_lines = DrawingService.draw_clean_lines( image, structure, line_width=line_width, line_color=line_color ) # 保存 output_path = structure_path.parent / f"{image_path.stem}.png" image_with_lines.save(str(output_path), 'PNG') draw_results.append({ 'success': True, 'image_path': str(output_path), 'filename': image_path.name }) except Exception as e: logger.error(f"绘制图片失败 {result.get('filename')}: {e}") draw_results.append({ 'success': False, 'error': str(e), 'filename': result.get('filename') }) success_count = sum(1 for r in draw_results if r.get('success')) logger.info(f"🖼️ 绘制完成: {success_count}/{len(results)} 张图片") return draw_results