Переглянути джерело

feat(paddle_table_classifier): 优化表格线检测,使用自适应阈值和线段过滤

zhch158_admin 1 тиждень тому
батько
коміт
d2258858b5

+ 36 - 6
ocr_tools/universal_doc_parser/models/adapters/paddle_table_classifier.py

@@ -198,20 +198,50 @@ class PaddleTableClassifier(BaseAdapter):
         else:
             gray = img_array
         
-        # 二值化
-        _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
+        # 二值化:自适应阈值更适合浅色表格线
+        binary = cv2.adaptiveThreshold(
+            gray,
+            255,
+            cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
+            cv2.THRESH_BINARY_INV,
+            25,
+            10
+        )
         
         h, w = binary.shape
         
         # 检测横线
-        horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (max(20, w//30), 1))
+        horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (max(20, w // 30), 1))
         horizontal_mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel)
-        horizontal_lines = cv2.findContours(horizontal_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
+        horizontal_contours = cv2.findContours(horizontal_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
         
         # 检测竖线
-        vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, max(20, h//30)))
+        vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, max(20, h // 30)))
         vertical_mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, vertical_kernel)
-        vertical_lines = cv2.findContours(vertical_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
+        vertical_contours = cv2.findContours(vertical_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
+
+        # 线段长度/长宽比过滤,降低文字竖画误检
+        def filter_lines(contours, orientation):
+            filtered = []
+            for cnt in contours:
+                x, y, cw, ch = cv2.boundingRect(cnt)
+                if cw <= 0 or ch <= 0:
+                    continue
+                if orientation == "h":
+                    if cw < w * 0.15:
+                        continue
+                    if cw / max(ch, 1) < 5.0:
+                        continue
+                else:
+                    if ch < h * 0.15:
+                        continue
+                    if ch / max(cw, 1) < 5.0:
+                        continue
+                filtered.append(cnt)
+            return filtered
+
+        horizontal_lines = filter_lines(horizontal_contours, "h")
+        vertical_lines = filter_lines(vertical_contours, "v")
         
         # 调试可视画
         # 使用传入的 debug_options (包含了可能的 override)