Kaynağa Gözat

feat(paddle_table_classifier): integrate debugging tools for enhanced table line detection visualization

- Added WiredTableVisualizer and WiredTableDebugUtils for improved debugging capabilities.
- Updated classify and _detect_table_lines methods to incorporate debug options for better visualization control.
- Merged debug options to allow for dynamic adjustments during classification.
zhch158_admin 3 gün önce
ebeveyn
işleme
11dcec8769

+ 35 - 4
ocr_tools/universal_doc_parser/models/adapters/paddle_table_classifier.py

@@ -6,7 +6,7 @@ PaddleOCR表格分类适配器
 适配 MinerU 的 PaddleTableClsModel,用于区分有线表格和无线表格。
 """
 import sys
-from typing import Dict, Any, Union
+from typing import Dict, Any, Union, Optional
 from pathlib import Path
 import numpy as np
 from PIL import Image
@@ -14,6 +14,8 @@ import cv2
 from loguru import logger
 
 from .base import BaseAdapter
+from .wired_table.visualization import WiredTableVisualizer
+from .wired_table.debug_utils import WiredTableDebugUtils, WiredTableDebugOptions
 
 # # 确保 MinerU 库可导入
 # mineru_root = Path(__file__).parents[5] / "MinerU"
@@ -53,6 +55,11 @@ class PaddleTableClassifier(BaseAdapter):
         self.confidence_threshold = config.get('confidence_threshold', 0.5)
         self.batch_size = config.get('batch_size', 16)
         
+        # 初始化调试工具
+        self.debug_utils = WiredTableDebugUtils()
+        self.debug_options = self.debug_utils.merge_debug_options(self.config)
+        self.visualizer = WiredTableVisualizer()
+        
     def initialize(self):
         """初始化模型"""
         if not MINERU_TABLE_CLS_AVAILABLE:
@@ -75,7 +82,8 @@ class PaddleTableClassifier(BaseAdapter):
     def classify(
         self, 
         image: Union[np.ndarray, Image.Image],
-        use_line_detection: bool = True
+        use_line_detection: bool = True,
+        debug_options: Optional[Dict[str, Any]] = None
     ) -> Dict[str, Any]:
         """
         分类单个表格图像
@@ -95,6 +103,12 @@ class PaddleTableClassifier(BaseAdapter):
         """
         if self.model is None:
             raise RuntimeError("Model not initialized. Call initialize() first.")
+            
+        # 合并调试选项
+        merged_debug_opts = self.debug_utils.merge_debug_options(
+            self.config, 
+            override=debug_options
+        )
         
         try:
             # Step 1: 调用 MinerU 的预测接口
@@ -117,7 +131,7 @@ class PaddleTableClassifier(BaseAdapter):
             
             # Step 2: 使用线条检测辅助判断(覆盖低置信度结果)
             if use_line_detection:
-                line_info = self._detect_table_lines(image)
+                line_info = self._detect_table_lines(image, merged_debug_opts)
                 result['line_detection'] = line_info
                 
                 # 🔑 关键逻辑:只有横线没有竖线 → 强制无线表格
@@ -158,7 +172,11 @@ class PaddleTableClassifier(BaseAdapter):
                 'error': str(e)
             }
     
-    def _detect_table_lines(self, image: Union[np.ndarray, Image.Image]) -> Dict[str, int]:
+    def _detect_table_lines(
+        self, 
+        image: Union[np.ndarray, Image.Image],
+        debug_options: Optional[WiredTableDebugOptions] = None
+    ) -> Dict[str, int]:
         """
         检测表格图像中的横线和竖线数量
         
@@ -195,6 +213,19 @@ class PaddleTableClassifier(BaseAdapter):
         vertical_mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel)
         vertical_lines = cv2.findContours(vertical_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
         
+        # 调试可视画
+        # 使用传入的 debug_options (包含了可能的 override)
+        opts = debug_options or self.debug_options
+        if self.debug_utils.debug_is_on("save_table_lines", opts):
+            out_path = self.debug_utils.debug_path("paddle_table_lines", opts)
+            if out_path:
+                self.visualizer.visualize_table_lines(
+                    img_array if len(img_array.shape) == 3 else cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR),
+                    horizontal_mask,  # morphologyEx result is already a binary mask
+                    vertical_mask,
+                    output_path=out_path
+                )
+
         return {
             'horizontal_lines': len(horizontal_lines),
             'vertical_lines': len(vertical_lines)