Sfoglia il codice sorgente

feat: 新增配置管理器、模型工厂和金融文档处理流水线,支持文档加载、处理及适配器注册

zhch158_admin 2 settimane fa
parent
commit
306142de56

+ 135 - 0
zhch/universal_doc_parser/core/config_manager.py

@@ -0,0 +1,135 @@
+"""配置管理器 - 加载和验证配置文件"""
+import yaml
+from pathlib import Path
+from typing import Dict, Any, Optional
+
+class ConfigManager:
+    """配置管理器"""
+    
+    _config_cache = {}
+    
+    @classmethod
+    def load_config(cls, config_path: str) -> Dict[str, Any]:
+        """加载配置文件"""
+        config_path = Path(config_path)
+        
+        # 缓存机制
+        cache_key = str(config_path.absolute())
+        if cache_key in cls._config_cache:
+            return cls._config_cache[cache_key]
+        
+        if not config_path.exists():
+            raise FileNotFoundError(f"Config file not found: {config_path}")
+        
+        with open(config_path, 'r', encoding='utf-8') as f:
+            config = yaml.safe_load(f)
+        
+        # 配置验证和默认值设置
+        config = cls._validate_and_set_defaults(config)
+        
+        # 缓存配置
+        cls._config_cache[cache_key] = config
+        
+        return config
+    
+    @classmethod
+    def _validate_and_set_defaults(cls, config: Dict[str, Any]) -> Dict[str, Any]:
+        """验证配置并设置默认值"""
+        # 设置默认场景名称
+        if 'scene_name' not in config:
+            config['scene_name'] = 'unknown'
+        
+        # 验证必需的配置项
+        required_sections = ['preprocessor', 'layout_detection', 'vl_recognition', 'ocr_recognition']
+        for section in required_sections:
+            if section not in config:
+                config[section] = {'module': 'mineru'}
+        
+        # 设置预处理器默认配置
+        preprocessor_defaults = {
+            'module': 'mineru',
+            'orientation_classifier': {'enabled': True},
+            'unwarping': {'enabled': False}
+        }
+        config['preprocessor'] = cls._merge_defaults(
+            config.get('preprocessor', {}), preprocessor_defaults
+        )
+        
+        # 设置版式检测默认配置
+        layout_defaults = {
+            'module': 'mineru',
+            'model_name': 'RT-DETR-H_layout_17cls',
+            'device': 'cpu',
+            'batch_size': 1,
+            'conf': 0.25,
+            'iou': 0.45
+        }
+        config['layout_detection'] = cls._merge_defaults(
+            config.get('layout_detection', {}), layout_defaults
+        )
+        
+        # 设置VL识别默认配置
+        vl_defaults = {
+            'module': 'mineru',
+            'backend': 'vllm-http-client',
+            'server_url': 'http://localhost:8111/v1',
+            'device': 'cpu',
+            'batch_size': 1,
+            'model_params': {'max_concurrency': 10, 'http_timeout': 600}
+        }
+        config['vl_recognition'] = cls._merge_defaults(
+            config.get('vl_recognition', {}), vl_defaults
+        )
+        
+        # 设置OCR默认配置
+        ocr_defaults = {
+            'module': 'mineru',
+            'language': 'ch',
+            'det_threshold': 0.3,
+            'unclip_ratio': 1.8,
+            'batch_size': 8,
+            'device': 'cpu'
+        }
+        config['ocr_recognition'] = cls._merge_defaults(
+            config.get('ocr_recognition', {}), ocr_defaults
+        )
+        
+        # 设置输出默认配置
+        output_defaults = {
+            'format': 'enhanced_json',
+            'save_json': True,
+            'save_markdown': True,
+            'save_html': True,
+            'save_images': {'layout': True, 'ocr': True, 'table_cells': True},
+            'coordinate_precision': 2
+        }
+        config['output'] = cls._merge_defaults(
+            config.get('output', {}), output_defaults
+        )
+        
+        return config
+    
+    @classmethod
+    def _merge_defaults(cls, user_config: Dict[str, Any], defaults: Dict[str, Any]) -> Dict[str, Any]:
+        """合并用户配置和默认配置"""
+        result = defaults.copy()
+        for key, value in user_config.items():
+            if isinstance(value, dict) and key in result and isinstance(result[key], dict):
+                result[key] = cls._merge_defaults(value, result[key])
+            else:
+                result[key] = value
+        return result
+    
+    @classmethod
+    def save_config(cls, config: Dict[str, Any], config_path: str):
+        """保存配置文件"""
+        config_path = Path(config_path)
+        config_path.parent.mkdir(parents=True, exist_ok=True)
+        
+        with open(config_path, 'w', encoding='utf-8') as f:
+            yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
+    
+    @classmethod
+    def clear_cache(cls):
+        """清空配置缓存"""
+        cls._config_cache.clear()

