|
|
@@ -0,0 +1,400 @@
|
|
|
+"""
|
|
|
+批量处理服务
|
|
|
+基于首页模板应用到多个文件,适用于多页银行流水等场景
|
|
|
+"""
|
|
|
+
|
|
|
+import json
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Dict, Optional, Tuple, Any, Callable
|
|
|
+from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
+from loguru import logger
|
|
|
+from PIL import Image
|
|
|
+import io
|
|
|
+
|
|
|
+# 添加 ocr_platform 根目录到 Python 路径(用于导入 core 和 ocr_utils)
|
|
|
+_file_path = Path(__file__).resolve()
|
|
|
+ocr_platform_root = _file_path.parents[2] # batch_service.py -> services -> backend -> table_line_generator -> ocr_platform
|
|
|
+if str(ocr_platform_root) not in sys.path:
|
|
|
+ sys.path.insert(0, str(ocr_platform_root))
|
|
|
+
|
|
|
+from table_line_generator.core.table_analyzer import TableAnalyzer
|
|
|
+from table_line_generator.core.ocr_parser import OcrParser
|
|
|
+from table_line_generator.core.drawing_service import DrawingService
|
|
|
+from table_line_generator.backend.services.template_service import TemplateService
|
|
|
+
|
|
|
+
|
|
|
+class BatchProcessor:
|
|
|
+ """
|
|
|
+ 批量表格处理器
|
|
|
+
|
|
|
+ 工作流程:
|
|
|
+ 1. 从首页学习模板结构(主要是竖线/列边界)
|
|
|
+ 2. 将模板应用到所有页(竖线复用,横线自适应)
|
|
|
+ 3. 支持并行处理提升效率
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, max_workers: int = 4):
|
|
|
+ """
|
|
|
+ 初始化批量处理器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ max_workers: 并行处理的最大线程数
|
|
|
+ """
|
|
|
+ self.max_workers = max_workers
|
|
|
+
|
|
|
+ def learn_template_from_structure(
|
|
|
+ self,
|
|
|
+ first_page_structure: Dict
|
|
|
+ ) -> Dict:
|
|
|
+ """
|
|
|
+ 从首页结构中学习模板
|
|
|
+
|
|
|
+ Args:
|
|
|
+ first_page_structure: 首页的表格结构
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 模板字典,包含可复用的部分(主要是列边界)
|
|
|
+ """
|
|
|
+ template = {
|
|
|
+ 'vertical_lines': first_page_structure.get('vertical_lines', []).copy(),
|
|
|
+ 'table_bbox': first_page_structure.get('table_bbox'),
|
|
|
+ 'col_widths': first_page_structure.get('col_widths'),
|
|
|
+ 'total_cols': first_page_structure.get('total_cols'),
|
|
|
+ 'mode': first_page_structure.get('mode', 'hybrid')
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info(f"从首页学习模板: {len(template['vertical_lines'])} 条竖线, {template['total_cols']} 列")
|
|
|
+
|
|
|
+ return template
|
|
|
+
|
|
|
+ def apply_template_to_structure(
|
|
|
+ self,
|
|
|
+ template: Dict,
|
|
|
+ target_ocr_data: Dict,
|
|
|
+ adjust_rows: bool = True,
|
|
|
+ y_tolerance: int = 5,
|
|
|
+ min_row_height: int = 20
|
|
|
+ ) -> Dict:
|
|
|
+ """
|
|
|
+ 将模板应用到目标页面的 OCR 数据
|
|
|
+
|
|
|
+ Args:
|
|
|
+ template: 从首页学习的模板
|
|
|
+ target_ocr_data: 目标页面的 OCR 数据
|
|
|
+ adjust_rows: 是否自适应调整行分割
|
|
|
+ y_tolerance: Y轴聚类容差
|
|
|
+ min_row_height: 最小行高
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 应用模板后的结构
|
|
|
+ """
|
|
|
+ # 创建分析器
|
|
|
+ analyzer = TableAnalyzer(None, target_ocr_data)
|
|
|
+
|
|
|
+ if adjust_rows:
|
|
|
+ # 重新分析行结构(自适应)
|
|
|
+ analyzed = analyzer.analyze(
|
|
|
+ y_tolerance=y_tolerance,
|
|
|
+ min_row_height=min_row_height,
|
|
|
+ method=template.get('mode', 'auto')
|
|
|
+ )
|
|
|
+
|
|
|
+ # 复用模板的列信息
|
|
|
+ new_structure = {
|
|
|
+ 'horizontal_lines': analyzed['horizontal_lines'],
|
|
|
+ 'vertical_lines': template['vertical_lines'].copy(),
|
|
|
+ 'table_bbox': template.get('table_bbox') or analyzed.get('table_bbox'),
|
|
|
+ 'row_height': analyzed.get('row_height'),
|
|
|
+ 'col_widths': template.get('col_widths'),
|
|
|
+ 'total_rows': analyzed.get('total_rows'),
|
|
|
+ 'total_cols': template.get('total_cols'),
|
|
|
+ 'mode': template.get('mode'),
|
|
|
+ 'modified_h_lines': [],
|
|
|
+ 'modified_v_lines': [],
|
|
|
+ 'image_rotation_angle': target_ocr_data.get('image_rotation_angle', 0.0),
|
|
|
+ 'skew_angle': target_ocr_data.get('skew_angle', 0.0),
|
|
|
+ 'is_skew_corrected': target_ocr_data.get('is_skew_corrected', False)
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ # 完全复用模板(包括横线)
|
|
|
+ new_structure = template.copy()
|
|
|
+ new_structure['image_rotation_angle'] = target_ocr_data.get('image_rotation_angle', 0.0)
|
|
|
+ new_structure['skew_angle'] = target_ocr_data.get('skew_angle', 0.0)
|
|
|
+ new_structure['is_skew_corrected'] = target_ocr_data.get('is_skew_corrected', False)
|
|
|
+
|
|
|
+ return new_structure
|
|
|
+
|
|
|
+ def process_batch_from_data_source(
|
|
|
+ self,
|
|
|
+ template_name: str,
|
|
|
+ file_pairs: List[Dict],
|
|
|
+ output_dir: str,
|
|
|
+ parallel: bool = True,
|
|
|
+ adjust_rows: bool = True,
|
|
|
+ structure_suffix: str = "_structure.json",
|
|
|
+ image_suffix: str = "_with_lines.png",
|
|
|
+ progress_callback: Optional[Callable[[int, int], None]] = None
|
|
|
+ ) -> Dict:
|
|
|
+ """
|
|
|
+ 批量处理数据源中的文件
|
|
|
+
|
|
|
+ Args:
|
|
|
+ template_name: 模板名称(从 TemplateService 加载)
|
|
|
+ file_pairs: 文件对列表 [{'json_path': ..., 'image_path': ...}, ...]
|
|
|
+ output_dir: 输出目录
|
|
|
+ parallel: 是否并行处理
|
|
|
+ adjust_rows: 是否自适应调整行分割
|
|
|
+ structure_suffix: 结构文件后缀
|
|
|
+ image_suffix: 输出图片后缀
|
|
|
+ progress_callback: 进度回调 callback(index, total)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 处理结果摘要
|
|
|
+ """
|
|
|
+ total = len(file_pairs)
|
|
|
+ results = []
|
|
|
+
|
|
|
+ # 加载模板
|
|
|
+ template_service = TemplateService()
|
|
|
+ logger.info(f"开始批量处理: {total} 个文件, 使用模板: {template_name}, 并行={parallel}")
|
|
|
+
|
|
|
+ if parallel and total > 1:
|
|
|
+ # 并行处理
|
|
|
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
|
+ futures = {
|
|
|
+ executor.submit(
|
|
|
+ self._process_single_file,
|
|
|
+ pair, template_name, output_dir, adjust_rows, structure_suffix
|
|
|
+ ): idx
|
|
|
+ for idx, pair in enumerate(file_pairs)
|
|
|
+ }
|
|
|
+
|
|
|
+ for future in as_completed(futures):
|
|
|
+ idx = futures[future]
|
|
|
+ try:
|
|
|
+ result = future.result()
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ if progress_callback:
|
|
|
+ progress_callback(idx, total)
|
|
|
+
|
|
|
+ status = "✅" if result['success'] else "❌"
|
|
|
+ logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}")
|
|
|
+ results.append({
|
|
|
+ 'success': False,
|
|
|
+ 'error': str(e),
|
|
|
+ 'index': idx
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ # 串行处理
|
|
|
+ for idx, pair in enumerate(file_pairs):
|
|
|
+ try:
|
|
|
+ result = self._process_single_file(
|
|
|
+ pair, template_name, output_dir, adjust_rows, structure_suffix
|
|
|
+ )
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ if progress_callback:
|
|
|
+ progress_callback(idx, total)
|
|
|
+
|
|
|
+ status = "✅" if result['success'] else "❌"
|
|
|
+ logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}")
|
|
|
+ results.append({
|
|
|
+ 'success': False,
|
|
|
+ 'error': str(e),
|
|
|
+ 'index': idx
|
|
|
+ })
|
|
|
+
|
|
|
+ # 统计结果
|
|
|
+ success_count = sum(1 for r in results if r.get('success'))
|
|
|
+ failed_count = total - success_count
|
|
|
+
|
|
|
+ summary = {
|
|
|
+ 'total': total,
|
|
|
+ 'success': success_count,
|
|
|
+ 'failed': failed_count,
|
|
|
+ 'results': results
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info(f"📊 批量处理完成: 成功 {success_count}/{total}, 失败 {failed_count}")
|
|
|
+
|
|
|
+ return summary
|
|
|
+
|
|
|
+ def _process_single_file(
|
|
|
+ self,
|
|
|
+ file_pair: Dict,
|
|
|
+ template_name: str,
|
|
|
+ output_dir: str,
|
|
|
+ adjust_rows: bool,
|
|
|
+ structure_suffix: str = "_structure.json"
|
|
|
+ ) -> Dict:
|
|
|
+ """
|
|
|
+ 处理单个文件
|
|
|
+
|
|
|
+ Args:
|
|
|
+ file_pair: 文件对 {'json_path': ..., 'image_path': ...}
|
|
|
+ template_name: 模板名称(从 TemplateService 加载)
|
|
|
+ output_dir: 输出目录
|
|
|
+ adjust_rows: 是否调整行
|
|
|
+ structure_suffix: 结构文件后缀
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 处理结果
|
|
|
+ """
|
|
|
+ json_path = Path(file_pair['json_path'])
|
|
|
+ image_path = Path(file_pair['image_path'])
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 1. 读取 OCR 数据
|
|
|
+ with open(json_path, 'r', encoding='utf-8') as f:
|
|
|
+ ocr_result = json.load(f)
|
|
|
+
|
|
|
+ # 2. 解析 OCR 数据(获取目标页面的 table_bbox 和 ocr_data)
|
|
|
+ target_table_bbox, ocr_data = OcrParser.parse(ocr_result)
|
|
|
+ target_image_size = ocr_data.get('image_size', {'width': 1, 'height': 1})
|
|
|
+
|
|
|
+ # 3. 使用 TemplateService.preview_apply() 应用模板到目标页面
|
|
|
+ # 这会自动处理坐标映射,适配不同尺寸的图片
|
|
|
+ template_service = TemplateService()
|
|
|
+ applied_template = template_service.preview_apply(
|
|
|
+ template_name=template_name,
|
|
|
+ target_image_size=target_image_size,
|
|
|
+ target_table_bbox=target_table_bbox,
|
|
|
+ mode='relative' # 使用相对坐标映射,适应不同尺寸图片
|
|
|
+ )
|
|
|
+
|
|
|
+ # 4. 构建最终结构(合并应用的模板和目标页面的元数据)
|
|
|
+ if adjust_rows:
|
|
|
+ # 如果启用自适应行,则自动分析目标页面的行结构
|
|
|
+ analyzer = TableAnalyzer(None, ocr_data)
|
|
|
+ analyzed = analyzer.analyze()
|
|
|
+
|
|
|
+ # 关键:使用模板的竖线(已通过 preview_apply 映射),结合目标的横线
|
|
|
+ new_structure = {
|
|
|
+ 'horizontal_lines': analyzed['horizontal_lines'], # 自适应调整
|
|
|
+ 'vertical_lines': applied_template['vertical_lines'], # 来自模板,已映射
|
|
|
+ 'table_bbox': applied_template['table_bbox'], # 目标页面的 bbox
|
|
|
+ 'row_height': analyzed.get('row_height'),
|
|
|
+ 'col_widths': applied_template.get('col_widths'),
|
|
|
+ 'total_rows': analyzed.get('total_rows'),
|
|
|
+ 'total_cols': applied_template.get('total_cols'),
|
|
|
+ 'mode': applied_template.get('mode', 'hybrid'),
|
|
|
+ 'modified_h_lines': [],
|
|
|
+ 'modified_v_lines': [],
|
|
|
+ # 各页使用自己的旋转角度
|
|
|
+ 'image_rotation_angle': ocr_data.get('image_rotation_angle', 0.0),
|
|
|
+ 'skew_angle': ocr_data.get('skew_angle', 0.0),
|
|
|
+ 'is_skew_corrected': ocr_data.get('is_skew_corrected', False)
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ # 完全复用应用的模板(包括竖线和横线)
|
|
|
+ new_structure = applied_template.copy()
|
|
|
+ new_structure['image_rotation_angle'] = ocr_data.get('image_rotation_angle', 0.0)
|
|
|
+ new_structure['skew_angle'] = ocr_data.get('skew_angle', 0.0)
|
|
|
+ new_structure['is_skew_corrected'] = ocr_data.get('is_skew_corrected', False)
|
|
|
+
|
|
|
+ # 5. 保存结构文件
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ structure_filename = f"{image_path.stem}{structure_suffix}"
|
|
|
+ structure_path = output_path / structure_filename
|
|
|
+
|
|
|
+ # 准备保存的数据(移除不需要的字段)
|
|
|
+ structure_to_save = new_structure.copy()
|
|
|
+ for key in ['rows', 'columns']:
|
|
|
+ structure_to_save.pop(key, None)
|
|
|
+
|
|
|
+ with open(structure_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(structure_to_save, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'success': True,
|
|
|
+ 'json_path': str(json_path),
|
|
|
+ 'image_path': str(image_path),
|
|
|
+ 'structure_path': str(structure_path),
|
|
|
+ 'filename': image_path.name,
|
|
|
+ 'rows': new_structure.get('total_rows', 0),
|
|
|
+ 'cols': new_structure.get('total_cols', 0)
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(f"处理文件失败: {json_path}")
|
|
|
+ return {
|
|
|
+ 'success': False,
|
|
|
+ 'json_path': str(json_path),
|
|
|
+ 'image_path': str(image_path),
|
|
|
+ 'error': str(e),
|
|
|
+ 'filename': image_path.name if image_path else 'unknown'
|
|
|
+ }
|
|
|
+
|
|
|
+ def draw_batch_images(
|
|
|
+ self,
|
|
|
+ results: List[Dict],
|
|
|
+ line_width: int = 2,
|
|
|
+ line_color: Tuple[int, int, int] = (0, 0, 0)
|
|
|
+ ) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 批量绘制表格线到图片上
|
|
|
+
|
|
|
+ Args:
|
|
|
+ results: process_batch_from_data_source 的返回结果中的 results 列表
|
|
|
+ line_width: 线条宽度
|
|
|
+ line_color: 线条颜色 RGB
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 绘制结果列表
|
|
|
+ """
|
|
|
+ draw_results = []
|
|
|
+
|
|
|
+ for result in results:
|
|
|
+ if not result.get('success'):
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ image_path = Path(result['image_path'])
|
|
|
+ structure_path = Path(result['structure_path'])
|
|
|
+
|
|
|
+ # 读取图片
|
|
|
+ image = Image.open(image_path)
|
|
|
+ if image.mode != 'RGB':
|
|
|
+ image = image.convert('RGB')
|
|
|
+
|
|
|
+ # 读取结构
|
|
|
+ with open(structure_path, 'r', encoding='utf-8') as f:
|
|
|
+ structure = json.load(f)
|
|
|
+
|
|
|
+ # 绘制线条
|
|
|
+ image_with_lines = DrawingService.draw_clean_lines(
|
|
|
+ image, structure, line_width=line_width, line_color=line_color
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存
|
|
|
+ output_path = structure_path.parent / f"{image_path.stem}.png"
|
|
|
+ image_with_lines.save(str(output_path), 'PNG')
|
|
|
+
|
|
|
+ draw_results.append({
|
|
|
+ 'success': True,
|
|
|
+ 'image_path': str(output_path),
|
|
|
+ 'filename': image_path.name
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"绘制图片失败 {result.get('filename')}: {e}")
|
|
|
+ draw_results.append({
|
|
|
+ 'success': False,
|
|
|
+ 'error': str(e),
|
|
|
+ 'filename': result.get('filename')
|
|
|
+ })
|
|
|
+
|
|
|
+ success_count = sum(1 for r in draw_results if r.get('success'))
|
|
|
+ logger.info(f"🖼️ 绘制完成: {success_count}/{len(results)} 张图片")
|
|
|
+
|
|
|
+ return draw_results
|