Kaynağa Gözat

feat: Enhance GridRecovery class by adding optional parameters for original image dimensions and improving scaling logic for cell extraction, along with detailed debug logging for better traceability.

zhch158_admin 2 gün önce
ebeveyn
işleme
60aa86e4bf

+ 77 - 38
ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py

@@ -3,7 +3,7 @@
 
 提供从表格线提取单元格和恢复网格结构的功能。
 """
-from typing import List, Dict
+from typing import List, Dict, Optional
 import cv2
 import numpy as np
 from loguru import logger
@@ -17,7 +17,9 @@ class GridRecovery:
         hpred_up: np.ndarray,
         vpred_up: np.ndarray,
         upscale: float = 1.0,
-        debug_dir: str = None,
+        orig_h: Optional[int] = None,
+        orig_w: Optional[int] = None,
+        debug_dir: Optional[str] = None,
         debug_prefix: str = "",
     ) -> List[List[float]]:
         """
@@ -33,7 +35,9 @@ class GridRecovery:
         Args:
             hpred_up: 横线预测mask(上采样后)
             vpred_up: 竖线预测mask(上采样后)
-            upscale: 上采样比例
+            upscale: 上采样比例(用于向后兼容,如果提供了 orig_h/orig_w 则会被覆盖)
+            orig_h: 原图的实际高度(用于计算真实的缩放比例)
+            orig_w: 原图的实际宽度(用于计算真实的缩放比例)
             debug_dir: 调试输出目录 (Optional)
             debug_prefix: 调试文件名前缀 (Optional)
             
@@ -145,7 +149,7 @@ class GridRecovery:
             # Pre-calculate Max Allowed Lengths (Original Length * Multiplier)
             # Multiplier = 2.0 means we allow the line to double in size, but not more.
             # This effectively stops short noise from becoming page-height lines.
-            extension_multiplier = 3.0 
+            extension_multiplier = 2.0 
             
             row_max_lens = [dist_sqrt(b[:2], b[2:]) * extension_multiplier for b in rowboxes]
             col_max_lens = [dist_sqrt(b[:2], b[2:]) * extension_multiplier for b in colboxes]
@@ -275,6 +279,23 @@ class GridRecovery:
         # 7. 连通域
         num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
         
+        # 计算真实的缩放比例
+        # 如果提供了原图尺寸,使用真实的缩放比例;否则使用 upscale(向后兼容)
+        if orig_h is not None and orig_w is not None and orig_h > 0 and orig_w > 0:
+            scale_h = h / orig_h
+            scale_w = w / orig_w
+            logger.debug(
+                f"连通域分析: mask尺寸=[{h}, {w}], 原图尺寸=[{orig_h}, {orig_w}], "
+                f"真实缩放比例=[{scale_h:.3f}, {scale_w:.3f}], upscale={upscale:.3f}"
+            )
+        else:
+            scale_h = upscale
+            scale_w = upscale
+            logger.debug(
+                f"连通域分析: mask尺寸=[{h}, {w}], upscale={upscale:.3f}, "
+                f"预期原图尺寸≈[{h/upscale:.1f}, {w/upscale:.1f}] (使用 upscale,未提供原图尺寸)"
+            )
+        
         bboxes = []
         
         # 8. 过滤
@@ -290,22 +311,35 @@ class GridRecovery:
             if area < 50:
                 continue
                 
-            orig_h = h_cell / upscale
-            orig_w = w_cell / upscale
+            # 使用真实的缩放比例转换为原图坐标
+            cell_orig_h = h_cell / scale_h
+            cell_orig_w = w_cell / scale_w
             
-            if orig_h < 4.0 or orig_w < 4.0:
+            if cell_orig_h < 4.0 or cell_orig_w < 4.0:
                 continue
             
             bboxes.append([
-                x / upscale,
-                y / upscale,
-                (x + w_cell) / upscale,
-                (y + h_cell) / upscale
+                x / scale_w,
+                y / scale_h,
+                (x + w_cell) / scale_w,
+                (y + h_cell) / scale_h
             ])
         
         bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
         
