|
@@ -0,0 +1,419 @@
|
|
|
|
|
+from typing import Dict, List, Any, Optional, Union
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+import fitz # PyMuPDF
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
+
|
|
|
|
|
+from .model_factory import ModelFactory
|
|
|
|
|
+from .config_manager import ConfigManager
|
|
|
|
|
+from models.adapters import BaseAdapter
|
|
|
|
|
+
|
|
|
|
|
+class FinancialDocPipeline:
|
|
|
|
|
+ """金融文档处理统一流水线"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, config_path: str):
|
|
|
|
|
+ self.config = ConfigManager.load_config(config_path)
|
|
|
|
|
+ self.scene_name = self.config.get('scene_name', 'unknown')
|
|
|
|
|
+
|
|
|
|
|
+ # 初始化各个组件
|
|
|
|
|
+ self._init_components()
|
|
|
|
|
+
|
|
|
|
|
+ def _init_components(self):
|
|
|
|
|
+ """初始化处理组件"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 1. 预处理器(方向分类、图像矫正等)
|
|
|
|
|
+ self.preprocessor = ModelFactory.create_preprocessor(
|
|
|
|
|
+ self.config['preprocessor']
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 版式检测器
|
|
|
|
|
+ self.layout_detector = ModelFactory.create_layout_detector(
|
|
|
|
|
+ self.config['layout_detection']
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 3. VL识别器(表格、公式等)
|
|
|
|
|
+ self.vl_recognizer = ModelFactory.create_vl_recognizer(
|
|
|
|
|
+ self.config['vl_recognition']
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 4. OCR识别器
|
|
|
|
|
+ self.ocr_recognizer = ModelFactory.create_ocr_recognizer(
|
|
|
|
|
+ self.config['ocr_recognition']
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"✅ Initialized pipeline for scene: {self.scene_name}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Failed to initialize pipeline components: {e}")
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
|
|
+ def process_document(self, document_path: str) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 处理文档的主流程
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ document_path: 文档路径
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 处理结果,包含所有元素的坐标和内容信息
|
|
|
|
|
+ """
|
|
|
|
|
+ results = {
|
|
|
|
|
+ 'scene': self.scene_name,
|
|
|
|
|
+ 'document_path': document_path,
|
|
|
|
|
+ 'pages': [],
|
|
|
|
|
+ 'metadata': self._extract_metadata(document_path)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 加载文档图像
|
|
|
|
|
+ images = self._load_document_images(document_path)
|
|
|
|
|
+ logger.info(f"📄 Loaded {len(images)} pages from document")
|
|
|
|
|
+
|
|
|
|
|
+ for page_idx, image in enumerate(images):
|
|
|
|
|
+ logger.info(f"🔍 Processing page {page_idx + 1}/{len(images)}")
|
|
|
|
|
+ page_result = self._process_single_page(image, page_idx)
|
|
|
|
|
+ results['pages'].append(page_result)
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Failed to process document: {e}")
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
|
|
+ return results
|
|
|
|
|
+
|
|
|
|
|
+ def _load_document_images(self, document_path: str) -> List[np.ndarray]:
|
|
|
|
|
+ """加载文档图像"""
|
|
|
|
|
+ document_path = Path(document_path)
|
|
|
|
|
+
|
|
|
|
|
+ if not document_path.exists():
|
|
|
|
|
+ raise FileNotFoundError(f"Document not found: {document_path}")
|
|
|
|
|
+
|
|
|
|
|
+ images = []
|
|
|
|
|
+
|
|
|
|
|
+ if document_path.suffix.lower() == '.pdf':
|
|
|
|
|
+ # 处理PDF文件
|
|
|
|
|
+ doc = fitz.open(document_path)
|
|
|
|
|
+ try:
|
|
|
|
|
+ for page_num in range(len(doc)):
|
|
|
|
|
+ page = doc.load_page(page_num)
|
|
|
|
|
+ # 设置合适的DPI
|
|
|
|
|
+ dpi = self.config.get('input', {}).get('dpi', 200)
|
|
|
|
|
+ mat = fitz.Matrix(dpi/72, dpi/72)
|
|
|
|
|
+ pix = page.get_pixmap(matrix=mat)
|
|
|
|
|
+ img_data = pix.tobytes("ppm")
|
|
|
|
|
+
|
|
|
|
|
+ # 转换为numpy数组
|
|
|
|
|
+ from io import BytesIO
|
|
|
|
|
+ img = Image.open(BytesIO(img_data))
|
|
|
|
|
+ img_array = np.array(img)
|
|
|
|
|
+ images.append(img_array)
|
|
|
|
|
+ finally:
|
|
|
|
|
+ doc.close()
|
|
|
|
|
+
|
|
|
|
|
+ elif document_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff']:
|
|
|
|
|
+ # 处理图像文件
|
|
|
|
|
+ img = Image.open(document_path)
|
|
|
|
|
+ if img.mode != 'RGB':
|
|
|
|
|
+ img = img.convert('RGB')
|
|
|
|
|
+ img_array = np.array(img)
|
|
|
|
|
+ images.append(img_array)
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"Unsupported file format: {document_path.suffix}")
|
|
|
|
|
+
|
|
|
|
|
+ return images
|
|
|
|
|
+
|
|
|
|
|
+ def _extract_metadata(self, document_path: str) -> Dict[str, Any]:
|
|
|
|
|
+ """提取文档元数据"""
|
|
|
|
|
+ document_path = Path(document_path)
|
|
|
|
|
+
|
|
|
|
|
+ metadata = {
|
|
|
|
|
+ 'filename': document_path.name,
|
|
|
|
|
+ 'size': document_path.stat().st_size,
|
|
|
|
|
+ 'format': document_path.suffix.lower()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 如果是PDF,提取更多元数据
|
|
|
|
|
+ if document_path.suffix.lower() == '.pdf':
|
|
|
|
|
+ try:
|
|
|
|
|
+ doc = fitz.open(document_path)
|
|
|
|
|
+ metadata.update({
|
|
|
|
|
+ 'page_count': len(doc),
|
|
|
|
|
+ 'title': doc.metadata.get('title', ''),
|
|
|
|
|
+ 'author': doc.metadata.get('author', ''),
|
|
|
|
|
+ 'subject': doc.metadata.get('subject', ''),
|
|
|
|
|
+ 'creator': doc.metadata.get('creator', '')
|
|
|
|
|
+ })
|
|
|
|
|
+ doc.close()
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ return metadata
|
|
|
|
|
+
|
|
|
|
|
+ def _process_single_page(self, image: np.ndarray, page_idx: int) -> Dict[str, Any]:
|
|
|
|
|
+ """处理单页文档"""
|
|
|
|
|
+ # 1. 预处理(方向校正等)
|
|
|
|
|
+ try:
|
|
|
|
|
+ preprocessed_image, rotate_angle = self.preprocessor.process(image)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"⚠️ Preprocessing failed for page {page_idx}: {e}")
|
|
|
|
|
+ preprocessed_image = image
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 版式检测
|
|
|
|
|
+ try:
|
|
|
|
|
+ layout_results = self.layout_detector.detect(preprocessed_image)
|
|
|
|
|
+ logger.info(f"📋 Detected {len(layout_results)} layout elements on page {page_idx}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Layout detection failed for page {page_idx}: {e}")
|
|
|
|
|
+ layout_results = []
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 根据场景类型分别处理不同元素
|
|
|
|
|
+ page_elements = []
|
|
|
|
|
+
|
|
|
|
|
+ for layout_item in layout_results:
|
|
|
|
|
+ try:
|
|
|
|
|
+ element_type = layout_item['category']
|
|
|
|
|
+
|
|
|
|
|
+ if element_type in ['table_body', 'table']:
|
|
|
|
|
+ # 表格使用VL模型处理
|
|
|
|
|
+ element_result = self._process_table_element(
|
|
|
|
|
+ preprocessed_image, layout_item
|
|
|
|
|
+ )
|
|
|
|
|
+ elif element_type in ['text', 'title', 'ocr_text']:
|
|
|
|
|
+ # 文本使用OCR处理
|
|
|
|
|
+ element_result = self._process_text_element(
|
|
|
|
|
+ preprocessed_image, layout_item
|
|
|
|
|
+ )
|
|
|
|
|
+ elif element_type in ['interline_equation', 'inline_equation']:
|
|
|
|
|
+ # 公式使用VL模型处理
|
|
|
|
|
+ element_result = self._process_formula_element(
|
|
|
|
|
+ preprocessed_image, layout_item
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 其他元素保持原样
|
|
|
|
|
+ element_result = layout_item.copy()
|
|
|
|
|
+ element_result['type'] = element_type
|
|
|
|
|
+
|
|
|
|
|
+ page_elements.append(element_result)
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"⚠️ Failed to process element {element_type}: {e}")
|
|
|
|
|
+ # 添加失败的元素,标记为错误
|
|
|
|
|
+ error_element = layout_item.copy()
|
|
|
|
|
+ error_element['type'] = 'error'
|
|
|
|
|
+ error_element['error'] = str(e)
|
|
|
|
|
+ page_elements.append(error_element)
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'page_idx': page_idx,
|
|
|
|
|
+ 'elements': page_elements,
|
|
|
|
|
+ 'layout_raw': layout_results,
|
|
|
|
|
+ 'image_shape': preprocessed_image.shape,
|
|
|
|
|
+ 'processed_image': preprocessed_image,
|
|
|
|
|
+ 'angle': rotate_angle
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _process_table_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ """处理表格元素"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 裁剪表格区域
|
|
|
|
|
+ cropped_table = self._crop_region(image, layout_item['bbox'])
|
|
|
|
|
+
|
|
|
|
|
+ # 使用VL模型识别表格
|
|
|
|
|
+ table_result = self.vl_recognizer.recognize_table(
|
|
|
|
|
+ cropped_table,
|
|
|
|
|
+ return_cells_coordinate=True # 关键:返回单元格坐标
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 转换坐标到原图坐标系
|
|
|
|
|
+ if 'cells' in table_result:
|
|
|
|
|
+ for cell in table_result['cells']:
|
|
|
|
|
+ cell['absolute_bbox'] = self._convert_to_absolute_coords(
|
|
|
|
|
+ cell['bbox'], layout_item['bbox']
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ result = {
|
|
|
|
|
+ 'type': 'table',
|
|
|
|
|
+ 'bbox': layout_item['bbox'],
|
|
|
|
|
+ 'confidence': layout_item.get('confidence', 0.0),
|
|
|
|
|
+ 'content': table_result,
|
|
|
|
|
+ 'scene_specific': self._add_scene_specific_info(table_result)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"✅ Table processed with {len(table_result.get('cells', []))} cells")
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Table processing failed: {e}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'type': 'table',
|
|
|
|
|
+ 'bbox': layout_item['bbox'],
|
|
|
|
|
+ 'content': {'html': '', 'markdown': '', 'cells': []},
|
|
|
|
|
+ 'error': str(e)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _process_text_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ """处理文本元素"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 裁剪文本区域
|
|
|
|
|
+ cropped_text = self._crop_region(image, layout_item['bbox'])
|
|
|
|
|
+
|
|
|
|
|
+ # 使用OCR识别文本
|
|
|
|
|
+ text_results = self.ocr_recognizer.recognize_text(cropped_text)
|
|
|
|
|
+
|
|
|
|
|
+ # 合并识别结果
|
|
|
|
|
+ combined_text = ""
|
|
|
|
|
+ if text_results:
|
|
|
|
|
+ text_parts = [item['text'] for item in text_results if item['confidence'] > 0.5]
|
|
|
|
|
+ combined_text = " ".join(text_parts)
|
|
|
|
|
+
|
|
|
|
|
+ result = {
|
|
|
|
|
+ 'type': layout_item['category'],
|
|
|
|
|
+ 'bbox': layout_item['bbox'],
|
|
|
|
|
+ 'confidence': layout_item.get('confidence', 0.0),
|
|
|
|
|
+ 'content': {
|
|
|
|
|
+ 'text': combined_text,
|
|
|
|
|
+ 'ocr_details': text_results
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"✅ Text processed: '{combined_text[:50]}...'")
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Text processing failed: {e}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'type': layout_item['category'],
|
|
|
|
|
+ 'bbox': layout_item['bbox'],
|
|
|
|
|
+ 'content': {'text': '', 'ocr_details': []},
|
|
|
|
|
+ 'error': str(e)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _process_formula_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ """处理公式元素"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 裁剪公式区域
|
|
|
|
|
+ cropped_formula = self._crop_region(image, layout_item['bbox'])
|
|
|
|
|
+
|
|
|
|
|
+ # 使用VL模型识别公式
|
|
|
|
|
+ formula_result = self.vl_recognizer.recognize_formula(cropped_formula)
|
|
|
|
|
+
|
|
|
|
|
+ result = {
|
|
|
|
|
+ 'type': 'formula',
|
|
|
|
|
+ 'bbox': layout_item['bbox'],
|
|
|
|
|
+ 'confidence': layout_item.get('confidence', 0.0),
|
|
|
|
|
+ 'content': formula_result
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"✅ Formula processed: {formula_result.get('latex', '')[:50]}...")
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Formula processing failed: {e}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'type': 'formula',
|
|
|
|
|
+ 'bbox': layout_item['bbox'],
|
|
|
|
|
+ 'content': {'latex': '', 'confidence': 0.0},
|
|
|
|
|
+ 'error': str(e)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _crop_region(self, image: np.ndarray, bbox: List[float]) -> np.ndarray:
|
|
|
|
|
+ """裁剪图像区域"""
|
|
|
|
|
+ if len(bbox) < 4:
|
|
|
|
|
+ return image
|
|
|
|
|
+
|
|
|
|
|
+ x1, y1, x2, y2 = map(int, bbox)
|
|
|
|
|
+
|
|
|
|
|
+ # 边界检查
|
|
|
|
|
+ h, w = image.shape[:2]
|
|
|
|
|
+ x1 = max(0, min(x1, w))
|
|
|
|
|
+ y1 = max(0, min(y1, h))
|
|
|
|
|
+ x2 = max(x1, min(x2, w))
|
|
|
|
|
+ y2 = max(y1, min(y2, h))
|
|
|
|
|
+
|
|
|
|
|
+ return image[y1:y2, x1:x2]
|
|
|
|
|
+
|
|
|
|
|
+ def _convert_to_absolute_coords(self, relative_bbox: List[float], region_bbox: List[float]) -> List[float]:
|
|
|
|
|
+ """将相对坐标转换为绝对坐标"""
|
|
|
|
|
+ if len(relative_bbox) < 4 or len(region_bbox) < 4:
|
|
|
|
|
+ return relative_bbox
|
|
|
|
|
+
|
|
|
|
|
+ rx1, ry1, rx2, ry2 = relative_bbox
|
|
|
|
|
+ bx1, by1, bx2, by2 = region_bbox
|
|
|
|
|
+
|
|
|
|
|
+ # 计算绝对坐标
|
|
|
|
|
+ abs_x1 = bx1 + rx1
|
|
|
|
|
+ abs_y1 = by1 + ry1
|
|
|
|
|
+ abs_x2 = bx1 + rx2
|
|
|
|
|
+ abs_y2 = by1 + ry2
|
|
|
|
|
+
|
|
|
|
|
+ return [abs_x1, abs_y1, abs_x2, abs_y2]
|
|
|
|
|
+
|
|
|
|
|
+ def _add_scene_specific_info(self, content: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ """根据场景添加特定信息"""
|
|
|
|
|
+ if self.scene_name == 'bank_statement':
|
|
|
|
|
+ return self._process_bank_statement_table(content)
|
|
|
|
|
+ elif self.scene_name == 'financial_report':
|
|
|
|
|
+ return self._process_financial_report_table(content)
|
|
|
|
|
+ return {}
|
|
|
|
|
+
|
|
|
|
|
+ def _process_bank_statement_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ """处理银行流水表格特定逻辑"""
|
|
|
|
|
+ scene_info = {
|
|
|
|
|
+ 'table_type': 'bank_statement',
|
|
|
|
|
+ 'expected_columns': ['日期', '摘要', '收入', '支出', '余额'],
|
|
|
|
|
+ 'validation_rules': {
|
|
|
|
|
+ 'amount_format': True,
|
|
|
|
|
+ 'date_format': True,
|
|
|
|
|
+ 'balance_consistency': True
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 进行银行流水特定的验证和处理
|
|
|
|
|
+ if 'html' in content and content['html']:
|
|
|
|
|
+ # 这里可以添加银行流水特定的HTML后处理逻辑
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ return scene_info
|
|
|
|
|
+
|
|
|
|
|
+ def _process_financial_report_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
+ """处理财务报表特定逻辑"""
|
|
|
|
|
+ scene_info = {
|
|
|
|
|
+ 'table_type': 'financial_report',
|
|
|
|
|
+ 'complex_headers': True,
|
|
|
|
|
+ 'merged_cells': True,
|
|
|
|
|
+ 'validation_rules': {
|
|
|
|
|
+ 'accounting_format': True,
|
|
|
|
|
+ 'sum_validation': True
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 进行财务报表特定的验证和处理
|
|
|
|
|
+ if 'html' in content and content['html']:
|
|
|
|
|
+ # 这里可以添加财务报表特定的HTML后处理逻辑
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ return scene_info
|
|
|
|
|
+
|
|
|
|
|
+ def cleanup(self):
|
|
|
|
|
+ """清理资源"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ if hasattr(self, 'preprocessor'):
|
|
|
|
|
+ self.preprocessor.cleanup()
|
|
|
|
|
+ if hasattr(self, 'layout_detector'):
|
|
|
|
|
+ self.layout_detector.cleanup()
|
|
|
|
|
+ if hasattr(self, 'vl_recognizer'):
|
|
|
|
|
+ self.vl_recognizer.cleanup()
|
|
|
|
|
+ if hasattr(self, 'ocr_recognizer'):
|
|
|
|
|
+ self.ocr_recognizer.cleanup()
|
|
|
|
|
+
|
|
|
|
|
+ logger.info("✅ Pipeline cleanup completed")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"⚠️ Cleanup failed: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ def __enter__(self):
|
|
|
|
|
+ return self
|
|
|
|
|
+
|
|
|
|
|
+ def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
+ self.cleanup()
|