+ 88 - 0
zhch/universal_doc_parser/core/model_factory.py

@@ -0,0 +1,88 @@
+"""模型工厂 - 根据配置创建模型实例"""
+from typing import Dict, Any, Optional
+from models.adapters import BasePreprocessor, BaseLayoutDetector, BaseVLRecognizer, BaseOCRRecognizer
+
+class ModelFactory:
+    """模型工厂类,负责创建和管理各种模型适配器"""
+    
+    _adapters_registry = {}
+    
+    @classmethod
+    def register_adapter(cls, adapter_type: str, module_name: str, class_name: str):
+        """注册适配器"""
+        cls._adapters_registry[adapter_type] = {
+            'module': module_name,
+            'class': class_name
+        }
+    
+    @classmethod
+    def create_preprocessor(cls, config: Dict[str, Any]) -> BasePreprocessor:
+        """创建预处理器"""
+        module_name = config.get('module', 'mineru')
+        
+        if module_name == 'mineru':
+            from models.adapters import MinerUPreprocessor
+            preprocessor = MinerUPreprocessor(config)
+        else:
+            raise ValueError(f"Unknown preprocessor module: {module_name}")
+            
+        preprocessor.initialize()
+        return preprocessor
+    
+    @classmethod
+    def create_layout_detector(cls, config: Dict[str, Any]) -> BaseLayoutDetector:
+        # 根据配置创建检测器
+        module_name = config.get('module', 'mineru')
+        if module_name == 'paddle':
+            from models.adapters import PaddleLayoutDetector
+            detector = PaddleLayoutDetector(config)
+        elif module_name == 'mineru':
+            from models.adapters import MinerULayoutDetector
+            detector = MinerULayoutDetector(config)
+        else:
+            raise ValueError(f"Unknown layout detector module: {module_name}")
+
+        detector.initialize()
+        return detector
+    
+    @classmethod
+    def create_vl_recognizer(cls, config: Dict[str, Any]) -> BaseVLRecognizer:
+        """创建VL识别器"""
+        module_name = config.get('module', 'mineru')
+        
+        if module_name == 'mineru':
+            from models.adapters import MinerUVLRecognizer
+            recognizer = MinerUVLRecognizer(config)
+        else:
+            raise ValueError(f"Unknown VL recognizer module: {module_name}")
+            
+        recognizer.initialize()
+        return recognizer
+    
+    @classmethod
+    def create_ocr_recognizer(cls, config: Dict[str, Any]) -> BaseOCRRecognizer:
+        """创建OCR识别器"""
+        module_name = config.get('module', 'mineru')
+        
+        if module_name == 'mineru':
+            from models.adapters import MinerUOCRRecognizer
+            recognizer = MinerUOCRRecognizer(config)
+        else:
+            raise ValueError(f"Unknown OCR recognizer module: {module_name}")
+            
+        recognizer.initialize()
+        return recognizer
+    
+    @classmethod
+    def cleanup_all(cls):
+        """清理所有模型资源"""
+        # 在实际应用中,可以维护一个活跃模型列表进行清理
+        pass
+
+# 注册默认适配器
+ModelFactory.register_adapter('preprocessor', 'mineru_adapter', 'MinerUPreprocessor')
+ModelFactory.register_adapter('layout_detector', 'mineru_adapter', 'MinerULayoutDetector')
+ModelFactory.register_adapter('vl_recognizer', 'mineru_adapter', 'MinerUVLRecognizer')
+ModelFactory.register_adapter('ocr_recognizer', 'mineru_adapter', 'MinerUOCRRecognizer')
+
+ModelFactory.register_adapter('layout_detector', 'paddle_adapter', 'PaddleLayoutDetector')

+ 419 - 0
zhch/universal_doc_parser/core/pipeline_manager.py

