Sfoglia il codice sorgente

fix: 增强网格结构恢复中的坐标转换精度,添加调试信息以验证缩放比例和单元格覆盖情况

zhch158_admin 1 giorno fa
parent
commit
bb0acb2afc

+ 178 - 14
ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py

@@ -19,6 +19,8 @@ class GridRecovery:
         upscale: float = 1.0,
         orig_h: Optional[int] = None,
         orig_w: Optional[int] = None,
+        unet_w_scale: Optional[float] = None,
+        unet_h_scale: Optional[float] = None,
         debug_dir: Optional[str] = None,
         debug_prefix: str = "",
     ) -> List[List[float]]:
@@ -38,6 +40,8 @@ class GridRecovery:
             upscale: 上采样比例(用于向后兼容,如果提供了 orig_h/orig_w 则会被覆盖)
             orig_h: 原图的实际高度(用于计算真实的缩放比例)
             orig_w: 原图的实际宽度(用于计算真实的缩放比例)
+            unet_w_scale: UNet预处理时的宽度缩放因子(可选,用于更精确的坐标转换)
+            unet_h_scale: UNet预处理时的高度缩放因子(可选,用于更精确的坐标转换)
             debug_dir: 调试输出目录 (Optional)
             debug_prefix: 调试文件名前缀 (Optional)
             
@@ -269,7 +273,7 @@ class GridRecovery:
         
         # Step 5b Debug (After Dilation)
         save_debug_image("step05b_dilated", line_mask)
-        
+
         # 6. 反转图像
         inv_grid = cv2.bitwise_not(line_mask)
         
@@ -279,26 +283,72 @@ class GridRecovery:
         # 7. 连通域
         num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
         
-        # 计算真实的缩放比例
-        # 如果提供了原图尺寸,使用真实的缩放比例;否则使用 upscale(向后兼容)
+        # 计算从mask坐标到原图坐标的缩放比例
+        # 核心理解:
+        # 1. mask已经被resize回上采样尺寸,所以 mask坐标系 = 上采样坐标系
+        # 2. 上采样图像 = 原图 × upscale
+        # 3. 因此:scale = mask_size / orig_size = (orig_size × upscale) / orig_size = upscale
+        # 
+        # 实际计算时,我们使用实际的mask尺寸和原图尺寸来计算,这样可以:
+        # - 处理任何微小的舍入误差
+        # - 验证resize是否正确(scale应该≈upscale)
+        # 
+        # 注意:unet_w_scale和unet_h_scale是UNet预处理时的缩放因子(上采样→UNet输入),
+        # 它们不应该直接用于mask→原图的坐标转换,因为mask已经被resize回上采样尺寸。
+        # 这些参数保留仅用于调试和验证。
+        
         if orig_h is not None and orig_w is not None and orig_h > 0 and orig_w > 0:
+            # 使用实际的mask尺寸和原图尺寸计算缩放比例
             scale_h = h / orig_h
             scale_w = w / orig_w
