| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- from typing import Dict, List, Any, Optional, Union
- from pathlib import Path
- import numpy as np
- from PIL import Image
- import fitz # PyMuPDF
- from loguru import logger
- from .model_factory import ModelFactory
- from .config_manager import ConfigManager
- from models.adapters import BaseAdapter
- class FinancialDocPipeline:
- """金融文档处理统一流水线"""
-
- def __init__(self, config_path: str):
- self.config = ConfigManager.load_config(config_path)
- self.scene_name = self.config.get('scene_name', 'unknown')
-
- # 初始化各个组件
- self._init_components()
-
- def _init_components(self):
- """初始化处理组件"""
- try:
- # 1. 预处理器(方向分类、图像矫正等)
- self.preprocessor = ModelFactory.create_preprocessor(
- self.config['preprocessor']
- )
-
- # 2. 版式检测器
- self.layout_detector = ModelFactory.create_layout_detector(
- self.config['layout_detection']
- )
-
- # 3. VL识别器(表格、公式等)
- self.vl_recognizer = ModelFactory.create_vl_recognizer(
- self.config['vl_recognition']
- )
-
- # 4. OCR识别器
- self.ocr_recognizer = ModelFactory.create_ocr_recognizer(
- self.config['ocr_recognition']
- )
-
- logger.info(f"✅ Initialized pipeline for scene: {self.scene_name}")
-
- except Exception as e:
- logger.error(f"❌ Failed to initialize pipeline components: {e}")
- raise
-
- def process_document(self, document_path: str) -> Dict[str, Any]:
- """
- 处理文档的主流程
-
- Args:
- document_path: 文档路径
-
- Returns:
- 处理结果,包含所有元素的坐标和内容信息
- """
- results = {
- 'scene': self.scene_name,
- 'document_path': document_path,
- 'pages': [],
- 'metadata': self._extract_metadata(document_path)
- }
-
- try:
- # 加载文档图像
- images = self._load_document_images(document_path)
- logger.info(f"📄 Loaded {len(images)} pages from document")
-
- for page_idx, image in enumerate(images):
- logger.info(f"🔍 Processing page {page_idx + 1}/{len(images)}")
- page_result = self._process_single_page(image, page_idx)
- results['pages'].append(page_result)
-
- except Exception as e:
- logger.error(f"❌ Failed to process document: {e}")
- raise
-
- return results
-
- def _load_document_images(self, document_path: str) -> List[np.ndarray]:
- """加载文档图像"""
- document_path = Path(document_path)
-
- if not document_path.exists():
- raise FileNotFoundError(f"Document not found: {document_path}")
-
- images = []
-
- if document_path.suffix.lower() == '.pdf':
- # 处理PDF文件
- doc = fitz.open(document_path)
- try:
- for page_num in range(len(doc)):
- page = doc.load_page(page_num)
- # 设置合适的DPI
- dpi = self.config.get('input', {}).get('dpi', 200)
- mat = fitz.Matrix(dpi/72, dpi/72)
- pix = page.get_pixmap(matrix=mat)
- img_data = pix.tobytes("ppm")
-
- # 转换为numpy数组
- from io import BytesIO
- img = Image.open(BytesIO(img_data))
- img_array = np.array(img)
- images.append(img_array)
- finally:
- doc.close()
-
- elif document_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff']:
- # 处理图像文件
- img = Image.open(document_path)
- if img.mode != 'RGB':
- img = img.convert('RGB')
- img_array = np.array(img)
- images.append(img_array)
-
- else:
- raise ValueError(f"Unsupported file format: {document_path.suffix}")
-
- return images
-
- def _extract_metadata(self, document_path: str) -> Dict[str, Any]:
- """提取文档元数据"""
- document_path = Path(document_path)
-
- metadata = {
- 'filename': document_path.name,
- 'size': document_path.stat().st_size,
- 'format': document_path.suffix.lower()
- }
-
- # 如果是PDF,提取更多元数据
- if document_path.suffix.lower() == '.pdf':
- try:
- doc = fitz.open(document_path)
- metadata.update({
- 'page_count': len(doc),
- 'title': doc.metadata.get('title', ''),
- 'author': doc.metadata.get('author', ''),
- 'subject': doc.metadata.get('subject', ''),
- 'creator': doc.metadata.get('creator', '')
- })
- doc.close()
- except Exception:
- pass
-
- return metadata
-
- def _process_single_page(self, image: np.ndarray, page_idx: int) -> Dict[str, Any]:
- """处理单页文档"""
- # 1. 预处理(方向校正等)
- try:
- preprocessed_image, rotate_angle = self.preprocessor.process(image)
- except Exception as e:
- logger.warning(f"⚠️ Preprocessing failed for page {page_idx}: {e}")
- preprocessed_image = image
-
- # 2. 版式检测
- try:
- layout_results = self.layout_detector.detect(preprocessed_image)
- logger.info(f"📋 Detected {len(layout_results)} layout elements on page {page_idx}")
- except Exception as e:
- logger.error(f"❌ Layout detection failed for page {page_idx}: {e}")
- layout_results = []
-
- # 3. 根据场景类型分别处理不同元素
- page_elements = []
-
- for layout_item in layout_results:
- try:
- element_type = layout_item['category']
-
- if element_type in ['table_body', 'table']:
- # 表格使用VL模型处理
- element_result = self._process_table_element(
- preprocessed_image, layout_item
- )
- elif element_type in ['text', 'title', 'ocr_text']:
- # 文本使用OCR处理
- element_result = self._process_text_element(
- preprocessed_image, layout_item
- )
- elif element_type in ['interline_equation', 'inline_equation']:
- # 公式使用VL模型处理
- element_result = self._process_formula_element(
- preprocessed_image, layout_item
- )
- else:
- # 其他元素保持原样
- element_result = layout_item.copy()
- element_result['type'] = element_type
-
- page_elements.append(element_result)
-
- except Exception as e:
- logger.warning(f"⚠️ Failed to process element {element_type}: {e}")
- # 添加失败的元素,标记为错误
- error_element = layout_item.copy()
- error_element['type'] = 'error'
- error_element['error'] = str(e)
- page_elements.append(error_element)
-
- return {
- 'page_idx': page_idx,
- 'elements': page_elements,
- 'layout_raw': layout_results,
- 'image_shape': preprocessed_image.shape,
- 'processed_image': preprocessed_image,
- 'angle': rotate_angle
- }
-
- def _process_table_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
- """处理表格元素"""
- try:
- # 裁剪表格区域
- cropped_table = self._crop_region(image, layout_item['bbox'])
-
- # 使用VL模型识别表格
- table_result = self.vl_recognizer.recognize_table(
- cropped_table,
- return_cells_coordinate=True # 关键:返回单元格坐标
- )
-
- # 转换坐标到原图坐标系
- if 'cells' in table_result:
- for cell in table_result['cells']:
- cell['absolute_bbox'] = self._convert_to_absolute_coords(
- cell['bbox'], layout_item['bbox']
- )
-
- result = {
- 'type': 'table',
- 'bbox': layout_item['bbox'],
- 'confidence': layout_item.get('confidence', 0.0),
- 'content': table_result,
- 'scene_specific': self._add_scene_specific_info(table_result)
- }
-
- logger.info(f"✅ Table processed with {len(table_result.get('cells', []))} cells")
- return result
-
- except Exception as e:
- logger.error(f"❌ Table processing failed: {e}")
- return {
- 'type': 'table',
- 'bbox': layout_item['bbox'],
- 'content': {'html': '', 'markdown': '', 'cells': []},
- 'error': str(e)
- }
-
- def _process_text_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
- """处理文本元素"""
- try:
- # 裁剪文本区域
- cropped_text = self._crop_region(image, layout_item['bbox'])
-
- # 使用OCR识别文本
- text_results = self.ocr_recognizer.recognize_text(cropped_text)
-
- # 合并识别结果
- combined_text = ""
- if text_results:
- text_parts = [item['text'] for item in text_results if item['confidence'] > 0.5]
- combined_text = " ".join(text_parts)
-
- result = {
- 'type': layout_item['category'],
- 'bbox': layout_item['bbox'],
- 'confidence': layout_item.get('confidence', 0.0),
- 'content': {
- 'text': combined_text,
- 'ocr_details': text_results
- }
- }
-
- logger.info(f"✅ Text processed: '{combined_text[:50]}...'")
- return result
-
- except Exception as e:
- logger.error(f"❌ Text processing failed: {e}")
- return {
- 'type': layout_item['category'],
- 'bbox': layout_item['bbox'],
- 'content': {'text': '', 'ocr_details': []},
- 'error': str(e)
- }
-
- def _process_formula_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
- """处理公式元素"""
- try:
- # 裁剪公式区域
- cropped_formula = self._crop_region(image, layout_item['bbox'])
-
- # 使用VL模型识别公式
- formula_result = self.vl_recognizer.recognize_formula(cropped_formula)
-
- result = {
- 'type': 'formula',
- 'bbox': layout_item['bbox'],
- 'confidence': layout_item.get('confidence', 0.0),
- 'content': formula_result
- }
-
- logger.info(f"✅ Formula processed: {formula_result.get('latex', '')[:50]}...")
- return result
-
- except Exception as e:
- logger.error(f"❌ Formula processing failed: {e}")
- return {
- 'type': 'formula',
- 'bbox': layout_item['bbox'],
- 'content': {'latex': '', 'confidence': 0.0},
- 'error': str(e)
- }
-
- def _crop_region(self, image: np.ndarray, bbox: List[float]) -> np.ndarray:
- """裁剪图像区域"""
- if len(bbox) < 4:
- return image
-
- x1, y1, x2, y2 = map(int, bbox)
-
- # 边界检查
- h, w = image.shape[:2]
- x1 = max(0, min(x1, w))
- y1 = max(0, min(y1, h))
- x2 = max(x1, min(x2, w))
- y2 = max(y1, min(y2, h))
-
- return image[y1:y2, x1:x2]
-
- def _convert_to_absolute_coords(self, relative_bbox: List[float], region_bbox: List[float]) -> List[float]:
- """将相对坐标转换为绝对坐标"""
- if len(relative_bbox) < 4 or len(region_bbox) < 4:
- return relative_bbox
-
- rx1, ry1, rx2, ry2 = relative_bbox
- bx1, by1, bx2, by2 = region_bbox
-
- # 计算绝对坐标
- abs_x1 = bx1 + rx1
- abs_y1 = by1 + ry1
- abs_x2 = bx1 + rx2
- abs_y2 = by1 + ry2
-
- return [abs_x1, abs_y1, abs_x2, abs_y2]
-
- def _add_scene_specific_info(self, content: Dict[str, Any]) -> Dict[str, Any]:
- """根据场景添加特定信息"""
- if self.scene_name == 'bank_statement':
- return self._process_bank_statement_table(content)
- elif self.scene_name == 'financial_report':
- return self._process_financial_report_table(content)
- return {}
-
- def _process_bank_statement_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
- """处理银行流水表格特定逻辑"""
- scene_info = {
- 'table_type': 'bank_statement',
- 'expected_columns': ['日期', '摘要', '收入', '支出', '余额'],
- 'validation_rules': {
- 'amount_format': True,
- 'date_format': True,
- 'balance_consistency': True
- }
- }
-
- # 进行银行流水特定的验证和处理
- if 'html' in content and content['html']:
- # 这里可以添加银行流水特定的HTML后处理逻辑
- pass
-
- return scene_info
-
- def _process_financial_report_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
- """处理财务报表特定逻辑"""
- scene_info = {
- 'table_type': 'financial_report',
- 'complex_headers': True,
- 'merged_cells': True,
- 'validation_rules': {
- 'accounting_format': True,
- 'sum_validation': True
- }
- }
-
- # 进行财务报表特定的验证和处理
- if 'html' in content and content['html']:
- # 这里可以添加财务报表特定的HTML后处理逻辑
- pass
-
- return scene_info
-
- def cleanup(self):
- """清理资源"""
- try:
- if hasattr(self, 'preprocessor'):
- self.preprocessor.cleanup()
- if hasattr(self, 'layout_detector'):
- self.layout_detector.cleanup()
- if hasattr(self, 'vl_recognizer'):
- self.vl_recognizer.cleanup()
- if hasattr(self, 'ocr_recognizer'):
- self.ocr_recognizer.cleanup()
-
- logger.info("✅ Pipeline cleanup completed")
-
- except Exception as e:
- logger.warning(f"⚠️ Cleanup failed: {e}")
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.cleanup()
|