|
@@ -6,7 +6,7 @@ PaddleOCR表格分类适配器
|
|
|
适配 MinerU 的 PaddleTableClsModel,用于区分有线表格和无线表格。
|
|
适配 MinerU 的 PaddleTableClsModel,用于区分有线表格和无线表格。
|
|
|
"""
|
|
"""
|
|
|
import sys
|
|
import sys
|
|
|
-from typing import Dict, Any, Union
|
|
|
|
|
|
|
+from typing import Dict, Any, Union, Optional
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
from PIL import Image
|
|
@@ -14,6 +14,8 @@ import cv2
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
|
|
|
|
|
from .base import BaseAdapter
|
|
from .base import BaseAdapter
|
|
|
|
|
+from .wired_table.visualization import WiredTableVisualizer
|
|
|
|
|
+from .wired_table.debug_utils import WiredTableDebugUtils, WiredTableDebugOptions
|
|
|
|
|
|
|
|
# # 确保 MinerU 库可导入
|
|
# # 确保 MinerU 库可导入
|
|
|
# mineru_root = Path(__file__).parents[5] / "MinerU"
|
|
# mineru_root = Path(__file__).parents[5] / "MinerU"
|
|
@@ -53,6 +55,11 @@ class PaddleTableClassifier(BaseAdapter):
|
|
|
self.confidence_threshold = config.get('confidence_threshold', 0.5)
|
|
self.confidence_threshold = config.get('confidence_threshold', 0.5)
|
|
|
self.batch_size = config.get('batch_size', 16)
|
|
self.batch_size = config.get('batch_size', 16)
|
|
|
|
|
|
|
|
|
|
+ # 初始化调试工具
|
|
|
|
|
+ self.debug_utils = WiredTableDebugUtils()
|
|
|
|
|
+ self.debug_options = self.debug_utils.merge_debug_options(self.config)
|
|
|
|
|
+ self.visualizer = WiredTableVisualizer()
|
|
|
|
|
+
|
|
|
def initialize(self):
|
|
def initialize(self):
|
|
|
"""初始化模型"""
|
|
"""初始化模型"""
|
|
|
if not MINERU_TABLE_CLS_AVAILABLE:
|
|
if not MINERU_TABLE_CLS_AVAILABLE:
|
|
@@ -75,7 +82,8 @@ class PaddleTableClassifier(BaseAdapter):
|
|
|
def classify(
|
|
def classify(
|
|
|
self,
|
|
self,
|
|
|
image: Union[np.ndarray, Image.Image],
|
|
image: Union[np.ndarray, Image.Image],
|
|
|
- use_line_detection: bool = True
|
|
|
|
|
|
|
+ use_line_detection: bool = True,
|
|
|
|
|
+ debug_options: Optional[Dict[str, Any]] = None
|
|
|
) -> Dict[str, Any]:
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
"""
|
|
|
分类单个表格图像
|
|
分类单个表格图像
|
|
@@ -95,6 +103,12 @@ class PaddleTableClassifier(BaseAdapter):
|
|
|
"""
|
|
"""
|
|
|
if self.model is None:
|
|
if self.model is None:
|
|
|
raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
|
|
+
|
|
|
|
|
+ # 合并调试选项
|
|
|
|
|
+ merged_debug_opts = self.debug_utils.merge_debug_options(
|
|
|
|
|
+ self.config,
|
|
|
|
|
+ override=debug_options
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
# Step 1: 调用 MinerU 的预测接口
|
|
# Step 1: 调用 MinerU 的预测接口
|
|
@@ -117,7 +131,7 @@ class PaddleTableClassifier(BaseAdapter):
|
|
|
|
|
|
|
|
# Step 2: 使用线条检测辅助判断(覆盖低置信度结果)
|
|
# Step 2: 使用线条检测辅助判断(覆盖低置信度结果)
|
|
|
if use_line_detection:
|
|
if use_line_detection:
|
|
|
- line_info = self._detect_table_lines(image)
|
|
|
|
|
|
|
+ line_info = self._detect_table_lines(image, merged_debug_opts)
|
|
|
result['line_detection'] = line_info
|
|
result['line_detection'] = line_info
|
|
|
|
|
|
|
|
# 🔑 关键逻辑:只有横线没有竖线 → 强制无线表格
|
|
# 🔑 关键逻辑:只有横线没有竖线 → 强制无线表格
|
|
@@ -158,7 +172,11 @@ class PaddleTableClassifier(BaseAdapter):
|
|
|
'error': str(e)
|
|
'error': str(e)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- def _detect_table_lines(self, image: Union[np.ndarray, Image.Image]) -> Dict[str, int]:
|
|
|
|
|
|
|
+ def _detect_table_lines(
|
|
|
|
|
+ self,
|
|
|
|
|
+ image: Union[np.ndarray, Image.Image],
|
|
|
|
|
+ debug_options: Optional[WiredTableDebugOptions] = None
|
|
|
|
|
+ ) -> Dict[str, int]:
|
|
|
"""
|
|
"""
|
|
|
检测表格图像中的横线和竖线数量
|
|
检测表格图像中的横线和竖线数量
|
|
|
|
|
|
|
@@ -195,6 +213,19 @@ class PaddleTableClassifier(BaseAdapter):
|
|
|
vertical_mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel)
|
|
vertical_mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel)
|
|
|
vertical_lines = cv2.findContours(vertical_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
|
vertical_lines = cv2.findContours(vertical_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
|
|
|
|
|
|
|
|
|
+ # 调试可视画
|
|
|
|
|
+ # 使用传入的 debug_options (包含了可能的 override)
|
|
|
|
|
+ opts = debug_options or self.debug_options
|
|
|
|
|
+ if self.debug_utils.debug_is_on("save_table_lines", opts):
|
|
|
|
|
+ out_path = self.debug_utils.debug_path("paddle_table_lines", opts)
|
|
|
|
|
+ if out_path:
|
|
|
|
|
+ self.visualizer.visualize_table_lines(
|
|
|
|
|
+ img_array if len(img_array.shape) == 3 else cv2.cvtColor(img_array, cv2.COLOR_GRAY2BGR),
|
|
|
|
|
+ horizontal_mask, # morphologyEx result is already a binary mask
|
|
|
|
|
+ vertical_mask,
|
|
|
|
|
+ output_path=out_path
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
return {
|
|
return {
|
|
|
'horizontal_lines': len(horizontal_lines),
|
|
'horizontal_lines': len(horizontal_lines),
|
|
|
'vertical_lines': len(vertical_lines)
|
|
'vertical_lines': len(vertical_lines)
|