|
|
@@ -0,0 +1,1030 @@
|
|
|
+正在收集工作区信息正在筛选到最相关的信息根据您的需求,我设计了一个**通用的多模型文档解析框架**,支持灵活配置不同场景的模型组合。以下是完整的架构设计:
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 🎯 架构概览
|
|
|
+
|
|
|
+```mermaid
|
|
|
+graph TB
|
|
|
+ Config[配置文件 YAML] --> Parser[UniversalDocParser]
|
|
|
+ Parser --> PreProc[文档预处理]
|
|
|
+ PreProc --> Layout[版面检测]
|
|
|
+ Layout --> VL[VLM识别]
|
|
|
+ VL --> PostProc[后处理]
|
|
|
+ PostProc --> Output[输出 JSON/MD/Images]
|
|
|
+
|
|
|
+ subgraph "可配置模型池"
|
|
|
+ LayoutPool[PP-DocLayoutV2<br>YOLOv10<br>RT-DETR]
|
|
|
+ VLPool[MinerU-VLM<br>PaddleOCR-VL<br>GOT-OCR]
|
|
|
+ OrientPool[PP-LCNet<br>Custom Classifier]
|
|
|
+ end
|
|
|
+
|
|
|
+ Layout -.-> LayoutPool
|
|
|
+ VL -.-> VLPool
|
|
|
+ PreProc -.-> OrientPool
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 📂 项目结构
|
|
|
+
|
|
|
+```bash
|
|
|
+zhch/
|
|
|
+├── universal_doc_parser/
|
|
|
+│ ├── __init__.py
|
|
|
+│ ├── parser.py # 主解析器
|
|
|
+│ ├── config_manager.py # 配置管理
|
|
|
+│ ├── model_factory.py # 模型工厂
|
|
|
+│ ├── output_formatter.py # 输出格式化
|
|
|
+│ └── models/
|
|
|
+│ ├── __init__.py
|
|
|
+│ ├── layout_detector.py # 版面检测抽象层
|
|
|
+│ ├── vl_recognizer.py # VLM识别抽象层
|
|
|
+│ ├── preprocessor.py # 预处理抽象层
|
|
|
+│ └── adapters/
|
|
|
+│ ├── __init__.py
|
|
|
+│ ├── mineru_adapter.py # MinerU适配器
|
|
|
+│ ├── paddlex_adapter.py # PaddleX适配器
|
|
|
+│ └── custom_adapter.py # 自定义适配器
|
|
|
+├── configs/
|
|
|
+│ ├── financial_report.yaml # 财报场景配置
|
|
|
+│ ├── bank_statement.yaml # 流水场景配置
|
|
|
+│ └── default.yaml # 默认配置
|
|
|
+└── main.py # 入口程序
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 🔧 核心代码实现
|
|
|
+
|
|
|
+### 1. 配置文件定义
|
|
|
+
|
|
|
+#### `configs/financial_report.yaml` (财报场景)
|
|
|
+
|
|
|
+```yaml
|
|
|
+# 财报场景配置
|
|
|
+scene_name: "financial_report"
|
|
|
+description: "上市公司年报、财务报表等场景"
|
|
|
+
|
|
|
+# 输入配置
|
|
|
+input:
|
|
|
+ supported_formats: [".pdf", ".png", ".jpg", ".jpeg"]
|
|
|
+ dpi: 300
|
|
|
+
|
|
|
+# 文档预处理
|
|
|
+preprocessor:
|
|
|
+ module: "paddlex" # paddlex | mineru | custom
|
|
|
+ orientation_classifier:
|
|
|
+ enabled: true
|
|
|
+ model_name: "PP-LCNet_x1_0_doc_ori"
|
|
|
+ model_dir: null
|
|
|
+ unwarping:
|
|
|
+ enabled: false
|
|
|
+
|
|
|
+# 版面检测
|
|
|
+layout_detection:
|
|
|
+ module: "paddlex"
|
|
|
+ model_name: "PP-DocLayoutV2"
|
|
|
+ model_dir: null
|
|
|
+ device: "cpu"
|
|
|
+ batch_size: 8
|
|
|
+ threshold:
|
|
|
+ table: 0.5
|
|
|
+ text: 0.4
|
|
|
+ image: 0.5
|
|
|
+ seal: 0.45
|
|
|
+ layout_nms: true
|
|
|
+ layout_unclip_ratio: [1.0, 1.0]
|
|
|
+
|
|
|
+# VLM识别
|
|
|
+vl_recognition:
|
|
|
+ module: "paddlex"
|
|
|
+ model_name: "PaddleOCR-VL-0.9B"
|
|
|
+ model_dir: null
|
|
|
+ backend: "vllm-server"
|
|
|
+ server_url: "http://10.192.72.11:8110/v1"
|
|
|
+ batch_size: 2048
|
|
|
+ device: "cpu"
|
|
|
+
|
|
|
+# 输出配置
|
|
|
+output:
|
|
|
+ format: "mineru" # mineru | paddlex | custom
|
|
|
+ save_json: true
|
|
|
+ save_markdown: true
|
|
|
+ save_images:
|
|
|
+ layout: true
|
|
|
+ ocr: true
|
|
|
+ raw: false
|
|
|
+ normalize_numbers: true
|
|
|
+```
|
|
|
+
|
|
|
+#### `configs/bank_statement.yaml` (流水场景)
|
|
|
+
|
|
|
+```yaml
|
|
|
+# 银行流水场景配置
|
|
|
+scene_name: "bank_statement"
|
|
|
+description: "银行流水、对账单等场景"
|
|
|
+
|
|
|
+input:
|
|
|
+ supported_formats: [".pdf", ".png", ".jpg"]
|
|
|
+ dpi: 200
|
|
|
+
|
|
|
+preprocessor:
|
|
|
+ module: "mineru"
|
|
|
+ orientation_classifier:
|
|
|
+ enabled: true
|
|
|
+ model_name: "paddle_orientation_classification"
|
|
|
+ model_dir: null
|
|
|
+ unwarping:
|
|
|
+ enabled: false
|
|
|
+
|
|
|
+layout_detection:
|
|
|
+ module: "mineru"
|
|
|
+ model_name: "YOLOv10" # doclayout_yolo
|
|
|
+ model_dir: null
|
|
|
+ device: "cuda"
|
|
|
+ batch_size: 4
|
|
|
+ conf: 0.1
|
|
|
+ iou: 0.45
|
|
|
+ imgsz: 1280
|
|
|
+
|
|
|
+vl_recognition:
|
|
|
+ module: "mineru"
|
|
|
+ model_name: "MinerU-VLM"
|
|
|
+ backend: "vllm-server"
|
|
|
+ server_url: "http://10.192.72.11:8111/v1"
|
|
|
+ device: "cuda"
|
|
|
+ batch_size: 1
|
|
|
+
|
|
|
+output:
|
|
|
+ format: "mineru"
|
|
|
+ save_json: true
|
|
|
+ save_markdown: true
|
|
|
+ save_images:
|
|
|
+ layout: true
|
|
|
+ ocr: true
|
|
|
+ raw: true
|
|
|
+ normalize_numbers: false # 流水不需要数字标准化
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 2. 配置管理器
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/config_manager.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""配置管理器 - 加载和验证配置文件"""
|
|
|
+import yaml
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, Any, Optional
|
|
|
+from dataclasses import dataclass
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class SceneConfig:
|
|
|
+ """场景配置数据类"""
|
|
|
+ scene_name: str
|
|
|
+ description: str
|
|
|
+ input: Dict[str, Any]
|
|
|
+ preprocessor: Dict[str, Any]
|
|
|
+ layout_detection: Dict[str, Any]
|
|
|
+ vl_recognition: Dict[str, Any]
|
|
|
+ output: Dict[str, Any]
|
|
|
+
|
|
|
+class ConfigManager:
|
|
|
+ """配置管理器"""
|
|
|
+
|
|
|
+ def __init__(self, config_path: str):
|
|
|
+ """
|
|
|
+ 初始化配置管理器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: 配置文件路径
|
|
|
+ """
|
|
|
+ self.config_path = Path(config_path)
|
|
|
+ self.config = self._load_config()
|
|
|
+ self._validate_config()
|
|
|
+
|
|
|
+ def _load_config(self) -> Dict[str, Any]:
|
|
|
+ """加载YAML配置文件"""
|
|
|
+ if not self.config_path.exists():
|
|
|
+ raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
|
|
|
+
|
|
|
+ with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
|
+ config = yaml.safe_load(f)
|
|
|
+
|
|
|
+ print(f"✅ 配置文件加载成功: {self.config_path}")
|
|
|
+ return config
|
|
|
+
|
|
|
+ def _validate_config(self):
|
|
|
+ """验证配置完整性"""
|
|
|
+ required_keys = [
|
|
|
+ 'scene_name', 'preprocessor', 'layout_detection',
|
|
|
+ 'vl_recognition', 'output'
|
|
|
+ ]
|
|
|
+
|
|
|
+ for key in required_keys:
|
|
|
+ if key not in self.config:
|
|
|
+ raise ValueError(f"配置文件缺少必需字段: {key}")
|
|
|
+
|
|
|
+ print(f"✅ 配置验证通过: {self.config['scene_name']}")
|
|
|
+
|
|
|
+ def get_scene_config(self) -> SceneConfig:
|
|
|
+ """获取场景配置对象"""
|
|
|
+ return SceneConfig(**self.config)
|
|
|
+
|
|
|
+ def get(self, key_path: str, default: Any = None) -> Any:
|
|
|
+ """
|
|
|
+ 获取嵌套配置值
|
|
|
+
|
|
|
+ Args:
|
|
|
+ key_path: 配置键路径,用'.'分隔,如 'layout_detection.model_name'
|
|
|
+ default: 默认值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 配置值
|
|
|
+ """
|
|
|
+ keys = key_path.split('.')
|
|
|
+ value = self.config
|
|
|
+
|
|
|
+ for key in keys:
|
|
|
+ if isinstance(value, dict) and key in value:
|
|
|
+ value = value[key]
|
|
|
+ else:
|
|
|
+ return default
|
|
|
+
|
|
|
+ return value
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 3. 模型工厂
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/model_factory.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""模型工厂 - 根据配置创建模型实例"""
|
|
|
+from typing import Any, Dict
|
|
|
+from .models.layout_detector import BaseLayoutDetector
|
|
|
+from .models.vl_recognizer import BaseVLRecognizer
|
|
|
+from .models.preprocessor import BasePreprocessor
|
|
|
+
|
|
|
+class ModelFactory:
|
|
|
+ """模型工厂类"""
|
|
|
+
|
|
|
+ # 注册的模型类
|
|
|
+ _layout_detectors = {}
|
|
|
+ _vl_recognizers = {}
|
|
|
+ _preprocessors = {}
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def register_layout_detector(cls, module_name: str, detector_class):
|
|
|
+ """注册版面检测模型"""
|
|
|
+ cls._layout_detectors[module_name] = detector_class
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def register_vl_recognizer(cls, module_name: str, recognizer_class):
|
|
|
+ """注册VLM识别模型"""
|
|
|
+ cls._vl_recognizers[module_name] = recognizer_class
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def register_preprocessor(cls, module_name: str, preprocessor_class):
|
|
|
+ """注册预处理器"""
|
|
|
+ cls._preprocessors[module_name] = preprocessor_class
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def create_layout_detector(cls, config: Dict[str, Any]) -> BaseLayoutDetector:
|
|
|
+ """
|
|
|
+ 创建版面检测器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: 版面检测配置
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 版面检测器实例
|
|
|
+ """
|
|
|
+ module = config.get('module', 'paddlex')
|
|
|
+
|
|
|
+ if module not in cls._layout_detectors:
|
|
|
+ raise ValueError(f"未注册的版面检测模块: {module}")
|
|
|
+
|
|
|
+ detector_class = cls._layout_detectors[module]
|
|
|
+ return detector_class(config)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def create_vl_recognizer(cls, config: Dict[str, Any]) -> BaseVLRecognizer:
|
|
|
+ """创建VLM识别器"""
|
|
|
+ module = config.get('module', 'paddlex')
|
|
|
+
|
|
|
+ if module not in cls._vl_recognizers:
|
|
|
+ raise ValueError(f"未注册的VLM识别模块: {module}")
|
|
|
+
|
|
|
+ recognizer_class = cls._vl_recognizers[module]
|
|
|
+ return recognizer_class(config)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def create_preprocessor(cls, config: Dict[str, Any]) -> BasePreprocessor:
|
|
|
+ """创建预处理器"""
|
|
|
+ module = config.get('module', 'paddlex')
|
|
|
+
|
|
|
+ if module not in cls._preprocessors:
|
|
|
+ raise ValueError(f"未注册的预处理模块: {module}")
|
|
|
+
|
|
|
+ preprocessor_class = cls._preprocessors[module]
|
|
|
+ return preprocessor_class(config)
|
|
|
+
|
|
|
+
|
|
|
+# 自动注册所有适配器
|
|
|
+def auto_register_adapters():
|
|
|
+ """自动注册所有适配器"""
|
|
|
+ from .models.adapters.paddlex_adapter import (
|
|
|
+ PaddleXLayoutDetector,
|
|
|
+ PaddleXVLRecognizer,
|
|
|
+ PaddleXPreprocessor
|
|
|
+ )
|
|
|
+ from .models.adapters.mineru_adapter import (
|
|
|
+ MinerULayoutDetector,
|
|
|
+ MinerUVLRecognizer,
|
|
|
+ MinerUPreprocessor
|
|
|
+ )
|
|
|
+
|
|
|
+ # 注册 PaddleX 适配器
|
|
|
+ ModelFactory.register_layout_detector('paddlex', PaddleXLayoutDetector)
|
|
|
+ ModelFactory.register_vl_recognizer('paddlex', PaddleXVLRecognizer)
|
|
|
+ ModelFactory.register_preprocessor('paddlex', PaddleXPreprocessor)
|
|
|
+
|
|
|
+ # 注册 MinerU 适配器
|
|
|
+ ModelFactory.register_layout_detector('mineru', MinerULayoutDetector)
|
|
|
+ ModelFactory.register_vl_recognizer('mineru', MinerUVLRecognizer)
|
|
|
+ ModelFactory.register_preprocessor('mineru', MinerUPreprocessor)
|
|
|
+
|
|
|
+ print("✅ 所有模型适配器已注册")
|
|
|
+
|
|
|
+# 模块导入时自动注册
|
|
|
+auto_register_adapters()
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 4. 抽象基类
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/models/layout_detector.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""版面检测抽象基类"""
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from typing import List, Dict, Any
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+class BaseLayoutDetector(ABC):
|
|
|
+ """版面检测器基类"""
|
|
|
+
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
+ """
|
|
|
+ 初始化版面检测器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: 版面检测配置
|
|
|
+ """
|
|
|
+ self.config = config
|
|
|
+ self.model_name = config.get('model_name')
|
|
|
+ self.device = config.get('device', 'cpu')
|
|
|
+ self.batch_size = config.get('batch_size', 1)
|
|
|
+
|
|
|
+ self._init_model()
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化模型 - 子类实现"""
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 检测单张图片
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 输入图片 (H, W, C)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 检测结果列表,每个元素包含:
|
|
|
+ - category_id: 类别ID
|
|
|
+ - label: 类别标签
|
|
|
+ - bbox: 边界框 [x1, y1, x2, y2]
|
|
|
+ - score: 置信度
|
|
|
+ """
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def batch_detect(self, images: List[np.ndarray]) -> List[List[Dict[str, Any]]]:
|
|
|
+ """批量检测"""
|
|
|
+ pass
|
|
|
+
|
|
|
+ def visualize(self, image: np.ndarray, results: List[Dict]) -> np.ndarray:
|
|
|
+ """可视化检测结果"""
|
|
|
+ import cv2
|
|
|
+
|
|
|
+ vis_img = image.copy()
|
|
|
+
|
|
|
+ for result in results:
|
|
|
+ bbox = result['bbox']
|
|
|
+ label = result.get('label', 'unknown')
|
|
|
+ score = result.get('score', 0.0)
|
|
|
+
|
|
|
+ x1, y1, x2, y2 = map(int, bbox)
|
|
|
+
|
|
|
+ # 绘制边界框
|
|
|
+ cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
+
|
|
|
+ # 绘制标签
|
|
|
+ text = f"{label} {score:.2f}"
|
|
|
+ cv2.putText(vis_img, text, (x1, y1-5),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
|
|
+
|
|
|
+ return vis_img
|
|
|
+```
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/models/vl_recognizer.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""VLM识别抽象基类"""
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from typing import List, Dict, Any
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+class BaseVLRecognizer(ABC):
|
|
|
+ """VLM识别器基类"""
|
|
|
+
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
+ """初始化VLM识别器"""
|
|
|
+ self.config = config
|
|
|
+ self.model_name = config.get('model_name')
|
|
|
+ self.backend = config.get('backend', 'local')
|
|
|
+ self.server_url = config.get('server_url')
|
|
|
+
|
|
|
+ self._init_model()
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化模型"""
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def recognize_region(self, image: np.ndarray, region: Dict[str, Any]) -> str:
|
|
|
+ """
|
|
|
+ 识别单个区域
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 完整图片
|
|
|
+ region: 区域信息 (包含bbox和label)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 识别文本
|
|
|
+ """
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def recognize_table(self, image: np.ndarray, region: Dict[str, Any]) -> str:
|
|
|
+ """
|
|
|
+ 识别表格区域
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 完整图片
|
|
|
+ region: 表格区域
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ HTML格式表格
|
|
|
+ """
|
|
|
+ pass
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 5. PaddleX适配器
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/models/adapters/paddlex_adapter.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""PaddleX模型适配器"""
|
|
|
+from ..layout_detector import BaseLayoutDetector
|
|
|
+from ..vl_recognizer import BaseVLRecognizer
|
|
|
+from ..preprocessor import BasePreprocessor
|
|
|
+from typing import List, Dict, Any
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+class PaddleXLayoutDetector(BaseLayoutDetector):
|
|
|
+ """PaddleX版面检测适配器"""
|
|
|
+
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化PaddleX版面检测模型"""
|
|
|
+ from paddlex.inference.models import create_predictor
|
|
|
+
|
|
|
+ self.model = create_predictor(
|
|
|
+ model_name=self.model_name,
|
|
|
+ model_dir=self.config.get('model_dir'),
|
|
|
+ device=self.device
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"✅ PaddleX版面检测模型已加载: {self.model_name}")
|
|
|
+
|
|
|
+ def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
|
+ """检测单张图片"""
|
|
|
+ result = list(self.model(image))[0]
|
|
|
+
|
|
|
+ # 转换为统一格式
|
|
|
+ regions = []
|
|
|
+ for box in result.get('boxes', []):
|
|
|
+ regions.append({
|
|
|
+ 'category_id': box.get('cls_id'),
|
|
|
+ 'label': box.get('label'),
|
|
|
+ 'bbox': box.get('coordinate'), # [x1, y1, x2, y2]
|
|
|
+ 'score': box.get('score')
|
|
|
+ })
|
|
|
+
|
|
|
+ return regions
|
|
|
+
|
|
|
+ def batch_detect(self, images: List[np.ndarray]) -> List[List[Dict[str, Any]]]:
|
|
|
+ """批量检测"""
|
|
|
+ results = []
|
|
|
+ for image in images:
|
|
|
+ results.append(self.detect(image))
|
|
|
+ return results
|
|
|
+
|
|
|
+
|
|
|
+class PaddleXVLRecognizer(BaseVLRecognizer):
|
|
|
+ """PaddleX VLM识别适配器"""
|
|
|
+
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化PaddleX VLM模型"""
|
|
|
+ if self.backend == 'vllm-server':
|
|
|
+ # 使用远程服务
|
|
|
+ import requests
|
|
|
+ self.session = requests.Session()
|
|
|
+ print(f"✅ PaddleX VLM连接到服务器: {self.server_url}")
|
|
|
+ else:
|
|
|
+ # 本地模型
|
|
|
+ from paddlex.inference.models import create_predictor
|
|
|
+ self.model = create_predictor(
|
|
|
+ model_name=self.model_name,
|
|
|
+ device=self.config.get('device', 'cpu')
|
|
|
+ )
|
|
|
+ print(f"✅ PaddleX VLM本地模型已加载")
|
|
|
+
|
|
|
+ def recognize_region(self, image: np.ndarray, region: Dict[str, Any]) -> str:
|
|
|
+ """识别单个区域"""
|
|
|
+ # 裁剪区域
|
|
|
+ bbox = region['bbox']
|
|
|
+ x1, y1, x2, y2 = map(int, bbox)
|
|
|
+ cropped = image[y1:y2, x1:x2]
|
|
|
+
|
|
|
+ if self.backend == 'vllm-server':
|
|
|
+ # 调用远程API
|
|
|
+ from PIL import Image
|
|
|
+ import base64
|
|
|
+ from io import BytesIO
|
|
|
+
|
|
|
+ pil_img = Image.fromarray(cropped)
|
|
|
+ buffered = BytesIO()
|
|
|
+ pil_img.save(buffered, format="PNG")
|
|
|
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": self.model_name,
|
|
|
+ "messages": [{
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}},
|
|
|
+ {"type": "text", "text": "识别图片中的所有文字"}
|
|
|
+ ]
|
|
|
+ }]
|
|
|
+ }
|
|
|
+
|
|
|
+ response = self.session.post(
|
|
|
+ f"{self.server_url}/chat/completions",
|
|
|
+ json=payload,
|
|
|
+ timeout=30
|
|
|
+ )
|
|
|
+
|
|
|
+ if response.status_code == 200:
|
|
|
+ return response.json()['choices'][0]['message']['content'].strip()
|
|
|
+
|
|
|
+ return ""
|
|
|
+
|
|
|
+ def recognize_table(self, image: np.ndarray, region: Dict[str, Any]) -> str:
|
|
|
+ """识别表格"""
|
|
|
+ # 类似实现,使用表格专用提示词
|
|
|
+ return "<table></table>" # 简化示例
|
|
|
+
|
|
|
+
|
|
|
+class PaddleXPreprocessor(BasePreprocessor):
|
|
|
+ """PaddleX预处理适配器"""
|
|
|
+
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化预处理模型"""
|
|
|
+ from paddlex import create_pipeline
|
|
|
+
|
|
|
+ self.pipeline = create_pipeline(
|
|
|
+ "doc_preprocessor",
|
|
|
+ device=self.config.get('device', 'cpu')
|
|
|
+ )
|
|
|
+
|
|
|
+ print("✅ PaddleX预处理管线已加载")
|
|
|
+
|
|
|
+ def preprocess(self, image: np.ndarray) -> np.ndarray:
|
|
|
+ """预处理单张图片"""
|
|
|
+ result = list(self.pipeline(image))[0]
|
|
|
+ return result['output_img']
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 6. MinerU适配器
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/models/adapters/mineru_adapter.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""MinerU模型适配器"""
|
|
|
+from ..layout_detector import BaseLayoutDetector
|
|
|
+from ..vl_recognizer import BaseVLRecognizer
|
|
|
+from ..preprocessor import BasePreprocessor
|
|
|
+from typing import List, Dict, Any
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+class MinerULayoutDetector(BaseLayoutDetector):
|
|
|
+ """MinerU版面检测适配器"""
|
|
|
+
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化MinerU版面检测模型"""
|
|
|
+ from mineru.model.layout.doclayoutyolo import DocLayoutYOLOModel
|
|
|
+ from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
+ from mineru.utils.enum_class import ModelPath
|
|
|
+ import os
|
|
|
+
|
|
|
+ weight_path = os.path.join(
|
|
|
+ auto_download_and_get_model_root_path(ModelPath.doclayout_yolo),
|
|
|
+ ModelPath.doclayout_yolo
|
|
|
+ )
|
|
|
+
|
|
|
+ self.model = DocLayoutYOLOModel(
|
|
|
+ weight=weight_path,
|
|
|
+ device=self.device,
|
|
|
+ imgsz=self.config.get('imgsz', 1280),
|
|
|
+ conf=self.config.get('conf', 0.1),
|
|
|
+ iou=self.config.get('iou', 0.45)
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"✅ MinerU版面检测模型已加载: YOLOv10")
|
|
|
+
|
|
|
+ def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
|
+ """检测单张图片"""
|
|
|
+ from PIL import Image
|
|
|
+
|
|
|
+ pil_image = Image.fromarray(image)
|
|
|
+ results = self.model.predict(pil_image)
|
|
|
+
|
|
|
+ # 转换为统一格式
|
|
|
+ regions = []
|
|
|
+ for res in results:
|
|
|
+ poly = res['poly']
|
|
|
+ regions.append({
|
|
|
+ 'category_id': res['category_id'],
|
|
|
+ 'label': f"cat_{res['category_id']}",
|
|
|
+ 'bbox': [poly[0], poly[1], poly[4], poly[5]],
|
|
|
+ 'score': res['score']
|
|
|
+ })
|
|
|
+
|
|
|
+ return regions
|
|
|
+
|
|
|
+ def batch_detect(self, images: List[np.ndarray]) -> List[List[Dict[str, Any]]]:
|
|
|
+ """批量检测"""
|
|
|
+ from PIL import Image
|
|
|
+
|
|
|
+ pil_images = [Image.fromarray(img) for img in images]
|
|
|
+ batch_results = self.model.batch_predict(pil_images, self.batch_size)
|
|
|
+
|
|
|
+ all_regions = []
|
|
|
+ for results in batch_results:
|
|
|
+ regions = []
|
|
|
+ for res in results:
|
|
|
+ poly = res['poly']
|
|
|
+ regions.append({
|
|
|
+ 'category_id': res['category_id'],
|
|
|
+ 'label': f"cat_{res['category_id']}",
|
|
|
+ 'bbox': [poly[0], poly[1], poly[4], poly[5]],
|
|
|
+ 'score': res['score']
|
|
|
+ })
|
|
|
+ all_regions.append(regions)
|
|
|
+
|
|
|
+ return all_regions
|
|
|
+
|
|
|
+
|
|
|
+class MinerUVLRecognizer(BaseVLRecognizer):
|
|
|
+ """MinerU VLM识别适配器"""
|
|
|
+
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化MinerU VLM模型"""
|
|
|
+ # MinerU使用HTTP客户端连接vLLM服务
|
|
|
+ import requests
|
|
|
+ self.session = requests.Session()
|
|
|
+ print(f"✅ MinerU VLM连接到服务器: {self.server_url}")
|
|
|
+
|
|
|
+ def recognize_region(self, image: np.ndarray, region: Dict[str, Any]) -> str:
|
|
|
+ """识别单个区域"""
|
|
|
+ # 实现类似PaddleX的远程调用
|
|
|
+ return ""
|
|
|
+
|
|
|
+ def recognize_table(self, image: np.ndarray, region: Dict[str, Any]) -> str:
|
|
|
+ """识别表格"""
|
|
|
+ return "<table></table>"
|
|
|
+
|
|
|
+
|
|
|
+class MinerUPreprocessor(BasePreprocessor):
|
|
|
+ """MinerU预处理适配器"""
|
|
|
+
|
|
|
+ def _init_model(self):
|
|
|
+ """初始化MinerU预处理模型"""
|
|
|
+ from mineru.backend.pipeline.model_init import AtomModelSingleton, AtomicModel
|
|
|
+
|
|
|
+ self.model_manager = AtomModelSingleton()
|
|
|
+
|
|
|
+ if self.config.get('orientation_classifier', {}).get('enabled'):
|
|
|
+ self.ori_model = self.model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.ImgOrientationCls
|
|
|
+ )
|
|
|
+
|
|
|
+ print("✅ MinerU预处理模型已加载")
|
|
|
+
|
|
|
+ def preprocess(self, image: np.ndarray) -> np.ndarray:
|
|
|
+ """预处理单张图片"""
|
|
|
+ if hasattr(self, 'ori_model'):
|
|
|
+ # 旋转校正
|
|
|
+ import cv2
|
|
|
+ from PIL import Image
|
|
|
+
|
|
|
+ pil_img = Image.fromarray(image)
|
|
|
+ rotate_label = self.ori_model.predict(image)
|
|
|
+
|
|
|
+ if rotate_label != 0:
|
|
|
+ image = np.array(pil_img.rotate(rotate_label * 90, expand=True))
|
|
|
+
|
|
|
+ return image
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 7. 主解析器
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/parser.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""通用文档解析器"""
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Dict, Any
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
+from .config_manager import ConfigManager
|
|
|
+from .model_factory import ModelFactory
|
|
|
+from .output_formatter import OutputFormatter
|
|
|
+
|
|
|
+class UniversalDocParser:
|
|
|
+ """通用文档解析器"""
|
|
|
+
|
|
|
+ def __init__(self, config_path: str):
|
|
|
+ """
|
|
|
+ 初始化解析器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: 配置文件路径
|
|
|
+ """
|
|
|
+ # 加载配置
|
|
|
+ self.config_manager = ConfigManager(config_path)
|
|
|
+ self.config = self.config_manager.get_scene_config()
|
|
|
+
|
|
|
+ # 创建模型实例
|
|
|
+ self.preprocessor = ModelFactory.create_preprocessor(
|
|
|
+ self.config.preprocessor
|
|
|
+ )
|
|
|
+ self.layout_detector = ModelFactory.create_layout_detector(
|
|
|
+ self.config.layout_detection
|
|
|
+ )
|
|
|
+ self.vl_recognizer = ModelFactory.create_vl_recognizer(
|
|
|
+ self.config.vl_recognition
|
|
|
+ )
|
|
|
+
|
|
|
+ # 输出格式化器
|
|
|
+ self.output_formatter = OutputFormatter(self.config.output)
|
|
|
+
|
|
|
+ print(f"🚀 解析器初始化完成: {self.config.scene_name}")
|
|
|
+
|
|
|
+ def parse(self, input_path: str, output_dir: str) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 解析文档
|
|
|
+
|
|
|
+ Args:
|
|
|
+ input_path: 输入文件路径
|
|
|
+ output_dir: 输出目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 解析结果字典
|
|
|
+ """
|
|
|
+ print(f"📄 开始解析: {input_path}")
|
|
|
+
|
|
|
+ # 1. 读取图片
|
|
|
+ image = self._read_image(input_path)
|
|
|
+
|
|
|
+ # 2. 预处理
|
|
|
+ print("🔄 步骤1: 文档预处理...")
|
|
|
+ preprocessed_image = self.preprocessor.preprocess(image)
|
|
|
+
|
|
|
+ # 3. 版面检测
|
|
|
+ print("📍 步骤2: 版面检测...")
|
|
|
+ layout_results = self.layout_detector.detect(preprocessed_image)
|
|
|
+ print(f" 检测到 {len(layout_results)} 个区域")
|
|
|
+
|
|
|
+ # 4. VLM识别
|
|
|
+ print("🔍 步骤3: VLM识别...")
|
|
|
+ recognized_results = []
|
|
|
+
|
|
|
+ for region in layout_results:
|
|
|
+ if region['label'] == 'table':
|
|
|
+ # 表格识别
|
|
|
+ html = self.vl_recognizer.recognize_table(preprocessed_image, region)
|
|
|
+ recognized_results.append({
|
|
|
+ **region,
|
|
|
+ 'content': html,
|
|
|
+ 'type': 'html'
|
|
|
+ })
|
|
|
+ elif region['label'] in ['text', 'title']:
|
|
|
+ # 文本识别
|
|
|
+ text = self.vl_recognizer.recognize_region(preprocessed_image, region)
|
|
|
+ recognized_results.append({
|
|
|
+ **region,
|
|
|
+ 'content': text,
|
|
|
+ 'type': 'text'
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ recognized_results.append(region)
|
|
|
+
|
|
|
+ # 5. 输出结果
|
|
|
+ print("💾 步骤4: 保存结果...")
|
|
|
+ result = {
|
|
|
+ 'input_path': input_path,
|
|
|
+ 'scene': self.config.scene_name,
|
|
|
+ 'layout_results': layout_results,
|
|
|
+ 'recognized_results': recognized_results
|
|
|
+ }
|
|
|
+
|
|
|
+ self.output_formatter.save(
|
|
|
+ result,
|
|
|
+ output_dir,
|
|
|
+ Path(input_path).stem,
|
|
|
+ preprocessed_image,
|
|
|
+ layout_results
|
|
|
+ )
|
|
|
+
|
|
|
+ print("✅ 解析完成!")
|
|
|
+ return result
|
|
|
+
|
|
|
+ def _read_image(self, path: str) -> np.ndarray:
|
|
|
+ """读取图片"""
|
|
|
+ img = Image.open(path).convert('RGB')
|
|
|
+ return np.array(img)
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 8. 输出格式化器
|
|
|
+
|
|
|
+创建 `zhch/universal_doc_parser/output_formatter.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""输出格式化器"""
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, Any, List
|
|
|
+import json
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+class OutputFormatter:
|
|
|
+ """输出格式化器"""
|
|
|
+
|
|
|
+ def __init__(self, output_config: Dict[str, Any]):
|
|
|
+ """初始化"""
|
|
|
+ self.config = output_config
|
|
|
+ self.format_type = output_config.get('format', 'mineru')
|
|
|
+
|
|
|
+ def save(self, result: Dict, output_dir: str, base_name: str,
|
|
|
+ image: np.ndarray, layout_results: List[Dict]):
|
|
|
+ """
|
|
|
+ 保存结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ result: 解析结果
|
|
|
+ output_dir: 输出目录
|
|
|
+ base_name: 文件基础名
|
|
|
+ image: 预处理后的图片
|
|
|
+ layout_results: 版面检测结果
|
|
|
+ """
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 1. 保存JSON
|
|
|
+ if self.config.get('save_json'):
|
|
|
+ json_path = output_path / f"{base_name}.json"
|
|
|
+ with open(json_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
+ print(f" ✅ JSON: {json_path}")
|
|
|
+
|
|
|
+ # 2. 保存Markdown
|
|
|
+ if self.config.get('save_markdown'):
|
|
|
+ md_path = output_path / f"{base_name}.md"
|
|
|
+ markdown_content = self._to_markdown(result)
|
|
|
+ with open(md_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(markdown_content)
|
|
|
+ print(f" ✅ Markdown: {md_path}")
|
|
|
+
|
|
|
+ # 3. 保存可视化图片
|
|
|
+ if self.config.get('save_images', {}).get('layout'):
|
|
|
+ import cv2
|
|
|
+ from ..models.layout_detector import BaseLayoutDetector
|
|
|
+
|
|
|
+ vis_img = BaseLayoutDetector.visualize(None, image, layout_results)
|
|
|
+ layout_img_path = output_path / f"{base_name}_layout.jpg"
|
|
|
+ cv2.imwrite(str(layout_img_path), cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR))
|
|
|
+ print(f" ✅ 版面图: {layout_img_path}")
|
|
|
+
|
|
|
+ def _to_markdown(self, result: Dict) -> str:
|
|
|
+ """转换为Markdown格式"""
|
|
|
+ blocks = []
|
|
|
+
|
|
|
+ for item in result.get('recognized_results', []):
|
|
|
+ if item.get('type') == 'text':
|
|
|
+ blocks.append(item['content'])
|
|
|
+ elif item.get('type') == 'html':
|
|
|
+ blocks.append(f'<div style="text-align: center;">{item["content"]}</div>')
|
|
|
+
|
|
|
+ return '\n\n'.join(blocks)
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 🎮 使用示例
|
|
|
+
|
|
|
+### 主程序
|
|
|
+
|
|
|
+创建 `zhch/main.py`:
|
|
|
+
|
|
|
+```python
|
|
|
+"""主程序入口"""
|
|
|
+import argparse
|
|
|
+from universal_doc_parser.parser import UniversalDocParser
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser(description="通用文档解析器")
|
|
|
+ parser.add_argument('--config', '-c', required=True, help='配置文件路径')
|
|
|
+ parser.add_argument('--input', '-i', required=True, help='输入文件路径')
|
|
|
+ parser.add_argument('--output', '-o', default='./output', help='输出目录')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 创建解析器
|
|
|
+ doc_parser = UniversalDocParser(args.config)
|
|
|
+
|
|
|
+ # 执行解析
|
|
|
+ result = doc_parser.parse(args.input, args.output)
|
|
|
+
|
|
|
+ print("\n🎉 处理完成!")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|
|
|
+```
|
|
|
+
|
|
|
+### 运行命令
|
|
|
+
|
|
|
+```bash
|
|
|
+# 财报场景
|
|
|
+python zhch/main.py \
|
|
|
+ --config zhch/configs/financial_report.yaml \
|
|
|
+ --input "/path/to/annual_report.pdf" \
|
|
|
+ --output "./output/financial"
|
|
|
+
|
|
|
+# 流水场景
|
|
|
+python zhch/main.py \
|
|
|
+ --config zhch/configs/bank_statement.yaml \
|
|
|
+ --input "/path/to/bank_statement.png" \
|
|
|
+ --output "./output/statement"
|
|
|
+```
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+## 📊 架构优势
|
|
|
+
|
|
|
+| 特性 | 说明 |
|
|
|
+|------|------|
|
|
|
+| ✅ **灵活配置** | YAML配置文件,无需修改代码 |
|
|
|
+| ✅ **模型可插拔** | 支持任意模型组合 |
|
|
|
+| ✅ **统一接口** | 抽象基类确保一致性 |
|
|
|
+| ✅ **多场景支持** | 一套代码适配多种业务 |
|
|
|
+| ✅ **易于扩展** | 只需实现适配器即可添加新模型 |
|
|
|
+
|
|
|
+这个架构完全满足您的需求,并且具有很强的可扩展性!🎯
|