Browse Source

feat: 添加PaddleOCR表格分类器适配器,支持有线/无线表格分类

zhch158_admin 1 day ago
parent
commit
0b7809226c

+ 192 - 0
ocr_tools/universal_doc_parser/models/adapters/paddle_table_classifier.py

@@ -0,0 +1,192 @@
+# 文件路径: /Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/models/adapters/paddle_table_classifier.py
+
+"""
+PaddleOCR表格分类适配器
+
+适配 MinerU 的 PaddleTableClsModel,用于区分有线表格和无线表格。
+"""
+import sys
+from typing import Dict, Any, Union
+from pathlib import Path
+import numpy as np
+from PIL import Image
+from loguru import logger
+
+from .base import BaseAdapter
+
+# # 确保 MinerU 库可导入
+# mineru_root = Path(__file__).parents[5] / "MinerU"
+# if str(mineru_root) not in sys.path:
+#     sys.path.insert(0, str(mineru_root))
+
+try:
+    from mineru.model.table.cls.paddle_table_cls import PaddleTableClsModel
+    from mineru.backend.pipeline.model_list import AtomicModel
+    MINERU_TABLE_CLS_AVAILABLE = True
+except ImportError as e:
+    logger.warning(f"MinerU table classifier not available: {e}")
+    MINERU_TABLE_CLS_AVAILABLE = False
+    PaddleTableClsModel = None
+    AtomicModel = None
+
+class PaddleTableClassifier(BaseAdapter):
+    """
+    PaddleOCR表格分类器适配器
+    
+    用于将表格图像分类为:
+    - wired_table: 有线表格(带边框)
+    - wireless_table: 无线表格(无边框)
+    """
+    
+    def __init__(self, config: Dict[str, Any]):
+        """
+        初始化表格分类器
+        
+        Args:
+            config: 配置字典,支持以下参数:
+                - confidence_threshold: 置信度阈值(默认 0.5)
+                - batch_size: 批处理大小(默认 16)
+        """
+        super().__init__(config)
+        self.model = None
+        self.confidence_threshold = config.get('confidence_threshold', 0.5)
+        self.batch_size = config.get('batch_size', 16)
+        
+    def initialize(self):
+        """初始化模型"""
+        if not MINERU_TABLE_CLS_AVAILABLE:
+            raise RuntimeError("MinerU table classifier not available")
+        
+        try:
+            self.model = PaddleTableClsModel()
+            logger.info("✅ PaddleTableClsModel initialized successfully")
+        except Exception as e:
+            logger.error(f"❌ Failed to initialize PaddleTableClsModel: {e}")
+            raise
+    
+    def cleanup(self):
+        """清理资源"""
+        if self.model:
+            del self.model
+            self.model = None
+            logger.info("✅ PaddleTableClsModel cleaned up")
+    
+    def classify(
+        self, 
+        image: Union[np.ndarray, Image.Image]
+    ) -> Dict[str, Any]:
+        """
+        分类单个表格图像
+        
+        Args:
+            image: 表格图像(numpy数组或PIL图像)
+            
+        Returns:
+            分类结果字典:
+            {
+                'table_type': 'wired' | 'wireless',
+                'confidence': float,
+                'raw_label': str  # AtomicModel.WiredTable 或 AtomicModel.WirelessTable
+            }
+        """
+        if self.model is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        try:
+            # 调用 MinerU 的预测接口
+            label, confidence = self.model.predict(image)
+            
+            # 转换标签为简化形式
+            if AtomicModel and label == AtomicModel.WiredTable:
+                table_type = 'wired'
+            elif AtomicModel and label == AtomicModel.WirelessTable:
+                table_type = 'wireless'
+            else:
+                # 兜底:基于字符串判断
+                table_type = 'wired' if 'wired' in str(label).lower() else 'wireless'
+            
+            result = {
+                'table_type': table_type,
+                'confidence': float(confidence),
+                'raw_label': str(label)
+            }
+            
+            logger.debug(f"Table classified as '{table_type}' (confidence: {confidence:.3f})")
+            return result
+            
+        except Exception as e:
+            logger.error(f"Table classification failed: {e}")
+            # 降级:返回默认值
+            return {
+                'table_type': 'wireless',  # 默认使用无线表格(更通用)
+                'confidence': 0.0,
+                'raw_label': 'unknown',
+                'error': str(e)
+            }
+    
+    def batch_classify(
+        self, 
+        images: list[Union[np.ndarray, Image.Image]]
+    ) -> list[Dict[str, Any]]:
+        """
+        批量分类表格图像
+        
+        Args:
+            images: 表格图像列表
+            
+        Returns:
+            分类结果列表
+        """
+        if self.model is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        if not images:
+            return []
+        
+        try:
+            # 构造 MinerU 期望的输入格式
+            img_info_list = []
+            for i, img in enumerate(images):
+                img_info_list.append({
+                    'wired_table_img': img,
+                    'table_res': {}  # MinerU 会在这里填充结果
+                })
+            
+            # 调用批量预测
+            self.model.batch_predict(img_info_list, batch_size=self.batch_size)
+            
+            # 提取结果
+            results = []
+            for img_info in img_info_list:
+                table_res = img_info['table_res']
+                label = table_res.get('cls_label', '')
+                confidence = table_res.get('cls_score', 0.0)
+                
+                # 转换标签
+                if AtomicModel and label == AtomicModel.WiredTable:
+                    table_type = 'wired'
+                elif AtomicModel and label == AtomicModel.WirelessTable:
+                    table_type = 'wireless'
+                else:
+                    table_type = 'wired' if 'wired' in str(label).lower() else 'wireless'
+                
+                results.append({
+                    'table_type': table_type,
+                    'confidence': float(confidence),
+                    'raw_label': str(label)
+                })
+            
+            return results
+            
+        except Exception as e:
+            logger.error(f"Batch table classification failed: {e}")
+            # 降级:返回默认值
+            return [
+                {
+                    'table_type': 'wireless',
+                    'confidence': 0.0,
+                    'raw_label': 'unknown',
+                    'error': str(e)
+                }
+                for _ in images
+            ]