@@ -0,0 +1,419 @@
+from typing import Dict, List, Any, Optional, Union
+from pathlib import Path
+import numpy as np
+from PIL import Image
+import fitz  # PyMuPDF
+from loguru import logger
+
+from .model_factory import ModelFactory
+from .config_manager import ConfigManager
+from models.adapters import BaseAdapter
+
+class FinancialDocPipeline:
+    """金融文档处理统一流水线"""
+    
+    def __init__(self, config_path: str):
+        self.config = ConfigManager.load_config(config_path)
+        self.scene_name = self.config.get('scene_name', 'unknown')
+        
+        # 初始化各个组件
+        self._init_components()
+        
+    def _init_components(self):
+        """初始化处理组件"""
+        try:
+            # 1. 预处理器(方向分类、图像矫正等)
+            self.preprocessor = ModelFactory.create_preprocessor(
+                self.config['preprocessor']
+            )
+            
+            # 2. 版式检测器
+            self.layout_detector = ModelFactory.create_layout_detector(
+                self.config['layout_detection']
+            )
+            
+            # 3. VL识别器(表格、公式等)
+            self.vl_recognizer = ModelFactory.create_vl_recognizer(
+                self.config['vl_recognition']
+            )
+            
+            # 4. OCR识别器
+            self.ocr_recognizer = ModelFactory.create_ocr_recognizer(
+                self.config['ocr_recognition']
+            )
+            
+            logger.info(f"✅ Initialized pipeline for scene: {self.scene_name}")
+            
+        except Exception as e:
+            logger.error(f"❌ Failed to initialize pipeline components: {e}")
+            raise
+    
+    def process_document(self, document_path: str) -> Dict[str, Any]:
+        """
+        处理文档的主流程
+        
+        Args:
+            document_path: 文档路径
+            
+        Returns:
+            处理结果,包含所有元素的坐标和内容信息
+        """
+        results = {
+            'scene': self.scene_name,
+            'document_path': document_path,
+            'pages': [],
+            'metadata': self._extract_metadata(document_path)
+        }
+        
+        try:
+            # 加载文档图像
+            images = self._load_document_images(document_path)
+            logger.info(f"📄 Loaded {len(images)} pages from document")
+            
+            for page_idx, image in enumerate(images):
+                logger.info(f"🔍 Processing page {page_idx + 1}/{len(images)}")
+                page_result = self._process_single_page(image, page_idx)
+                results['pages'].append(page_result)
+                
+        except Exception as e:
+            logger.error(f"❌ Failed to process document: {e}")
+            raise
+            
+        return results
+    
+    def _load_document_images(self, document_path: str) -> List[np.ndarray]:
+        """加载文档图像"""
+        document_path = Path(document_path)
+        
+        if not document_path.exists():
+            raise FileNotFoundError(f"Document not found: {document_path}")
+        
+        images = []
+        
+        if document_path.suffix.lower() == '.pdf':
+            # 处理PDF文件
+            doc = fitz.open(document_path)
+            try:
+                for page_num in range(len(doc)):
+                    page = doc.load_page(page_num)
+                    # 设置合适的DPI
+                    dpi = self.config.get('input', {}).get('dpi', 200)
+                    mat = fitz.Matrix(dpi/72, dpi/72)
+                    pix = page.get_pixmap(matrix=mat)
+                    img_data = pix.tobytes("ppm")
+                    
+                    # 转换为numpy数组
+                    from io import BytesIO
+                    img = Image.open(BytesIO(img_data))
+                    img_array = np.array(img)
+                    images.append(img_array)
+            finally:
+                doc.close()
+                
+        elif document_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff']:
+            # 处理图像文件
+            img = Image.open(document_path)
+            if img.mode != 'RGB':
+                img = img.convert('RGB')
+            img_array = np.array(img)
+            images.append(img_array)
+            
+        else:
+            raise ValueError(f"Unsupported file format: {document_path.suffix}")
+        
+        return images
+    
+    def _extract_metadata(self, document_path: str) -> Dict[str, Any]:
+        """提取文档元数据"""
+        document_path = Path(document_path)
+        
+        metadata = {
+            'filename': document_path.name,
+            'size': document_path.stat().st_size,
+            'format': document_path.suffix.lower()
+        }
+        
+        # 如果是PDF,提取更多元数据
+        if document_path.suffix.lower() == '.pdf':
+            try:
+                doc = fitz.open(document_path)
+                metadata.update({
+                    'page_count': len(doc),
+                    'title': doc.metadata.get('title', ''),
+                    'author': doc.metadata.get('author', ''),
+                    'subject': doc.metadata.get('subject', ''),
+                    'creator': doc.metadata.get('creator', '')
+                })
+                doc.close()
+            except Exception:
+                pass
+        
+        return metadata
+    
+    def _process_single_page(self, image: np.ndarray, page_idx: int) -> Dict[str, Any]:
+        """处理单页文档"""
+        # 1. 预处理(方向校正等)
+        try:
+            preprocessed_image, rotate_angle = self.preprocessor.process(image)
+        except Exception as e:
+            logger.warning(f"⚠️ Preprocessing failed for page {page_idx}: {e}")
+            preprocessed_image = image
+        
+        # 2. 版式检测
+        try:
+            layout_results = self.layout_detector.detect(preprocessed_image)
+            logger.info(f"📋 Detected {len(layout_results)} layout elements on page {page_idx}")
+        except Exception as e:
+            logger.error(f"❌ Layout detection failed for page {page_idx}: {e}")
+            layout_results = []
+        
+        # 3. 根据场景类型分别处理不同元素
+        page_elements = []
+        
+        for layout_item in layout_results:
+            try:
+                element_type = layout_item['category']
+                
+                if element_type in ['table_body', 'table']:
+                    # 表格使用VL模型处理
+                    element_result = self._process_table_element(
+                        preprocessed_image, layout_item
+                    )
+                elif element_type in ['text', 'title', 'ocr_text']:
+                    # 文本使用OCR处理
+                    element_result = self._process_text_element(
+                        preprocessed_image, layout_item
+                    )
+                elif element_type in ['interline_equation', 'inline_equation']:
+                    # 公式使用VL模型处理
+                    element_result = self._process_formula_element(
+                        preprocessed_image, layout_item
+                    )
+                else:
+                    # 其他元素保持原样
+                    element_result = layout_item.copy()
+                    element_result['type'] = element_type
+                    
+                page_elements.append(element_result)
+                
+            except Exception as e:
+                logger.warning(f"⚠️ Failed to process element {element_type}: {e}")
+                # 添加失败的元素,标记为错误
+                error_element = layout_item.copy()
+                error_element['type'] = 'error'
+                error_element['error'] = str(e)
+                page_elements.append(error_element)
+        
+        return {
+            'page_idx': page_idx,
+            'elements': page_elements,
+            'layout_raw': layout_results,
+            'image_shape': preprocessed_image.shape,
+            'processed_image': preprocessed_image,
+            'angle': rotate_angle
+        }
+    
+    def _process_table_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
+        """处理表格元素"""
+        try:
+            # 裁剪表格区域
+            cropped_table = self._crop_region(image, layout_item['bbox'])
+            
+            # 使用VL模型识别表格
+            table_result = self.vl_recognizer.recognize_table(
+                cropped_table,
+                return_cells_coordinate=True  # 关键:返回单元格坐标
+            )
+            
+            # 转换坐标到原图坐标系
+            if 'cells' in table_result:
+                for cell in table_result['cells']:
+                    cell['absolute_bbox'] = self._convert_to_absolute_coords(
+                        cell['bbox'], layout_item['bbox']
+                    )
+            
+            result = {
+                'type': 'table',
+                'bbox': layout_item['bbox'],
+                'confidence': layout_item.get('confidence', 0.0),
+                'content': table_result,
+                'scene_specific': self._add_scene_specific_info(table_result)
+            }
+            
+            logger.info(f"✅ Table processed with {len(table_result.get('cells', []))} cells")
+            return result
+            
+        except Exception as e:
+            logger.error(f"❌ Table processing failed: {e}")
+            return {
+                'type': 'table',
+                'bbox': layout_item['bbox'],
+                'content': {'html': '', 'markdown': '', 'cells': []},
+                'error': str(e)
+            }
+    
+    def _process_text_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
+        """处理文本元素"""
+        try:
+            # 裁剪文本区域
+            cropped_text = self._crop_region(image, layout_item['bbox'])
+            
+            # 使用OCR识别文本
+            text_results = self.ocr_recognizer.recognize_text(cropped_text)
+            
+            # 合并识别结果
+            combined_text = ""
+            if text_results:
+                text_parts = [item['text'] for item in text_results if item['confidence'] > 0.5]
+                combined_text = " ".join(text_parts)
+            
+            result = {
+                'type': layout_item['category'],
+                'bbox': layout_item['bbox'],
+                'confidence': layout_item.get('confidence', 0.0),
+                'content': {
+                    'text': combined_text,
+                    'ocr_details': text_results
+                }
+            }
+            
+            logger.info(f"✅ Text processed: '{combined_text[:50]}...'")
+            return result
+            
+        except Exception as e:
+            logger.error(f"❌ Text processing failed: {e}")
+            return {
+                'type': layout_item['category'],
+                'bbox': layout_item['bbox'],
+                'content': {'text': '', 'ocr_details': []},
+                'error': str(e)
+            }
+    
+    def _process_formula_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
+        """处理公式元素"""
+        try:
+            # 裁剪公式区域
+            cropped_formula = self._crop_region(image, layout_item['bbox'])
+            
+            # 使用VL模型识别公式
+            formula_result = self.vl_recognizer.recognize_formula(cropped_formula)
+            
+            result = {
+                'type': 'formula',
+                'bbox': layout_item['bbox'],
+                'confidence': layout_item.get('confidence', 0.0),
+                'content': formula_result
+            }
+            
+            logger.info(f"✅ Formula processed: {formula_result.get('latex', '')[:50]}...")
+            return result
+            
+        except Exception as e:
+            logger.error(f"❌ Formula processing failed: {e}")
+            return {
+                'type': 'formula',
+                'bbox': layout_item['bbox'],
+                'content': {'latex': '', 'confidence': 0.0},
+                'error': str(e)
+            }
+    
+    def _crop_region(self, image: np.ndarray, bbox: List[float]) -> np.ndarray:
+        """裁剪图像区域"""
+        if len(bbox) < 4:
+            return image
+            
+        x1, y1, x2, y2 = map(int, bbox)
+        
+        # 边界检查
+        h, w = image.shape[:2]
+        x1 = max(0, min(x1, w))
+        y1 = max(0, min(y1, h))
+        x2 = max(x1, min(x2, w))
+        y2 = max(y1, min(y2, h))
+        
+        return image[y1:y2, x1:x2]
+    
+    def _convert_to_absolute_coords(self, relative_bbox: List[float], region_bbox: List[float]) -> List[float]:
+        """将相对坐标转换为绝对坐标"""
+        if len(relative_bbox) < 4 or len(region_bbox) < 4:
+            return relative_bbox
+            
+        rx1, ry1, rx2, ry2 = relative_bbox
+        bx1, by1, bx2, by2 = region_bbox
+        
+        # 计算绝对坐标
+        abs_x1 = bx1 + rx1
+        abs_y1 = by1 + ry1
+        abs_x2 = bx1 + rx2
+        abs_y2 = by1 + ry2
+        
+        return [abs_x1, abs_y1, abs_x2, abs_y2]
+    
+    def _add_scene_specific_info(self, content: Dict[str, Any]) -> Dict[str, Any]:
+        """根据场景添加特定信息"""
+        if self.scene_name == 'bank_statement':
+            return self._process_bank_statement_table(content)
+        elif self.scene_name == 'financial_report':
+            return self._process_financial_report_table(content)
+        return {}
+    
+    def _process_bank_statement_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
+        """处理银行流水表格特定逻辑"""
+        scene_info = {
+            'table_type': 'bank_statement',
+            'expected_columns': ['日期', '摘要', '收入', '支出', '余额'],
+            'validation_rules': {
+                'amount_format': True,
+                'date_format': True,
+                'balance_consistency': True
+            }
+        }
+        
+        # 进行银行流水特定的验证和处理
+        if 'html' in content and content['html']:
+            # 这里可以添加银行流水特定的HTML后处理逻辑
+            pass
+            
+        return scene_info
+    
+    def _process_financial_report_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
+        """处理财务报表特定逻辑"""
+        scene_info = {
+            'table_type': 'financial_report',
+            'complex_headers': True,
+            'merged_cells': True,
+            'validation_rules': {
+                'accounting_format': True,
+                'sum_validation': True
+            }
+        }
+        
+        # 进行财务报表特定的验证和处理
+        if 'html' in content and content['html']:
+            # 这里可以添加财务报表特定的HTML后处理逻辑
+            pass
+            
+        return scene_info
+    
+    def cleanup(self):
+        """清理资源"""
+        try:
+            if hasattr(self, 'preprocessor'):
+                self.preprocessor.cleanup()
+            if hasattr(self, 'layout_detector'):
+                self.layout_detector.cleanup()
+            if hasattr(self, 'vl_recognizer'):
+                self.vl_recognizer.cleanup()
+            if hasattr(self, 'ocr_recognizer'):
+                self.ocr_recognizer.cleanup()
+                
+            logger.info("✅ Pipeline cleanup completed")
+            
+        except Exception as e:
+            logger.warning(f"⚠️ Cleanup failed: {e}")
+    
+    def __enter__(self):
+        return self
+    
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.cleanup()