-            logger.debug(
-                f"连通域分析: mask尺寸=[{h}, {w}], 原图尺寸=[{orig_h}, {orig_w}], "
-                f"真实缩放比例=[{scale_h:.3f}, {scale_w:.3f}], upscale={upscale:.3f}"
+            
+            # 验证:scale应该非常接近upscale(允许<0.1%的误差)
+            scale_diff_h = abs(scale_h - upscale)
+            scale_diff_w = abs(scale_w - upscale)
+            scale_diff_ratio_h = scale_diff_h / upscale if upscale > 0 else 0
+            scale_diff_ratio_w = scale_diff_w / upscale if upscale > 0 else 0
+            
+            logger.info(
+                f"🔍 连通域坐标转换参数:\n"
+                f"  - Mask尺寸: [{h}, {w}]\n"
+                f"  - 原图尺寸: [{orig_h}, {orig_w}]\n"
+                f"  - 计算scale: h={scale_h:.6f}, w={scale_w:.6f}\n"
+                f"  - 理论upscale: {upscale:.6f}\n"
+                f"  - 差异: h={scale_diff_h:.6f} ({scale_diff_ratio_h*100:.3f}%), w={scale_diff_w:.6f} ({scale_diff_ratio_w*100:.3f}%)"
             )
+            
+            # 如果差异过大,可能表明resize有问题
+            if scale_diff_ratio_h > 0.01 or scale_diff_ratio_w > 0.01:  # >1%差异
+                logger.warning(
+                    f"⚠️ 计算的scale ([{scale_h:.3f}, {scale_w:.3f}]) 与理论upscale ({upscale:.3f}) 差异超过1%!"
+                    f"这可能表明mask尺寸不正确或resize有问题。"
+                )
+            
+            # 记录UNet缩放因子(仅用于调试,不参与坐标转换)
+            if unet_w_scale is not None and unet_h_scale is not None:
+                logger.debug(
+                    f"  (调试信息) UNet预处理缩放因子: h_scale={unet_h_scale:.6f}, w_scale={unet_w_scale:.6f}, "
+                    f"upscale/unet_scale = [{upscale/unet_h_scale:.3f}, {upscale/unet_w_scale:.3f}]"
+                )
         else:
+            # 如果没有提供原图尺寸,回退到使用upscale
             scale_h = upscale
             scale_w = upscale
-            logger.debug(
-                f"连通域分析: mask尺寸=[{h}, {w}], upscale={upscale:.3f}, "
-                f"预期原图尺寸≈[{h/upscale:.1f}, {w/upscale:.1f}] (使用 upscale,未提供原图尺寸)"
+            logger.info(
+                f"🔍 连通域坐标转换参数:\n"
+                f"  - Mask尺寸: [{h}, {w}]\n"
+                f"  - 使用upscale: {upscale:.3f}\n"
+                f"  - 预期原图尺寸≈[{h/upscale:.1f}, {w/upscale:.1f}]"
             )
         
         bboxes = []
         
         # 8. 过滤
+        # 8. 过滤(增强版:添加贴边连通域过滤)
+        # 由于裁剪时添加了 padding=10,表格真实边框应该距离图像边缘至少 10 像素
+        # 因此,任何直接贴着图像边缘的连通域都是 padding 区域的背景噪声
+        edge_threshold = 5  # 距离边缘小于5px视为"贴边"
+
         for i in range(1, num_labels):
             x = stats[i, cv2.CC_STAT_LEFT]
             y = stats[i, cv2.CC_STAT_TOP]
@@ -306,15 +356,28 @@ class GridRecovery:
             h_cell = stats[i, cv2.CC_STAT_HEIGHT]
             area = stats[i, cv2.CC_STAT_AREA]
             
+            # 过滤1:整图大小的连通域(背景)
             if w_cell > w * 0.98 and h_cell > h * 0.98:
                 continue
+            
+            # 过滤2:面积过小的噪点
             if area < 50:
                 continue
-                
-            # 使用真实的缩放比例转换为原图坐标
+            
+            # 过滤3:贴边连通域(padding区域的背景噪声)
+            # 判断连通域是否贴着图像边缘
+            is_touching_edge = (
+                x < edge_threshold or  # 左边缘
+                y < edge_threshold or  # 上边缘
+                (x + w_cell) > (w - edge_threshold) or  # 右边缘
+                (y + h_cell) > (h - edge_threshold)     # 下边缘
+            )
+            if is_touching_edge:
+                continue  # 过滤掉贴边的连通域
+            
+            # 过滤4:原图坐标下尺寸过小的单元格
             cell_orig_h = h_cell / scale_h
             cell_orig_w = w_cell / scale_w
-            
             if cell_orig_h < 4.0 or cell_orig_w < 4.0:
                 continue
             
@@ -326,10 +389,111 @@ class GridRecovery:
             ])
         
         bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
+
+        # 添加详细验证
+        if len(bboxes) > 0:
+            min_y = min(b[1] for b in bboxes)
+            max_y = max(b[3] for b in bboxes)
+            coverage_h = max_y - min_y
+            expected_h = orig_h if orig_h else h / upscale
+            
+            logger.info(
+                f"📏 单元格Y轴覆盖验证:\n"
+                f"  - 最小Y: {min_y:.1f}\n"
+                f"  - 最大Y: {max_y:.1f}\n"
+                f"  - 覆盖高度: {coverage_h:.1f}\n"
+                f"  - 原图高度: {expected_h:.1f}\n"
+                f"  - 覆盖率: {coverage_h/expected_h*100:.1f}%\n"
+                f"  - 顶部空白: {min_y:.1f}px ({min_y/expected_h*100:.1f}%)\n"
+                f"  - 底部空白: {expected_h - max_y:.1f}px ({(expected_h-max_y)/expected_h*100:.1f}%)"
+            )
         
