from typing import Dict, List, Any, Optional, Union from pathlib import Path import numpy as np from PIL import Image import fitz # PyMuPDF from loguru import logger from .model_factory import ModelFactory from .config_manager import ConfigManager from models.adapters import BaseAdapter class FinancialDocPipeline: """金融文档处理统一流水线""" def __init__(self, config_path: str): self.config = ConfigManager.load_config(config_path) self.scene_name = self.config.get('scene_name', 'unknown') # 初始化各个组件 self._init_components() def _init_components(self): """初始化处理组件""" try: # 1. 预处理器(方向分类、图像矫正等) self.preprocessor = ModelFactory.create_preprocessor( self.config['preprocessor'] ) # 2. 版式检测器 self.layout_detector = ModelFactory.create_layout_detector( self.config['layout_detection'] ) # 3. VL识别器(表格、公式等) self.vl_recognizer = ModelFactory.create_vl_recognizer( self.config['vl_recognition'] ) # 4. OCR识别器 self.ocr_recognizer = ModelFactory.create_ocr_recognizer( self.config['ocr_recognition'] ) logger.info(f"✅ Initialized pipeline for scene: {self.scene_name}") except Exception as e: logger.error(f"❌ Failed to initialize pipeline components: {e}") raise def process_document(self, document_path: str) -> Dict[str, Any]: """ 处理文档的主流程 Args: document_path: 文档路径 Returns: 处理结果,包含所有元素的坐标和内容信息 """ results = { 'scene': self.scene_name, 'document_path': document_path, 'pages': [], 'metadata': self._extract_metadata(document_path) } try: # 加载文档图像 images = self._load_document_images(document_path) logger.info(f"📄 Loaded {len(images)} pages from document") for page_idx, image in enumerate(images): logger.info(f"🔍 Processing page {page_idx + 1}/{len(images)}") page_result = self._process_single_page(image, page_idx) results['pages'].append(page_result) except Exception as e: logger.error(f"❌ Failed to process document: {e}") raise return results def _load_document_images(self, document_path: str) -> List[np.ndarray]: """加载文档图像""" document_path = Path(document_path) if not document_path.exists(): raise FileNotFoundError(f"Document not found: {document_path}") images = [] if document_path.suffix.lower() == '.pdf': # 处理PDF文件 doc = fitz.open(document_path) try: for page_num in range(len(doc)): page = doc.load_page(page_num) # 设置合适的DPI dpi = self.config.get('input', {}).get('dpi', 200) mat = fitz.Matrix(dpi/72, dpi/72) pix = page.get_pixmap(matrix=mat) img_data = pix.tobytes("ppm") # 转换为numpy数组 from io import BytesIO img = Image.open(BytesIO(img_data)) img_array = np.array(img) images.append(img_array) finally: doc.close() elif document_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff']: # 处理图像文件 img = Image.open(document_path) if img.mode != 'RGB': img = img.convert('RGB') img_array = np.array(img) images.append(img_array) else: raise ValueError(f"Unsupported file format: {document_path.suffix}") return images def _extract_metadata(self, document_path: str) -> Dict[str, Any]: """提取文档元数据""" document_path = Path(document_path) metadata = { 'filename': document_path.name, 'size': document_path.stat().st_size, 'format': document_path.suffix.lower() } # 如果是PDF,提取更多元数据 if document_path.suffix.lower() == '.pdf': try: doc = fitz.open(document_path) metadata.update({ 'page_count': len(doc), 'title': doc.metadata.get('title', ''), 'author': doc.metadata.get('author', ''), 'subject': doc.metadata.get('subject', ''), 'creator': doc.metadata.get('creator', '') }) doc.close() except Exception: pass return metadata def _process_single_page(self, image: np.ndarray, page_idx: int) -> Dict[str, Any]: """处理单页文档""" # 1. 预处理(方向校正等) try: preprocessed_image, rotate_angle = self.preprocessor.process(image) except Exception as e: logger.warning(f"⚠️ Preprocessing failed for page {page_idx}: {e}") preprocessed_image = image # 2. 版式检测 try: layout_results = self.layout_detector.detect(preprocessed_image) logger.info(f"📋 Detected {len(layout_results)} layout elements on page {page_idx}") except Exception as e: logger.error(f"❌ Layout detection failed for page {page_idx}: {e}") layout_results = [] # 3. 根据场景类型分别处理不同元素 page_elements = [] for layout_item in layout_results: try: element_type = layout_item['category'] if element_type in ['table_body', 'table']: # 表格使用VL模型处理 element_result = self._process_table_element( preprocessed_image, layout_item ) elif element_type in ['text', 'title', 'ocr_text']: # 文本使用OCR处理 element_result = self._process_text_element( preprocessed_image, layout_item ) elif element_type in ['interline_equation', 'inline_equation']: # 公式使用VL模型处理 element_result = self._process_formula_element( preprocessed_image, layout_item ) else: # 其他元素保持原样 element_result = layout_item.copy() element_result['type'] = element_type page_elements.append(element_result) except Exception as e: logger.warning(f"⚠️ Failed to process element {element_type}: {e}") # 添加失败的元素,标记为错误 error_element = layout_item.copy() error_element['type'] = 'error' error_element['error'] = str(e) page_elements.append(error_element) return { 'page_idx': page_idx, 'elements': page_elements, 'layout_raw': layout_results, 'image_shape': preprocessed_image.shape, 'processed_image': preprocessed_image, 'angle': rotate_angle } def _process_table_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]: """处理表格元素""" try: # 裁剪表格区域 cropped_table = self._crop_region(image, layout_item['bbox']) # 使用VL模型识别表格 table_result = self.vl_recognizer.recognize_table( cropped_table, return_cells_coordinate=True # 关键:返回单元格坐标 ) # 转换坐标到原图坐标系 if 'cells' in table_result: for cell in table_result['cells']: cell['absolute_bbox'] = self._convert_to_absolute_coords( cell['bbox'], layout_item['bbox'] ) result = { 'type': 'table', 'bbox': layout_item['bbox'], 'confidence': layout_item.get('confidence', 0.0), 'content': table_result, 'scene_specific': self._add_scene_specific_info(table_result) } logger.info(f"✅ Table processed with {len(table_result.get('cells', []))} cells") return result except Exception as e: logger.error(f"❌ Table processing failed: {e}") return { 'type': 'table', 'bbox': layout_item['bbox'], 'content': {'html': '', 'markdown': '', 'cells': []}, 'error': str(e) } def _process_text_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]: """处理文本元素""" try: # 裁剪文本区域 cropped_text = self._crop_region(image, layout_item['bbox']) # 使用OCR识别文本 text_results = self.ocr_recognizer.recognize_text(cropped_text) # 合并识别结果 combined_text = "" if text_results: text_parts = [item['text'] for item in text_results if item['confidence'] > 0.5] combined_text = " ".join(text_parts) result = { 'type': layout_item['category'], 'bbox': layout_item['bbox'], 'confidence': layout_item.get('confidence', 0.0), 'content': { 'text': combined_text, 'ocr_details': text_results } } logger.info(f"✅ Text processed: '{combined_text[:50]}...'") return result except Exception as e: logger.error(f"❌ Text processing failed: {e}") return { 'type': layout_item['category'], 'bbox': layout_item['bbox'], 'content': {'text': '', 'ocr_details': []}, 'error': str(e) } def _process_formula_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]: """处理公式元素""" try: # 裁剪公式区域 cropped_formula = self._crop_region(image, layout_item['bbox']) # 使用VL模型识别公式 formula_result = self.vl_recognizer.recognize_formula(cropped_formula) result = { 'type': 'formula', 'bbox': layout_item['bbox'], 'confidence': layout_item.get('confidence', 0.0), 'content': formula_result } logger.info(f"✅ Formula processed: {formula_result.get('latex', '')[:50]}...") return result except Exception as e: logger.error(f"❌ Formula processing failed: {e}") return { 'type': 'formula', 'bbox': layout_item['bbox'], 'content': {'latex': '', 'confidence': 0.0}, 'error': str(e) } def _crop_region(self, image: np.ndarray, bbox: List[float]) -> np.ndarray: """裁剪图像区域""" if len(bbox) < 4: return image x1, y1, x2, y2 = map(int, bbox) # 边界检查 h, w = image.shape[:2] x1 = max(0, min(x1, w)) y1 = max(0, min(y1, h)) x2 = max(x1, min(x2, w)) y2 = max(y1, min(y2, h)) return image[y1:y2, x1:x2] def _convert_to_absolute_coords(self, relative_bbox: List[float], region_bbox: List[float]) -> List[float]: """将相对坐标转换为绝对坐标""" if len(relative_bbox) < 4 or len(region_bbox) < 4: return relative_bbox rx1, ry1, rx2, ry2 = relative_bbox bx1, by1, bx2, by2 = region_bbox # 计算绝对坐标 abs_x1 = bx1 + rx1 abs_y1 = by1 + ry1 abs_x2 = bx1 + rx2 abs_y2 = by1 + ry2 return [abs_x1, abs_y1, abs_x2, abs_y2] def _add_scene_specific_info(self, content: Dict[str, Any]) -> Dict[str, Any]: """根据场景添加特定信息""" if self.scene_name == 'bank_statement': return self._process_bank_statement_table(content) elif self.scene_name == 'financial_report': return self._process_financial_report_table(content) return {} def _process_bank_statement_table(self, content: Dict[str, Any]) -> Dict[str, Any]: """处理银行流水表格特定逻辑""" scene_info = { 'table_type': 'bank_statement', 'expected_columns': ['日期', '摘要', '收入', '支出', '余额'], 'validation_rules': { 'amount_format': True, 'date_format': True, 'balance_consistency': True } } # 进行银行流水特定的验证和处理 if 'html' in content and content['html']: # 这里可以添加银行流水特定的HTML后处理逻辑 pass return scene_info def _process_financial_report_table(self, content: Dict[str, Any]) -> Dict[str, Any]: """处理财务报表特定逻辑""" scene_info = { 'table_type': 'financial_report', 'complex_headers': True, 'merged_cells': True, 'validation_rules': { 'accounting_format': True, 'sum_validation': True } } # 进行财务报表特定的验证和处理 if 'html' in content and content['html']: # 这里可以添加财务报表特定的HTML后处理逻辑 pass return scene_info def cleanup(self): """清理资源""" try: if hasattr(self, 'preprocessor'): self.preprocessor.cleanup() if hasattr(self, 'layout_detector'): self.layout_detector.cleanup() if hasattr(self, 'vl_recognizer'): self.vl_recognizer.cleanup() if hasattr(self, 'ocr_recognizer'): self.ocr_recognizer.cleanup() logger.info("✅ Pipeline cleanup completed") except Exception as e: logger.warning(f"⚠️ Cleanup failed: {e}") def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.cleanup()