-        logger.info(f"矢量重构分析提取到 {len(bboxes)} 个单元格 (Dynamic Alpha: {dynamic_alpha})")
+        # 调试日志:输出样本 bbox 坐标信息
+        if len(bboxes) > 0:
+            logger.debug(f"样本 bbox (原图坐标): 前3个 = {bboxes[:3]}, 后3个 = {bboxes[-3:]}")
+            logger.debug(f"bbox 坐标范围: x=[{min(b[0] for b in bboxes):.1f}, {max(b[2] for b in bboxes):.1f}], "
+                        f"y=[{min(b[1] for b in bboxes):.1f}, {max(b[3] for b in bboxes):.1f}]")
+        
+        if orig_h is not None and orig_w is not None:
+            logger.info(
+                f"矢量重构分析提取到 {len(bboxes)} 个单元格 "
+                f"(Dynamic Alpha: {dynamic_alpha}, 真实缩放比例=[{scale_h:.3f}, {scale_w:.3f}])"
+            )
+        else:
+            logger.info(f"矢量重构分析提取到 {len(bboxes)} 个单元格 (Dynamic Alpha: {dynamic_alpha}, upscale={upscale:.3f})")
         
         return bboxes
 
@@ -466,33 +500,43 @@ class GridRecovery:
     def compress_grid(cells: List[Dict]) -> List[Dict]:
         """
         压缩网格索引,移除空行和空列
-        
+
+        算法解释:
+        1. 遍历所有单元格,查找每个单元格的起始行/列和其跨度(rowspan/colspan),计算出原始网格的最大行数和列数。
+        2. 使用布尔列表记录每一行和每一列是否有单元格占用(row_occupied/col_occupied)。如果某一行或列没有任何单元格,则为"空"行/列。
+        3. 根据占用情况遍历并为原网格的每一行/列建立到新索引的映射表(row_map/col_map)。未被占用(即为空)的行或列不会增加新的索引。
+        4. 根据刚刚生成的映射表,对所有单元格的row/col索引进行压缩更新,使它们在新的"紧凑"网格中连续、不含空行空列。
+        5. 返回更新后的单元格列表。
+
+        该算法本质上是:保留有效的单元格行列,剔除全空的行列,并将单元格的行列索引重新编号,使得新网格紧凑无缝隙。
+
         Args:
             cells: 单元格列表
-            
+
         Returns:
             压缩后的单元格列表
         """
+        # 源码实现见下方
         if not cells:
             return []
-        
+
         # 1. 计算当前最大行列
         max_row = 0
         max_col = 0
         for cell in cells:
             max_row = max(max_row, cell["row"] + cell.get("rowspan", 1))
             max_col = max(max_col, cell["col"] + cell.get("colspan", 1))
-        
+
         # 2. 标记占用情况
         row_occupied = [False] * max_row
         col_occupied = [False] * max_col
-        
+
         for cell in cells:
             if cell["row"] < max_row:
                 row_occupied[cell["row"]] = True
             if cell["col"] < max_col:
                 col_occupied[cell["col"]] = True
-        
+
         # 3. 构建映射表
         row_map = [0] * (max_row + 1)
         current_row = 0
@@ -500,38 +544,33 @@ class GridRecovery:
             if row_occupied[r]:
                 current_row += 1
             row_map[r + 1] = current_row
-        
+
         col_map = [0] * (max_col + 1)
         current_col = 0
         for c in range(max_col):
             if col_occupied[c]:
                 current_col += 1
             col_map[c + 1] = current_col
-        
+
         # 4. 更新单元格索引
         new_cells = []
         for cell in cells:
             new_cell = cell.copy()
-            
+
             old_r1 = cell["row"]
-            old_r2 = old_r1 + cell.get("rowspan", 1)
-            new_r1 = row_map[old_r1]
-            new_r2 = row_map[old_r2]
-            
             old_c1 = cell["col"]
+            old_r2 = old_r1 + cell.get("rowspan", 1)
             old_c2 = old_c1 + cell.get("colspan", 1)
-            new_c1 = col_map[old_c1]
-            new_c2 = col_map[old_c2]
-            
-            new_span_r = max(1, new_r2 - new_r1)
-            new_span_c = max(1, new_c2 - new_c1)
-            
-            new_cell["row"] = new_r1
-            new_cell["col"] = new_c1
-            new_cell["rowspan"] = new_span_r
-            new_cell["colspan"] = new_span_c
-            
+
+            new_row = row_map[old_r1]
+            new_col = col_map[old_c1]
+            new_rowspan = row_map[old_r2] - row_map[old_r1]
+            new_colspan = col_map[old_c2] - col_map[old_c1]
+
+            new_cell["row"] = new_row
+            new_cell["col"] = new_col
+            new_cell["rowspan"] = new_rowspan
+            new_cell["colspan"] = new_colspan
             new_cells.append(new_cell)
-        
-        return new_cells
 
+        return new_cells