Browse Source

feat: 添加边缘线过滤功能,优化线段提取过程以减少噪声

zhch158_admin 2 ngày trước cách đây
mục cha
commit
f90f868f20

+ 3 - 2
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -334,7 +334,7 @@ class MinerUWiredTableRecognizer:
             # Step 2: 使用连通域法提取单元格 (替换了原来的投影法)
             debug_prefix = f"{dbg.prefix}_grid" if dbg.prefix else "grid"
             
-            # 传入原图的实际尺寸用于计算坐标缩放比例
+            # 传入原图的实际尺寸和裁剪padding用于计算坐标缩放比例和边缘过滤
             bboxes = self.grid_recovery.compute_cells_from_lines(
                 hpred_up, 
                 vpred_up, 
@@ -342,7 +342,8 @@ class MinerUWiredTableRecognizer:
                 orig_h=h,
                 orig_w=w,
                 debug_dir=debug_dir,
-                debug_prefix=debug_prefix
+                debug_prefix=debug_prefix,
+                crop_padding=10  # 传递 padding 值(与 element_processors.py 中的 crop_padding 保持一致)
             )
             # bboxes = self.grid_recovery.compute_cells_from_lines(hpred_up, vpred_up, upscale) # Original call
             if not bboxes:

+ 121 - 2
ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py

@@ -21,6 +21,7 @@ class GridRecovery:
         orig_w: Optional[int] = None,
         debug_dir: Optional[str] = None,
         debug_prefix: str = "",
+        crop_padding: int = 10,  # 新增:裁剪时的padding值(原图坐标系)
     ) -> List[List[float]]:
         """
         基于矢量重构的连通域分析 (Advanced Vector-based Recovery)
@@ -40,7 +41,13 @@ class GridRecovery:
             orig_w: 原图的实际宽度(用于计算真实的缩放比例)
             debug_dir: 调试输出目录 (Optional)
             debug_prefix: 调试文件名前缀 (Optional)
+            crop_padding: 裁剪时的padding值(原图坐标系,默认10px)
             
+        注意:
+            - hpred_up/vpred_up 是上采样后的mask,坐标系已经放大了 upscale 倍
+            - crop_padding 是原图坐标系的值,需要乘以 upscale 转换到mask坐标系
+            - edge_margin 用于过滤贴近图像边缘的线条(padding区域的噪声)
+        
         Returns:
             单元格bbox列表 [[x1, y1, x2, y2], ...]
         """
@@ -201,12 +208,124 @@ class GridRecovery:
         # 2. 提取矢量线段
         rowboxes = get_table_line(h_bin, axis=0, lineW=int(10))
         colboxes = get_table_line(v_bin, axis=1, lineW=int(10))
-        
+
         logger.debug(f"Initial lines -> Rows: {len(rowboxes)}, Cols: {len(colboxes)}")
         
         # Step 2 Debug
         save_debug_image("step02_raw_vectors", h_bin, is_lines=True, lines=rowboxes + colboxes)
