| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400 |
- """
- 批量处理服务
- 基于首页模板应用到多个文件,适用于多页银行流水等场景
- """
- import json
- import sys
- from pathlib import Path
- from typing import List, Dict, Optional, Tuple, Any, Callable
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from loguru import logger
- from PIL import Image
- import io
- # 添加 ocr_platform 根目录到 Python 路径(用于导入 core 和 ocr_utils)
- _file_path = Path(__file__).resolve()
- ocr_platform_root = _file_path.parents[2] # batch_service.py -> services -> backend -> table_line_generator -> ocr_platform
- if str(ocr_platform_root) not in sys.path:
- sys.path.insert(0, str(ocr_platform_root))
- from table_line_generator.core.table_analyzer import TableAnalyzer
- from table_line_generator.core.ocr_parser import OcrParser
- from table_line_generator.core.drawing_service import DrawingService
- from table_line_generator.backend.services.template_service import TemplateService
- class BatchProcessor:
- """
- 批量表格处理器
-
- 工作流程:
- 1. 从首页学习模板结构(主要是竖线/列边界)
- 2. 将模板应用到所有页(竖线复用,横线自适应)
- 3. 支持并行处理提升效率
- """
-
- def __init__(self, max_workers: int = 4):
- """
- 初始化批量处理器
-
- Args:
- max_workers: 并行处理的最大线程数
- """
- self.max_workers = max_workers
-
- def learn_template_from_structure(
- self,
- first_page_structure: Dict
- ) -> Dict:
- """
- 从首页结构中学习模板
-
- Args:
- first_page_structure: 首页的表格结构
-
- Returns:
- 模板字典,包含可复用的部分(主要是列边界)
- """
- template = {
- 'vertical_lines': first_page_structure.get('vertical_lines', []).copy(),
- 'table_bbox': first_page_structure.get('table_bbox'),
- 'col_widths': first_page_structure.get('col_widths'),
- 'total_cols': first_page_structure.get('total_cols'),
- 'mode': first_page_structure.get('mode', 'hybrid')
- }
-
- logger.info(f"从首页学习模板: {len(template['vertical_lines'])} 条竖线, {template['total_cols']} 列")
-
- return template
-
- def apply_template_to_structure(
- self,
- template: Dict,
- target_ocr_data: Dict,
- adjust_rows: bool = True,
- y_tolerance: int = 5,
- min_row_height: int = 20
- ) -> Dict:
- """
- 将模板应用到目标页面的 OCR 数据
-
- Args:
- template: 从首页学习的模板
- target_ocr_data: 目标页面的 OCR 数据
- adjust_rows: 是否自适应调整行分割
- y_tolerance: Y轴聚类容差
- min_row_height: 最小行高
-
- Returns:
- 应用模板后的结构
- """
- # 创建分析器
- analyzer = TableAnalyzer(None, target_ocr_data)
-
- if adjust_rows:
- # 重新分析行结构(自适应)
- analyzed = analyzer.analyze(
- y_tolerance=y_tolerance,
- min_row_height=min_row_height,
- method=template.get('mode', 'auto')
- )
-
- # 复用模板的列信息
- new_structure = {
- 'horizontal_lines': analyzed['horizontal_lines'],
- 'vertical_lines': template['vertical_lines'].copy(),
- 'table_bbox': template.get('table_bbox') or analyzed.get('table_bbox'),
- 'row_height': analyzed.get('row_height'),
- 'col_widths': template.get('col_widths'),
- 'total_rows': analyzed.get('total_rows'),
- 'total_cols': template.get('total_cols'),
- 'mode': template.get('mode'),
- 'modified_h_lines': [],
- 'modified_v_lines': [],
- 'image_rotation_angle': target_ocr_data.get('image_rotation_angle', 0.0),
- 'skew_angle': target_ocr_data.get('skew_angle', 0.0),
- 'is_skew_corrected': target_ocr_data.get('is_skew_corrected', False)
- }
- else:
- # 完全复用模板(包括横线)
- new_structure = template.copy()
- new_structure['image_rotation_angle'] = target_ocr_data.get('image_rotation_angle', 0.0)
- new_structure['skew_angle'] = target_ocr_data.get('skew_angle', 0.0)
- new_structure['is_skew_corrected'] = target_ocr_data.get('is_skew_corrected', False)
-
- return new_structure
-
- def process_batch_from_data_source(
- self,
- template_name: str,
- file_pairs: List[Dict],
- output_dir: str,
- parallel: bool = True,
- adjust_rows: bool = True,
- structure_suffix: str = "_structure.json",
- image_suffix: str = "_with_lines.png",
- progress_callback: Optional[Callable[[int, int], None]] = None
- ) -> Dict:
- """
- 批量处理数据源中的文件
-
- Args:
- template_name: 模板名称(从 TemplateService 加载)
- file_pairs: 文件对列表 [{'json_path': ..., 'image_path': ...}, ...]
- output_dir: 输出目录
- parallel: 是否并行处理
- adjust_rows: 是否自适应调整行分割
- structure_suffix: 结构文件后缀
- image_suffix: 输出图片后缀
- progress_callback: 进度回调 callback(index, total)
-
- Returns:
- 处理结果摘要
- """
- total = len(file_pairs)
- results = []
-
- # 加载模板
- template_service = TemplateService()
- logger.info(f"开始批量处理: {total} 个文件, 使用模板: {template_name}, 并行={parallel}")
-
- if parallel and total > 1:
- # 并行处理
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
- futures = {
- executor.submit(
- self._process_single_file,
- pair, template_name, output_dir, adjust_rows, structure_suffix
- ): idx
- for idx, pair in enumerate(file_pairs)
- }
-
- for future in as_completed(futures):
- idx = futures[future]
- try:
- result = future.result()
- results.append(result)
-
- if progress_callback:
- progress_callback(idx, total)
-
- status = "✅" if result['success'] else "❌"
- logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}")
-
- except Exception as e:
- logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}")
- results.append({
- 'success': False,
- 'error': str(e),
- 'index': idx
- })
- else:
- # 串行处理
- for idx, pair in enumerate(file_pairs):
- try:
- result = self._process_single_file(
- pair, template_name, output_dir, adjust_rows, structure_suffix
- )
- results.append(result)
-
- if progress_callback:
- progress_callback(idx, total)
-
- status = "✅" if result['success'] else "❌"
- logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}")
-
- except Exception as e:
- logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}")
- results.append({
- 'success': False,
- 'error': str(e),
- 'index': idx
- })
-
- # 统计结果
- success_count = sum(1 for r in results if r.get('success'))
- failed_count = total - success_count
-
- summary = {
- 'total': total,
- 'success': success_count,
- 'failed': failed_count,
- 'results': results
- }
-
- logger.info(f"📊 批量处理完成: 成功 {success_count}/{total}, 失败 {failed_count}")
-
- return summary
-
- def _process_single_file(
- self,
- file_pair: Dict,
- template_name: str,
- output_dir: str,
- adjust_rows: bool,
- structure_suffix: str = "_structure.json"
- ) -> Dict:
- """
- 处理单个文件
-
- Args:
- file_pair: 文件对 {'json_path': ..., 'image_path': ...}
- template_name: 模板名称(从 TemplateService 加载)
- output_dir: 输出目录
- adjust_rows: 是否调整行
- structure_suffix: 结构文件后缀
-
- Returns:
- 处理结果
- """
- json_path = Path(file_pair['json_path'])
- image_path = Path(file_pair['image_path'])
-
- try:
- # 1. 读取 OCR 数据
- with open(json_path, 'r', encoding='utf-8') as f:
- ocr_result = json.load(f)
-
- # 2. 解析 OCR 数据(获取目标页面的 table_bbox 和 ocr_data)
- target_table_bbox, ocr_data = OcrParser.parse(ocr_result)
- target_image_size = ocr_data.get('image_size', {'width': 1, 'height': 1})
-
- # 3. 使用 TemplateService.preview_apply() 应用模板到目标页面
- # 这会自动处理坐标映射,适配不同尺寸的图片
- template_service = TemplateService()
- applied_template = template_service.preview_apply(
- template_name=template_name,
- target_image_size=target_image_size,
- target_table_bbox=target_table_bbox,
- mode='relative' # 使用相对坐标映射,适应不同尺寸图片
- )
-
- # 4. 构建最终结构(合并应用的模板和目标页面的元数据)
- if adjust_rows:
- # 如果启用自适应行,则自动分析目标页面的行结构
- analyzer = TableAnalyzer(None, ocr_data)
- analyzed = analyzer.analyze()
-
- # 关键:使用模板的竖线(已通过 preview_apply 映射),结合目标的横线
- new_structure = {
- 'horizontal_lines': analyzed['horizontal_lines'], # 自适应调整
- 'vertical_lines': applied_template['vertical_lines'], # 来自模板,已映射
- 'table_bbox': applied_template['table_bbox'], # 目标页面的 bbox
- 'row_height': analyzed.get('row_height'),
- 'col_widths': applied_template.get('col_widths'),
- 'total_rows': analyzed.get('total_rows'),
- 'total_cols': applied_template.get('total_cols'),
- 'mode': applied_template.get('mode', 'hybrid'),
- 'modified_h_lines': [],
- 'modified_v_lines': [],
- # 各页使用自己的旋转角度
- 'image_rotation_angle': ocr_data.get('image_rotation_angle', 0.0),
- 'skew_angle': ocr_data.get('skew_angle', 0.0),
- 'is_skew_corrected': ocr_data.get('is_skew_corrected', False)
- }
- else:
- # 完全复用应用的模板(包括竖线和横线)
- new_structure = applied_template.copy()
- new_structure['image_rotation_angle'] = ocr_data.get('image_rotation_angle', 0.0)
- new_structure['skew_angle'] = ocr_data.get('skew_angle', 0.0)
- new_structure['is_skew_corrected'] = ocr_data.get('is_skew_corrected', False)
-
- # 5. 保存结构文件
- output_path = Path(output_dir)
- output_path.mkdir(parents=True, exist_ok=True)
-
- structure_filename = f"{image_path.stem}{structure_suffix}"
- structure_path = output_path / structure_filename
-
- # 准备保存的数据(移除不需要的字段)
- structure_to_save = new_structure.copy()
- for key in ['rows', 'columns']:
- structure_to_save.pop(key, None)
-
- with open(structure_path, 'w', encoding='utf-8') as f:
- json.dump(structure_to_save, f, ensure_ascii=False, indent=2)
-
- return {
- 'success': True,
- 'json_path': str(json_path),
- 'image_path': str(image_path),
- 'structure_path': str(structure_path),
- 'filename': image_path.name,
- 'rows': new_structure.get('total_rows', 0),
- 'cols': new_structure.get('total_cols', 0)
- }
-
- except Exception as e:
- logger.exception(f"处理文件失败: {json_path}")
- return {
- 'success': False,
- 'json_path': str(json_path),
- 'image_path': str(image_path),
- 'error': str(e),
- 'filename': image_path.name if image_path else 'unknown'
- }
-
- def draw_batch_images(
- self,
- results: List[Dict],
- line_width: int = 2,
- line_color: Tuple[int, int, int] = (0, 0, 0)
- ) -> List[Dict]:
- """
- 批量绘制表格线到图片上
-
- Args:
- results: process_batch_from_data_source 的返回结果中的 results 列表
- line_width: 线条宽度
- line_color: 线条颜色 RGB
-
- Returns:
- 绘制结果列表
- """
- draw_results = []
-
- for result in results:
- if not result.get('success'):
- continue
-
- try:
- image_path = Path(result['image_path'])
- structure_path = Path(result['structure_path'])
-
- # 读取图片
- image = Image.open(image_path)
- if image.mode != 'RGB':
- image = image.convert('RGB')
-
- # 读取结构
- with open(structure_path, 'r', encoding='utf-8') as f:
- structure = json.load(f)
-
- # 绘制线条
- image_with_lines = DrawingService.draw_clean_lines(
- image, structure, line_width=line_width, line_color=line_color
- )
-
- # 保存
- output_path = structure_path.parent / f"{image_path.stem}.png"
- image_with_lines.save(str(output_path), 'PNG')
-
- draw_results.append({
- 'success': True,
- 'image_path': str(output_path),
- 'filename': image_path.name
- })
-
- except Exception as e:
- logger.error(f"绘制图片失败 {result.get('filename')}: {e}")
- draw_results.append({
- 'success': False,
- 'error': str(e),
- 'filename': result.get('filename')
- })
-
- success_count = sum(1 for r in draw_results if r.get('success'))
- logger.info(f"🖼️ 绘制完成: {success_count}/{len(results)} 张图片")
-
- return draw_results
|