Преглед изворни кода

feat(paddle_table_classifier): integrate debugging tools for enhanced table line detection visualization

- Added WiredTableVisualizer and WiredTableDebugUtils for improved debugging capabilities.
- Updated classify and _detect_table_lines methods to incorporate debug options for better visualization control.
- Merged debug options to allow for dynamic adjustments during classification.
zhch158_admin пре 3 дана
родитељ
комит
11dcec8769
1 измењених фајлова са 35 додато и 4 уклоњено
  1. 35 4
      ocr_tools/universal_doc_parser/models/adapters/paddle_table_classifier.py

+ 35 - 4
ocr_tools/universal_doc_parser/models/adapters/paddle_table_classifier.py

@@ -6,7 +6,7 @@ PaddleOCR表格分类适配器
 适配 MinerU 的 PaddleTableClsModel,用于区分有线表格和无线表格。
 适配 MinerU 的 PaddleTableClsModel,用于区分有线表格和无线表格。
 """
 """
 import sys
 import sys
-from typing import Dict, Any, Union
+from typing import Dict, Any, Union, Optional
 from pathlib import Path
 from pathlib import Path
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
@@ -14,6 +14,8 @@ import cv2
 from loguru import logger
 from loguru import logger
 
 
 from .base import BaseAdapter
 from .base import BaseAdapter
+from .wired_table.visualization import WiredTableVisualizer
+from .wired_table.debug_utils import WiredTableDebugUtils, WiredTableDebugOptions
 
 
 # # 确保 MinerU 库可导入
 # # 确保 MinerU 库可导入
 # mineru_root = Path(__file__).parents[5] / "MinerU"
 # mineru_root = Path(__file__).parents[5] / "MinerU"
@@ -53,6 +55,11 @@ class PaddleTableClassifier(BaseAdapter):
         self.confidence_threshold = config.get('confidence_threshold', 0.5)
         self.confidence_threshold = config.get('confidence_threshold', 0.5)
         self.batch_size = config.get('batch_size', 16)
         self.batch_size = config.get('batch_size', 16)
         
         
+        # 初始化调试工具
+        self.debug_utils = WiredTableDebugUtils()
+        self.debug_options = self.debug_utils.merge_debug_options(self.config)
+        self.visualizer = WiredTableVisualizer()
+        
     def initialize(self):
     def initialize(self):
         """初始化模型"""
         """初始化模型"""
         if not MINERU_TABLE_CLS_AVAILABLE:
         if not MINERU_TABLE_CLS_AVAILABLE:
@@ -75,7 +82,8 @@ class PaddleTableClassifier(BaseAdapter):
     def classify(
     def classify(
         self, 
         self, 
         image: Union[np.ndarray, Image.Image],
         image: Union[np.ndarray, Image.Image],
-        use_line_detection: bool = True
+        use_line_detection: bool = True,
+        debug_options: Optional[Dict[str, Any]] = None
     ) -> Dict[str, Any]:
     ) -> Dict[str, Any]:
         """
         """
         分类单个表格图像
         分类单个表格图像
@@ -95,6 +103,12 @@ class PaddleTableClassifier(BaseAdapter):
         """
         """
         if self.model is None:
         if self.model is None:
             raise RuntimeError("Model not initialized. Call initialize() first.")
             raise RuntimeError("Model not initialized. Call initialize() first.")
+            
+        # 合并调试选项
+        merged_debug_opts = self.debug_utils.merge_debug_options(
+            self.config, 
+            override=debug_options
+        )
         
         
         try:
         try:
             # Step 1: 调用 MinerU 的预测接口
             # Step 1: 调用 MinerU 的预测接口
@@ -117,7 +131,7 @@ class PaddleTableClassifier(BaseAdapter):
             
             
             # Step 2: 使用线条检测辅助判断(覆盖低置信度结果)
             # Step 2: 使用线条检测辅助判断(覆盖低置信度结果)
             if use_line_detection:
             if use_line_detection:
-                line_info = self._detect_table_lines(image)
+                line_info = self._detect_table_lines(image, merged_debug_opts)
                 result['line_detection'] = line_info
                 result['line_detection'] = line_info
                 
                 
                 # 🔑 关键逻辑:只有横线没有竖线 → 强制无线表格
                 # 🔑 关键逻辑:只有横线没有竖线 → 强制无线表格
@@ -158,7 +172,11 @@ class PaddleTableClassifier(BaseAdapter):
                 'error': str(e)
                 'error': str(e)
             }
             }
     
     
-    def _detect_table_lines(self, image: Union[np.ndarray, Image.Image]) -> Dict[str, int]:
+    def _detect_table_lines(
+        self, 
+        image: Union[np.ndarray, Image.Image],
+        debug_options: Optional[WiredTableDebugOptions] = None
+    ) -> Dict[str, int]:
         """
         """
         检测表格图像中的横线和竖线数量
         检测表格图像中的横线和竖线数量
         
         
@@ -195,6 +213,19 @@ class PaddleTableClassifier(BaseAdapter):
         vertical_mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel)
         vertical_mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel)
         vertical_lines = cv2.findContours(vertical_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
         vertical_lines = cv2.findContours(vertical_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
         
         
+        # 调试可视画
+        # 使用传入的 debug_options (包含了可能的 override)
+        opts = debug_options or self.debug_options
+        if self.debug_utils.debug_is_on("save_table_lines", opts):
+            out_path = self.debug_utils.debug_path("paddle_table_lines", opts)
+            if out_path:
+                self.visualizer.visualize_table_lines(
+                    img_array if len(img_array.shape) == 3 else cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR),
+                    horizontal_mask,  # morphologyEx result is already a binary mask
+                    vertical_mask,
+                    output_path=out_path
+                )
+
         return {
         return {
             'horizontal_lines': len(horizontal_lines),
             'horizontal_lines': len(horizontal_lines),
             'vertical_lines': len(vertical_lines)
             'vertical_lines': len(vertical_lines)