-        
+
+        # ==================== 新增:边缘线过滤 ====================
+        # 2.5 过滤边缘线条(padding 区域的噪声)
+        # 
+        # 关键理解:
+        # - crop_padding = 10px 是原图坐标系的值
+        # - hpred_up/vpred_up 是上采样后的mask,坐标系已放大 upscale 倍
+        # - 因此 mask 坐标系中的 padding = crop_padding × upscale
+        # - edge_margin 应略大于 padding,以确保能过滤边缘噪声
+        #
+        edge_margin = int((crop_padding-1) * upscale * 1)  # padding
+        logger.info(
+            f"   🔧 边缘过滤参数: crop_padding={crop_padding}px (原图坐标系), "
+            f"upscale={upscale:.3f}, "
+            f"edge_margin={edge_margin}px (mask坐标系)"
+        )
+
+        def filter_edge_lines(lines, img_h, img_w, margin):
+            """
+            过滤贴近图像边缘的线条(padding 区域噪声)
+            
+            Args:
+                lines: 线段列表 [[x1, y1, x2, y2], ...] (mask坐标系)
+                img_h: mask图像高度
+                img_w: mask图像宽度
+                margin: 边缘阈值(像素,mask坐标系)
+            
+            Returns:
+                (过滤后的线段列表, 被过滤的线段详情列表)
+            """
+            filtered = []
+            removed = []
+            
+            for line in lines:
+                x1, y1, x2, y2 = line
+                
+                # 计算线段的边界框
+                min_x = min(x1, x2)
+                max_x = max(x1, x2)
+                min_y = min(y1, y2)
+                max_y = max(y1, y2)
+                
+                # 判断线段方向(基于长度比例)
+                is_horizontal = abs(y2 - y1) < abs(x2 - x1)
+                
+                should_remove = False
+                reason = ""
+                
+                if is_horizontal:
+                    # 横线:检查是否贴近上下边缘
+                    if min_y < margin:
+                        should_remove = True
+                        reason = f"贴近上边缘 (min_y={min_y:.1f} < {margin})"
+                    elif max_y > (img_h - margin):
+                        should_remove = True
+                        reason = f"贴近下边缘 (max_y={max_y:.1f} > {img_h - margin:.1f})"
+                else:
+                    # 竖线:检查是否贴近左右边缘
+                    if min_x < margin:
+                        should_remove = True
+                        reason = f"贴近左边缘 (min_x={min_x:.1f} < {margin})"
+                    elif max_x > (img_w - margin):
+                        should_remove = True
+                        reason = f"贴近右边缘 (max_x={max_x:.1f} > {img_w - margin:.1f})"
+                
+                if should_remove:
+                    removed.append((line, reason))
+                else:
+                    filtered.append(line)
+            
+            return filtered, removed
+
+        # 执行边缘过滤
+        len_row_before = len(rowboxes)
+        len_col_before = len(colboxes)
+
+        rowboxes_filtered, rowboxes_removed = filter_edge_lines(rowboxes, h, w, edge_margin)
+        colboxes_filtered, colboxes_removed = filter_edge_lines(colboxes, h, w, edge_margin)
+
+        # 详细日志
+        if rowboxes_removed or colboxes_removed:
+            logger.info(
+                f"   🧹 边缘过滤结果: "
+                f"横线 {len_row_before} → {len(rowboxes_filtered)} (-{len(rowboxes_removed)}), "
+                f"竖线 {len_col_before} → {len(colboxes_filtered)} (-{len(colboxes_removed)})"
+            )
+            
+            # 详细列出被过滤的竖线(通常是噪声的主要来源)
+            if colboxes_removed:
+                logger.debug(f"   被过滤的竖线 ({len(colboxes_removed)}条):")
+                for line, reason in colboxes_removed[:5]:  # 只显示前5条
+                    logger.debug(f"     - [{line[0]:.1f}, {line[1]:.1f}, {line[2]:.1f}, {line[3]:.1f}] → {reason}")
+                if len(colboxes_removed) > 5:
+                    logger.debug(f"     ... 还有 {len(colboxes_removed) - 5} 条")
+            
+            # 横线过滤信息
+            if rowboxes_removed:
+                logger.debug(f"   被过滤的横线 ({len(rowboxes_removed)}条):")
+                for line, reason in rowboxes_removed[:5]:
+                    logger.debug(f"     - [{line[0]:.1f}, {line[1]:.1f}, {line[2]:.1f}, {line[3]:.1f}] → {reason}")
+                if len(rowboxes_removed) > 5:
+                    logger.debug(f"     ... 还有 {len(rowboxes_removed) - 5} 条")
+        else:
+            logger.debug(f"   ✓ 边缘过滤: 无噪声线条被移除")
+
+        # 更新线段列表
+        rowboxes = rowboxes_filtered
+        colboxes = colboxes_filtered
+
+        # Step 2.5 Debug(过滤后的干净线条)
+        save_debug_image("step02b_edge_filtered", h_bin, is_lines=True, lines=rowboxes + colboxes)
+        # ==================== 边缘线过滤结束 ====================
+
         # 3. 线段合并 (adjust_lines)
         rboxes_row_ = adjust_lines(rowboxes, alph=100, angle=50)
         rboxes_col_ = adjust_lines(colboxes, alph=15, angle=50)