|
|
@@ -0,0 +1,192 @@
|
|
|
+# 文件路径: /Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/models/adapters/paddle_table_classifier.py
|
|
|
+
|
|
|
+"""
|
|
|
+PaddleOCR表格分类适配器
|
|
|
+
|
|
|
+适配 MinerU 的 PaddleTableClsModel,用于区分有线表格和无线表格。
|
|
|
+"""
|
|
|
+import sys
|
|
|
+from typing import Dict, Any, Union
|
|
|
+from pathlib import Path
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+from loguru import logger
|
|
|
+
|
|
|
+from .base import BaseAdapter
|
|
|
+
|
|
|
+# # 确保 MinerU 库可导入
|
|
|
+# mineru_root = Path(__file__).parents[5] / "MinerU"
|
|
|
+# if str(mineru_root) not in sys.path:
|
|
|
+# sys.path.insert(0, str(mineru_root))
|
|
|
+
|
|
|
+try:
|
|
|
+ from mineru.model.table.cls.paddle_table_cls import PaddleTableClsModel
|
|
|
+ from mineru.backend.pipeline.model_list import AtomicModel
|
|
|
+ MINERU_TABLE_CLS_AVAILABLE = True
|
|
|
+except ImportError as e:
|
|
|
+ logger.warning(f"MinerU table classifier not available: {e}")
|
|
|
+ MINERU_TABLE_CLS_AVAILABLE = False
|
|
|
+ PaddleTableClsModel = None
|
|
|
+ AtomicModel = None
|
|
|
+
|
|
|
+class PaddleTableClassifier(BaseAdapter):
|
|
|
+ """
|
|
|
+ PaddleOCR表格分类器适配器
|
|
|
+
|
|
|
+ 用于将表格图像分类为:
|
|
|
+ - wired_table: 有线表格(带边框)
|
|
|
+ - wireless_table: 无线表格(无边框)
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
+ """
|
|
|
+ 初始化表格分类器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config: 配置字典,支持以下参数:
|
|
|
+ - confidence_threshold: 置信度阈值(默认 0.5)
|
|
|
+ - batch_size: 批处理大小(默认 16)
|
|
|
+ """
|
|
|
+ super().__init__(config)
|
|
|
+ self.model = None
|
|
|
+ self.confidence_threshold = config.get('confidence_threshold', 0.5)
|
|
|
+ self.batch_size = config.get('batch_size', 16)
|
|
|
+
|
|
|
+ def initialize(self):
|
|
|
+ """初始化模型"""
|
|
|
+ if not MINERU_TABLE_CLS_AVAILABLE:
|
|
|
+ raise RuntimeError("MinerU table classifier not available")
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.model = PaddleTableClsModel()
|
|
|
+ logger.info("✅ PaddleTableClsModel initialized successfully")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"❌ Failed to initialize PaddleTableClsModel: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def cleanup(self):
|
|
|
+ """清理资源"""
|
|
|
+ if self.model:
|
|
|
+ del self.model
|
|
|
+ self.model = None
|
|
|
+ logger.info("✅ PaddleTableClsModel cleaned up")
|
|
|
+
|
|
|
+ def classify(
|
|
|
+ self,
|
|
|
+ image: Union[np.ndarray, Image.Image]
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 分类单个表格图像
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 表格图像(numpy数组或PIL图像)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 分类结果字典:
|
|
|
+ {
|
|
|
+ 'table_type': 'wired' | 'wireless',
|
|
|
+ 'confidence': float,
|
|
|
+ 'raw_label': str # AtomicModel.WiredTable 或 AtomicModel.WirelessTable
|
|
|
+ }
|
|
|
+ """
|
|
|
+ if self.model is None:
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 调用 MinerU 的预测接口
|
|
|
+ label, confidence = self.model.predict(image)
|
|
|
+
|
|
|
+ # 转换标签为简化形式
|
|
|
+ if AtomicModel and label == AtomicModel.WiredTable:
|
|
|
+ table_type = 'wired'
|
|
|
+ elif AtomicModel and label == AtomicModel.WirelessTable:
|
|
|
+ table_type = 'wireless'
|
|
|
+ else:
|
|
|
+ # 兜底:基于字符串判断
|
|
|
+ table_type = 'wired' if 'wired' in str(label).lower() else 'wireless'
|
|
|
+
|
|
|
+ result = {
|
|
|
+ 'table_type': table_type,
|
|
|
+ 'confidence': float(confidence),
|
|
|
+ 'raw_label': str(label)
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.debug(f"Table classified as '{table_type}' (confidence: {confidence:.3f})")
|
|
|
+ return result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Table classification failed: {e}")
|
|
|
+ # 降级:返回默认值
|
|
|
+ return {
|
|
|
+ 'table_type': 'wireless', # 默认使用无线表格(更通用)
|
|
|
+ 'confidence': 0.0,
|
|
|
+ 'raw_label': 'unknown',
|
|
|
+ 'error': str(e)
|
|
|
+ }
|
|
|
+
|
|
|
+ def batch_classify(
|
|
|
+ self,
|
|
|
+ images: list[Union[np.ndarray, Image.Image]]
|
|
|
+ ) -> list[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 批量分类表格图像
|
|
|
+
|
|
|
+ Args:
|
|
|
+ images: 表格图像列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 分类结果列表
|
|
|
+ """
|
|
|
+ if self.model is None:
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
+
|
|
|
+ if not images:
|
|
|
+ return []
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 构造 MinerU 期望的输入格式
|
|
|
+ img_info_list = []
|
|
|
+ for i, img in enumerate(images):
|
|
|
+ img_info_list.append({
|
|
|
+ 'wired_table_img': img,
|
|
|
+ 'table_res': {} # MinerU 会在这里填充结果
|
|
|
+ })
|
|
|
+
|
|
|
+ # 调用批量预测
|
|
|
+ self.model.batch_predict(img_info_list, batch_size=self.batch_size)
|
|
|
+
|
|
|
+ # 提取结果
|
|
|
+ results = []
|
|
|
+ for img_info in img_info_list:
|
|
|
+ table_res = img_info['table_res']
|
|
|
+ label = table_res.get('cls_label', '')
|
|
|
+ confidence = table_res.get('cls_score', 0.0)
|
|
|
+
|
|
|
+ # 转换标签
|
|
|
+ if AtomicModel and label == AtomicModel.WiredTable:
|
|
|
+ table_type = 'wired'
|
|
|
+ elif AtomicModel and label == AtomicModel.WirelessTable:
|
|
|
+ table_type = 'wireless'
|
|
|
+ else:
|
|
|
+ table_type = 'wired' if 'wired' in str(label).lower() else 'wireless'
|
|
|
+
|
|
|
+ results.append({
|
|
|
+ 'table_type': table_type,
|
|
|
+ 'confidence': float(confidence),
|
|
|
+ 'raw_label': str(label)
|
|
|
+ })
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Batch table classification failed: {e}")
|
|
|
+ # 降级:返回默认值
|
|
|
+ return [
|
|
|
+ {
|
|
|
+ 'table_type': 'wireless',
|
|
|
+ 'confidence': 0.0,
|
|
|
+ 'raw_label': 'unknown',
|
|
|
+ 'error': str(e)
|
|
|
+ }
|
|
|
+ for _ in images
|
|
|
+ ]
|