-        # 调试日志:输出样本 bbox 坐标信息
+        # 可视化验证:保存调试图像,显示上采样mask上的连通域bbox和转换后的原图坐标
+        if debug_dir and len(bboxes) > 0:
+            try:
+                os.makedirs(debug_dir, exist_ok=True)
+                
+                # 创建可视化图像:上采样mask上的连通域bbox(绿色)
+                vis_mask = np.zeros((h, w, 3), dtype=np.uint8)
+                vis_mask[:, :, 0] = inv_grid  # 背景用反转的grid
+                vis_mask[:, :, 1] = inv_grid
+                vis_mask[:, :, 2] = inv_grid
+                
+                # 在上采样mask上绘制连通域bbox(使用上采样坐标)
+                for idx, bbox_orig in enumerate(bboxes[:min(20, len(bboxes))]):  # 只绘制前20个,避免太密集
+                    # 反推上采样坐标
+                    x_up = int(bbox_orig[0] * scale_w)
+                    y_up = int(bbox_orig[1] * scale_h)
+                    x2_up = int(bbox_orig[2] * scale_w)
+                    y2_up = int(bbox_orig[3] * scale_h)
+                    
+                    # 确保坐标在范围内
+                    x_up = max(0, min(x_up, w - 1))
+                    y_up = max(0, min(y_up, h - 1))
+                    x2_up = max(0, min(x2_up, w - 1))
+                    y2_up = max(0, min(y2_up, h - 1))
+                    
+                    if x2_up > x_up and y2_up > y_up:
+                        cv2.rectangle(vis_mask, (x_up, y_up), (x2_up, y2_up), (0, 255, 0), 2)
+                        # 标注单元格索引
+                        cv2.putText(vis_mask, str(idx), (x_up + 2, y_up + 15), 
+                                   cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
+                
+                name = f"{debug_prefix}_coordinate_verification_mask.png" if debug_prefix else "coordinate_verification_mask.png"
+                path = os.path.join(debug_dir, name)
+                cv2.imwrite(path, vis_mask)
+                logger.info(f"保存坐标验证图像(上采样mask): {path}")
+                
+            except Exception as e:
+                logger.warning(f"保存坐标验证图像失败: {e}")
+        
+        # 详细的坐标转换调试日志
         if len(bboxes) > 0:
-            logger.debug(f"样本 bbox (原图坐标): 前3个 = {bboxes[:3]}, 后3个 = {bboxes[-3:]}")
+            logger.info(
+                f"🔍 坐标转换验证:\n"
+                f"  - mask尺寸: [{h}, {w}]\n"
+                f"  - 原图尺寸: [{orig_h}, {orig_w}]\n"
+                f"  - 缩放比例: scale_h={scale_h:.6f}, scale_w={scale_w:.6f}\n"
+                f"  - 缩放比例差异: {abs(scale_h - scale_w):.6f}\n"
+                f"  - 提取到 {len(bboxes)} 个单元格"
+            )
+            
+            # 记录前几个和后几个单元格的详细坐标转换过程
+            sample_indices = [0, 1, 2] + [len(bboxes) - 3, len(bboxes) - 2, len(bboxes) - 1]
+            sample_indices = [i for i in sample_indices if 0 <= i < len(bboxes)]
+            
+            logger.info("🔍 样本单元格坐标转换详情:")
+            for idx in sample_indices:
+                bbox_orig = bboxes[idx]
+                # 反推上采样坐标(用于验证)
+                x_up = bbox_orig[0] * scale_w
+                y_up = bbox_orig[1] * scale_h
+                w_up = (bbox_orig[2] - bbox_orig[0]) * scale_w
+                h_up = (bbox_orig[3] - bbox_orig[1]) * scale_h
+                
+                logger.info(
+                    f"  单元格 {idx}: 原图坐标 [{bbox_orig[0]:.1f}, {bbox_orig[1]:.1f}, "
+                    f"{bbox_orig[2]:.1f}, {bbox_orig[3]:.1f}] "
+                    f"(尺寸: {bbox_orig[2]-bbox_orig[0]:.1f}x{bbox_orig[3]-bbox_orig[1]:.1f}) "
+                    f"-> 反推上采样坐标 [{x_up:.1f}, {y_up:.1f}, {x_up+w_up:.1f}, {y_up+h_up:.1f}] "
+                    f"(尺寸: {w_up:.1f}x{h_up:.1f})"
+                )
+            
+            # 检查是否有系统性偏移
+            if len(bboxes) >= 2:
+                first_y = bboxes[0][1]
+                last_y = bboxes[-1][3]
+                expected_height = last_y - first_y
+                actual_image_height = orig_h if orig_h else h / upscale
+                logger.info(
+                    f"🔍 系统性偏移检查:\n"
+                    f"  - 第一个单元格y1: {first_y:.1f}\n"
+                    f"  - 最后一个单元格y2: {last_y:.1f}\n"
+                    f"  - 单元格覆盖高度: {expected_height:.1f}\n"
+                    f"  - 原图实际高度: {actual_image_height:.1f}\n"
+                    f"  - 高度差异: {abs(expected_height - actual_image_height):.1f}"
+                )
+            
             logger.debug(f"bbox 坐标范围: x=[{min(b[0] for b in bboxes):.1f}, {max(b[2] for b in bboxes):.1f}], "
                         f"y=[{min(b[1] for b in bboxes):.1f}, {max(b[3] for b in bboxes):.1f}]")