Pārlūkot izejas kodu

feat: 重构单元格计算与网格恢复逻辑,增强对复杂表格的处理能力

zhch158_admin 6 dienas atpakaļ
vecāks
revīzija
60f761a6b5

+ 427 - 11
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -544,7 +544,7 @@ class MinerUWiredTableRecognizer:
         return str(soup)
     
     # ========== 基于表格线交点的单元格计算 ==========
-    def _compute_cells_from_lines(
+    def _compute_cells_from_lines_4_1(
         self,
         hpred_up: np.ndarray,
         vpred_up: np.ndarray,
@@ -626,18 +626,28 @@ class MinerUWiredTableRecognizer:
         for bbox in bboxes:
             h_cell = bbox[3] - bbox[1]
             w_cell = bbox[2] - bbox[0]
-            # 1. 绝对高度过滤:过滤极矮的噪点 (例如 < 6px)
+
+            # 1. 绝对高度过滤:过滤极矮的噪点
+            # 降低阈值到 6px,防止漏掉极小的字号
             if h_cell < 6.0:
                 continue
 
-            # 2. 相对高度+形态过滤:过滤扁长形的缝隙
-            # 场景:表格底部双线或粗线产生的缝隙,通常高度显著小于正常行,且宽度较大
-            # 阈值:高度小于中位数的 0.6 且 宽高比 > 5
-            # 例如你的case: h=12.9, median~28. h < 16.8 且 ratio=51 > 5 -> 过滤
-            elif median_h > 0 and h_cell < median_h * 0.6 and w_cell > h_cell * 5:
-                continue
-            elif median_h > 0 and h_cell < median_h * 0.33:
-                continue
+            # 2. 相对高度过滤 (更保守的策略)
+            # 仅当高度同时满足 "相对极小" AND "绝对较小" 时才过滤
+            # 这样可以防止在 median_h 很大(如160px)时误删正常的小行(如25px)
+            if median_h > 0:
+                ratio = h_cell / median_h
+                
+                # 策略A: 极矮行过滤
+                # 高度 < 10% median 且 绝对高度 < 10px
+                # (你的case: 25/164 = 0.15 > 0.1, 且 25 > 10, 故保留)
+                if ratio < 0.1 and h_cell < 10.0:
+                    continue
+                
+                # 策略B: 扁长缝隙过滤 (通常是双线造成的)
+                # 高度 < 20% median 且 宽高比 > 5 且 绝对高度 < 15px
+                if ratio < 0.2 and w_cell > h_cell * 5 and h_cell < 15.0:
+                    continue
 
             final_bboxes.append(bbox)
 
@@ -652,6 +662,104 @@ class MinerUWiredTableRecognizer:
             
         return bboxes
     
+    def _compute_cells_from_lines(
+        self,
+        hpred_up: np.ndarray,
+        vpred_up: np.ndarray,
+        upscale: float = 1.0,
+    ) -> List[List[float]]:
+        """
+        基于连通域分析从表格线 Mask 提取单元格 (健壮版)
+        
+        改进点:
+        1. 使用形态学闭运算(Closing)修复断线,而非简单膨胀。
+        2. 移除基于中位数的统计过滤,改为基于几何特征(长宽比、绝对尺寸)的过滤。
+        3. 专门处理双线表格产生的细长缝隙噪声。
+        """
+        h, w = hpred_up.shape[:2]
+        
+        # 1. 预处理:二值化
+        _, h_bin = cv2.threshold(hpred_up, 127, 255, cv2.THRESH_BINARY)
+        _, v_bin = cv2.threshold(vpred_up, 127, 255, cv2.THRESH_BINARY)
+        
+        # 2. 形态学修复:使用闭运算 (Closing) 连接断线
+        # 闭运算 = 先膨胀后腐蚀,能填补小黑洞(断点)而不改变整体轮廓大小
+        # 横线:横向连接能力强;竖线:竖向连接能力强
+        kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 1))
+        kernel_v = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 7))
+        
+        h_bin = cv2.morphologyEx(h_bin, cv2.MORPH_CLOSE, kernel_h)
+        v_bin = cv2.morphologyEx(v_bin, cv2.MORPH_CLOSE, kernel_v)
+        
+        # 3. 合成网格图
+        grid_mask = cv2.bitwise_or(h_bin, v_bin)
+        
+        # 4. 反转图像 (黑线白底),提取白色连通域
+        inv_grid = cv2.bitwise_not(grid_mask)
+        
+        # 5. 提取连通域
+        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
+        
+        bboxes = []
+        
+        # 6. 几何特征过滤 (不依赖全局统计,只看个体特征)
+        for i in range(1, num_labels):
+            x = stats[i, cv2.CC_STAT_LEFT]
+            y = stats[i, cv2.CC_STAT_TOP]
+            w_cell = stats[i, cv2.CC_STAT_WIDTH]
+            h_cell = stats[i, cv2.CC_STAT_HEIGHT]
+            area = stats[i, cv2.CC_STAT_AREA]
+            
+            # --- 过滤规则 ---
+            
+            # A. 边界与背景过滤
+            # 过滤掉几乎占据全图的背景,或极小的噪点
+            if area < 50 or w_cell > w * 0.95 or h_cell > h * 0.95:
+                continue
+
+            # 转换到原图尺度进行判断
+            orig_h = h_cell / upscale
+            orig_w = w_cell / upscale
+            
+            # B. 绝对尺寸过滤 (物理极限)
+            # 任何小于 4px 的东西都不可能是有效的文本单元格
+            if orig_h < 4.0 or orig_w < 4.0:
+                continue
+                
+            # C. 形态过滤:双线缝隙 (Sliver)
+            # 特征:长宽比极大,且短边极短
+            # 例如:宽100,高5 -> 比例20,且高<12 -> 判定为横向缝隙
+            # 例如:宽5,高100 -> 比例0.05,且宽<12 -> 判定为纵向缝隙
+            
+            ratio = w_cell / h_cell
+            
+            # 横向缝隙 (极扁)
+            if ratio > 10.0 and orig_h < 15.0:
+                continue
+                
+            # 纵向缝隙 (极细)
+            if ratio < 0.1 and orig_w < 12.0:
+                continue
+            
+            # D. 包含关系过滤 (可选,但 CCA 通常不会产生重叠框)
+            # 如果需要处理嵌套表格,这里需要更复杂的逻辑,但一般表格不需要。
+
+            # 通过所有检查,加入结果
+            bboxes.append([
+                x / upscale,
+                y / upscale,
+                (x + w_cell) / upscale,
+                (y + h_cell) / upscale
+            ])
+            
+        # 按阅读顺序排序 (先上后下,再左后右)
+        # 允许 10px 的行误差,防止轻微歪斜导致的排序混乱
+        bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
+        
+        logger.info(f"连通域分析提取到 {len(bboxes)} 个单元格")
+            
+        return bboxes
+
     def _visualize_detected_lines(
         self,
         hpred: np.ndarray,
@@ -796,7 +904,7 @@ class MinerUWiredTableRecognizer:
             
         return new_cells
 
-    def _recover_grid_structure(self, bboxes: List[List[float]]) -> List[Dict]:
+    def _recover_grid_structure_4_1(self, bboxes: List[List[float]]) -> List[Dict]:
         """
         从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
         改进版:使用边界投影聚类,更稳健
@@ -883,6 +991,314 @@ class MinerUWiredTableRecognizer:
         
         return structured_cells
 
+    def _recover_grid_structure_4_2(self, bboxes: List[List[float]]) -> List[Dict]:
+        """
+        从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
+        重构版:基于标准行骨架的匹配,解决密集行与跨行单元格混合的问题
+        """
+        if not bboxes:
+            return []
+
+        # --- 1. 识别行结构 (Row Structure) ---
+        
+        # 计算高度中位数,用于区分"标准行"和"跨行单元格"
+        heights = [b[3] - b[1] for b in bboxes]
+        median_h = np.median(heights) if heights else 0
+        
+        # 定义标准行单元格:高度在 [0.5, 1.5] 倍中位数之间
+        # 这样可以排除跨行的大单元格,也可以排除极小的噪点
+        standard_cells = []
+        for i, bbox in enumerate(bboxes):
+            h = bbox[3] - bbox[1]
+            if median_h > 0 and 0.5 * median_h < h < 1.8 * median_h:
+                standard_cells.append({"bbox": bbox, "index": i})
+        
+        # 兜底:如果找不到标准行(比如表格全是奇怪的单元格),则使用所有单元格
+        if not standard_cells:
+            standard_cells = [{"bbox": b, "index": i} for i, b in enumerate(bboxes)]
+
+        # 对标准单元格按 Y 中心排序
+        standard_cells.sort(key=lambda x: (x["bbox"][1] + x["bbox"][3]) / 2)
+        
+        # 贪心聚类生成"行骨架"
+        # rows_defs 存储每一行的垂直范围 {'top': y1, 'bottom': y2, 'center': yc}
+        rows_defs = []
+        
+        for item in standard_cells:
+            box = item["bbox"]
+            cy = (box[1] + box[3]) / 2
+            
+            # 尝试匹配已有的行
+            matched = False
+            for r_def in rows_defs:
+                # 判断条件:中心点距离小于行高的一半 (假设行高近似 median_h)
+                # 或者:垂直重叠率高
+                r_h = r_def['bottom'] - r_def['top']
+                ref_h = max(r_h, median_h) # 参考高度
+                
+                if abs(cy - r_def['center']) < ref_h * 0.6:
+                    # 匹配成功,更新行范围
+                    r_def['top'] = min(r_def['top'], box[1])
+                    r_def['bottom'] = max(r_def['bottom'], box[3])
+                    r_def['center'] = (r_def['top'] + r_def['bottom']) / 2
+                    matched = True
+                    break
+            
+            if not matched:
+                rows_defs.append({
+                    'top': box[1],
+                    'bottom': box[3],
+                    'center': cy
+                })
+        
+        # 对行骨架按位置排序
+        rows_defs.sort(key=lambda x: x['center'])
+        
+        # 合并靠得太近的行骨架 (防止过度切分)
+        # 阈值:0.5 * median_h
+        merged_rows = []
+        if rows_defs:
+            curr = rows_defs[0]
+            for next_row in rows_defs[1:]:
+                if next_row['center'] - curr['center'] < median_h * 0.5:
+                    # 合并
+                    curr['top'] = min(curr['top'], next_row['top'])
+                    curr['bottom'] = max(curr['bottom'], next_row['bottom'])
+                    curr['center'] = (curr['top'] + curr['bottom']) / 2
+                else:
+                    merged_rows.append(curr)
+                    curr = next_row
+            merged_rows.append(curr)
+        rows_defs = merged_rows
+
+        # --- 2. 识别列结构 (Col Structure) ---
+        # 列分割线逻辑保持不变,通常列比较规整
+        x_coords = []
+        for b in bboxes:
+            x_coords.append(b[0])
+            x_coords.append(b[2])
+        x_coords.sort()
+        
+        col_dividers = []
+        if x_coords:
+            thresh = 5 # 列间隙阈值
+            curr_cluster = [x_coords[0]]
+            for x in x_coords[1:]:
+                if x - curr_cluster[-1] < thresh:
+                    curr_cluster.append(x)
+                else:
+                    col_dividers.append(sum(curr_cluster)/len(curr_cluster))
+                    curr_cluster = [x]
+            col_dividers.append(sum(curr_cluster)/len(curr_cluster))
+            
+        # --- 3. 匹配单元格到网格 ---
+        structured_cells = []
+        for bbox in bboxes:
+            # --- 匹配行 (Row) ---
+            b_top, b_bottom = bbox[1], bbox[3]
+            b_h = b_bottom - b_top
+            
+            matched_row_indices = []
+            
+            for r_idx, r_def in enumerate(rows_defs):
+                # 计算 Y 轴重叠
+                inter_top = max(b_top, r_def['top'])
+                inter_bottom = min(b_bottom, r_def['bottom'])
+                inter_h = max(0, inter_bottom - inter_top)
+                
+                r_h = r_def['bottom'] - r_def['top']
+                
+                # 判定覆盖:
+                # 1. 单元格覆盖了该行的大部分 (跨行情况) -> inter_h / r_h > 0.5
+                # 2. 该行覆盖了单元格的大部分 (小单元格情况) -> inter_h / b_h > 0.5
+                if r_h > 0 and (inter_h / r_h > 0.5 or inter_h / b_h > 0.5):
+                    matched_row_indices.append(r_idx)
+            
+            if not matched_row_indices:
+                # 兜底:找中心点最近的行
+                cy = (b_top + b_bottom) / 2
+                closest_r = min(range(len(rows_defs)), key=lambda i: abs(rows_defs[i]['center'] - cy))
+                matched_row_indices = [closest_r]
+            
+            row_start = min(matched_row_indices)
+            row_end = max(matched_row_indices)
+            rowspan = row_end - row_start + 1
+            
+            # --- 匹配列 (Col) ---
+            # 找左右边界最近的 divider
+            c1 = 0
+            c2 = 0
+            if len(col_dividers) >= 2:
+                c1 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - bbox[0]))
+                c2 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - bbox[2]))
+                if c1 > c2: c1, c2 = c2, c1
+            
+            colspan = max(1, c2 - c1)
+            
+            structured_cells.append({
+                "bbox": bbox,
+                "row": row_start,
+                "col": c1,
+                "rowspan": rowspan,
+                "colspan": colspan
+            })
+
+        # 按行列排序
+        structured_cells.sort(key=lambda c: (c["row"], c["col"]))
+
+        # 压缩网格,移除空行空列
+        structured_cells = self._compress_grid(structured_cells)
+        
+        return structured_cells
+
+    def _recover_grid_structure(self, bboxes: List[List[float]]) -> List[Dict]:
+        """
+        从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
+        重构版:基于投影网格线 (Projected Grid Lines) 的算法
+        适用于行高差异巨大、存在密集小行的复杂表格
+        """
+        if not bboxes:
+            return []
+
+        # 1. 识别行分割线 (Y轴)
+        # 收集所有单元格的 top 和 bottom
+        y_coords = []
+        for b in bboxes:
+            y_coords.append(b[1])
+            y_coords.append(b[3])
+        
+        # 聚类并筛选有效的行网格线
+        # 阈值:5像素容差,至少对齐 2 个单元格 (防止噪点)
+        row_dividers = self._find_grid_lines(y_coords, tolerance=5, min_support=2)
+        
+        # 2. 识别列分割线 (X轴)
+        x_coords = []
+        for b in bboxes:
+            x_coords.append(b[0])
+            x_coords.append(b[2])
+        col_dividers = self._find_grid_lines(x_coords, tolerance=5, min_support=2)
+
+        # 3. 构建网格结构
+        structured_cells = []
+        
+        # 定义行区间 (Row Intervals)
+        row_intervals = []
+        for i in range(len(row_dividers) - 1):
+            row_intervals.append({
+                "top": row_dividers[i],
+                "bottom": row_dividers[i+1],
+                "height": row_dividers[i+1] - row_dividers[i],
+                "index": i
+            })
+            
+        # 定义列区间 (Col Intervals)
+        col_intervals = []
+        for i in range(len(col_dividers) - 1):
+            col_intervals.append({
+                "left": col_dividers[i],
+                "right": col_dividers[i+1],
+                "width": col_dividers[i+1] - col_dividers[i],
+                "index": i
+            })
+
+        for bbox in bboxes:
+            b_top, b_bottom = bbox[1], bbox[3]
+            b_left, b_right = bbox[0], bbox[2]
+            b_h = b_bottom - b_top
+            b_w = b_right - b_left
+
+            # --- 匹配行 (Row) ---
+            matched_rows = []
+            for r in row_intervals:
+                # 计算垂直重叠
+                inter_top = max(b_top, r["top"])
+                inter_bottom = min(b_bottom, r["bottom"])
+                inter_h = max(0, inter_bottom - inter_top)
+                
+                # 判定属于该行的条件:
+                # 1. 单元格覆盖了该行的大部分 (inter_h / r_height > 0.5) -> 适用于跨行单元格覆盖矮行
+                # 2. 该行覆盖了单元格的大部分 (inter_h / b_h > 0.5) -> 适用于单元格完全在行内
+                if r["height"] > 0 and (inter_h / r["height"] > 0.5 or inter_h / b_h > 0.5):
+                    matched_rows.append(r["index"])
+            
+            if not matched_rows:
+                # 兜底:找中心点所在的行
+                cy = (b_top + b_bottom) / 2
+                closest_r = min(row_intervals, key=lambda r: abs((r["top"]+r["bottom"])/2 - cy))
+                matched_rows = [closest_r["index"]]
+
+            row_start = min(matched_rows)
+            row_end = max(matched_rows)
+            rowspan = row_end - row_start + 1
+
+            # --- 匹配列 (Col) ---
+            matched_cols = []
+            for c in col_intervals:
+                inter_left = max(b_left, c["left"])
+                inter_right = min(b_right, c["right"])
+                inter_w = max(0, inter_right - inter_left)
+                
+                if c["width"] > 0 and (inter_w / c["width"] > 0.5 or inter_w / b_w > 0.5):
+                    matched_cols.append(c["index"])
+            
+            if not matched_cols:
+                cx = (b_left + b_right) / 2
+                closest_c = min(col_intervals, key=lambda c: abs((c["left"]+c["right"])/2 - cx))
+                matched_cols = [closest_c["index"]]
+
+            col_start = min(matched_cols)
+            col_end = max(matched_cols)
+            colspan = col_end - col_start + 1
+
+            structured_cells.append({
+                "bbox": bbox,
+                "row": row_start,
+                "col": col_start,
+                "rowspan": rowspan,
+                "colspan": colspan
+            })
+
+        # 按行列排序
+        structured_cells.sort(key=lambda c: (c["row"], c["col"]))
+        
+        # 压缩网格 (移除空行空列)
+        structured_cells = self._compress_grid(structured_cells)
+        
+        return structured_cells
+
+    def _find_grid_lines(self, coords: List[float], tolerance: float = 5.0, min_support: int = 2) -> List[float]:
+        """
+        聚类坐标点并筛选出高支持度的网格线
+        """
+        if not coords:
+            return []
+        
+        coords.sort()
+        
+        # 1. 简单聚类
+        clusters = []
+        if coords:
+            curr_cluster = [coords[0]]
+            for x in coords[1:]:
+                if x - curr_cluster[-1] < tolerance:
+                    curr_cluster.append(x)
+                else:
+                    clusters.append(curr_cluster)
+                    curr_cluster = [x]
+            clusters.append(curr_cluster)
+        
+        # 2. 计算聚类中心和支持度
+        grid_lines = []
+        for cluster in clusters:
+            # 支持度 = 该位置出现的坐标点数量
+            # 注意:这里传入的是所有box的边,所以支持度直接反映了有多少个单元格对齐到了这条线
+            if len(cluster) >= min_support:
+                center = sum(cluster) / len(cluster)
+                grid_lines.append(center)
+        
+        return grid_lines
+
+
     def _build_html_from_merged_cells(self, merged_cells: List[Dict]) -> str:
         """
         基于矩阵填充法生成 HTML,防止错位