瀏覽代碼

feat: 增强有线表格识别功能,支持自定义后处理和坐标格式转换

zhch158_admin 1 周之前
父節點
當前提交
3ed2659206
共有 1 個文件被更改,包括 912 次插入30 次删除
  1. 912 30
      ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

+ 912 - 30
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -1,9 +1,14 @@
 import sys
-from typing import Any, Dict, List, Tuple, cast
+import html
+import copy
+from typing import Any, Dict, List, Tuple, Optional, cast
+import ast
 
 import cv2
 import numpy as np
 
+from loguru import logger
+
 # 确保 mineru 库可导入
 mineru_path = str((__file__ and __file__) and __file__)
 # 使用已有 mineru_adapter 中的路径追加逻辑
@@ -16,7 +21,12 @@ from mineru.model.table.rec.unet_table.main import UnetTableModel
 
 
 class MinerUWiredTableRecognizer:
-    """有线表格识别封装:裁剪+放大→UNet→坐标回写+按中心点匹配OCR文本"""
+    """有线表格识别封装:裁剪+放大→UNet→坐标回写+按中心点匹配OCR文本
+    
+    支持两种后处理模式:
+    - recognize_legacy(): 原始流程,使用MinerU的plot_html_table
+    - recognize_v4(): 改进流程,使用自定义HTML生成和文本填充(支持data-bbox属性)
+    """
 
     def __init__(self, config: Dict[str, Any], ocr_engine: Any):
         self.config = config or {}
@@ -26,9 +36,28 @@ class MinerUWiredTableRecognizer:
         self.col_threshold: int = self.config.get("col_threshold", 15)
         self.ocr_conf_threshold: float = self.config.get("ocr_conf_threshold", 0.5)
         self.cell_crop_margin: int = self.config.get("cell_crop_margin", 2)
+        # 是否使用自定义后处理(v2),默认启用
+        self.use_custom_postprocess: bool = self.config.get("use_custom_postprocess", True)
         self.table_model = UnetTableModel(ocr_engine)
         self.ocr_engine = ocr_engine
 
+    # ========== 坐标格式转换工具 ==========
+    
+    @staticmethod
+    def _normalize_bbox(box: List[float]) -> List[float]:
+        """将8点或4点坐标统一转换为 [x_min, y_min, x_max, y_max] 格式"""
+        if not box:
+            return []
+        if len(box) == 8:
+            xs = [box[0], box[2], box[4], box[6]]
+            ys = [box[1], box[3], box[5], box[7]]
+            return [min(xs), min(ys), max(xs), max(ys)]
+        elif len(box) == 4:
+            # 已经是4点格式,确保是 [x_min, y_min, x_max, y_max]
+            x1, y1, x2, y2 = box
+            return [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)]
+        return []
+    
     @staticmethod
     def _to_unet_ocr_format(ocr_boxes: List[Dict[str, Any]]) -> List[List[Any]]:
         """将OCR结果转成 UNet 期望格式 [[poly4,text,score], ...],坐标用浮点。"""
@@ -57,44 +86,263 @@ class MinerUWiredTableRecognizer:
 
     @staticmethod
     def _poly_to_bbox(poly: np.ndarray) -> List[float]:
+        """将4点多边形转换为 [x_min, y_min, x_max, y_max]"""
         xs = poly[:, 0]
         ys = poly[:, 1]
         return [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())]
 
-    def _match_text_by_center(
+    # ========== 行列分组与网格计算 (修复版) ==========
+    
+    def _group_cells_into_rows(self, bboxes: List[List[float]]) -> List[List[int]]:
+        """
+        按垂直位置将单元格分组到行 (修复版)
+        
+        使用单元格的垂直中心点进行聚类分组
+        """
+        if not bboxes:
+            return []
+        
+        # 计算每个单元格的垂直中心和高度
+        cells_info = []
+        for i, bbox in enumerate(bboxes):
+            y_center = (bbox[1] + bbox[3]) / 2
+            height = bbox[3] - bbox[1]
+            cells_info.append({
+                'index': i,
+                'y_center': y_center,
+                'y_min': bbox[1],
+                'y_max': bbox[3],
+                'height': height,
+                'bbox': bbox
+            })
+        
+        # 按y中心排序
+        cells_info.sort(key=lambda c: c['y_center'])
+        
+        # 计算自适应行高阈值(使用高度的中位数)
+        heights = [c['height'] for c in cells_info if c['height'] > 0]
+        if not heights:
+            return [[i for i in range(len(bboxes))]]
+        
+        median_height = sorted(heights)[len(heights) // 2]
+        # 行分组阈值:同一行的单元格y中心差异不超过中位高度的40%
+        row_thresh = median_height * 0.4
+        
+        logger.debug(f"行分组: median_height={median_height:.1f}, row_thresh={row_thresh:.1f}")
+        
+        # 基于y中心进行分组
+        rows = []
+        current_row = [cells_info[0]['index']]
+        current_row_y_centers = [cells_info[0]['y_center']]
+        
+        for cell in cells_info[1:]:
+            # 计算当前行的平均y中心
+            avg_y_center = sum(current_row_y_centers) / len(current_row_y_centers)
+            
+            # 如果当前单元格的y中心与行平均y中心的差距在阈值内,加入当前行
+            if abs(cell['y_center'] - avg_y_center) <= row_thresh:
+                current_row.append(cell['index'])
+                current_row_y_centers.append(cell['y_center'])
+            else:
+                # 开始新行
+                rows.append(current_row)
+                current_row = [cell['index']]
+                current_row_y_centers = [cell['y_center']]
+        
+        # 添加最后一行
+        if current_row:
+            rows.append(current_row)
+        
+        # 每行内按x坐标排序
+        for row in rows:
+            row.sort(key=lambda i: bboxes[i][0])
+        
+        logger.info(f"行分组结果: {len(rows)} 行, 每行单元格数: {[len(r) for r in rows[:10]]}...")
+        
+        return rows
+    
+    def _find_grid_index(self, value: float, edges: List[float]) -> int:
+        """
+        找到值对应的网格索引
+        
+        边界将坐标空间划分为 N-1 个网格区间
+        返回 value 所在的网格区间索引
+        """
+        if not edges:
+            return 0
+        
+        if len(edges) == 1:
+            return 0
+        
+        # value 小于第一个边界
+        if value <= edges[0]:
+            return 0
+        
+        # value 大于最后一个边界
+        if value >= edges[-1]:
+            return len(edges) - 2
+        
+        # 找到 value 所在的区间 [edges[i], edges[i+1])
+        for i in range(len(edges) - 1):
+            if edges[i] <= value < edges[i + 1]:
+                return i
+        
+        return len(edges) - 2
+    
+    # ========== HTML生成与文本填充 ==========
+    
+    def _plot_html_with_bbox(
         self,
-        cells_bbox: List[List[float]],
+        bboxes: List[List[float]],
+        logic_points: List[List[int]],
+        texts: List[str],
+        row_edges: List[float],
+        col_edges: List[float],
+    ) -> str:
+        """
+        生成带 data-bbox 属性的 HTML 表格。
+        
+        直接在生成 <td> 时附加 data-bbox="[x1,y1,x2,y2]"(原图坐标)。
+        
+        Args:
+            bboxes: 单元格坐标(已还原到原图坐标)
+            logic_points: 逻辑坐标 [[row_start, row_end, col_start, col_end], ...]
+            texts: 单元格文本
+            row_edges: 行边界
+            col_edges: 列边界
+            
+        Returns:
+            HTML字符串
+        """
+        if not bboxes or len(row_edges) < 2 or len(col_edges) < 2:
+            return ""
+        
+        n_rows = len(row_edges) - 1
+        n_cols = len(col_edges) - 1
+        
+        # 构建网格,记录每个格子对应的单元格索引
+        # -1 表示被合并单元格占用,None 表示空
+        grid: List[List[Optional[int]]] = [[None for _ in range(n_cols)] for _ in range(n_rows)]
+        
+        for idx, lp in enumerate(logic_points):
+            r0, r1, c0, c1 = lp
+            if r0 < 0 or c0 < 0 or r1 >= n_rows or c1 >= n_cols:
+                continue
+            for rr in range(r0, r1 + 1):
+                for cc in range(c0, c1 + 1):
+                    if rr == r0 and cc == c0:
+                        grid[rr][cc] = idx  # 主格
+                    else:
+                        grid[rr][cc] = -1  # 占位(被合并)
+        
+        # 生成HTML
+        html_parts = ["<table>", "<tbody>"]
+        
+        for r in range(n_rows):
+            html_parts.append("<tr>")
+            c = 0
+            while c < n_cols:
+                cell_idx = grid[r][c]
+                
+                if cell_idx is None:
+                    # 空格子,输出空td
+                    html_parts.append("<td></td>")
+                    c += 1
+                elif cell_idx == -1:
+                    # 被合并,跳过
+                    c += 1
+                else:
+                    # 主格,输出带span的td
+                    lp = logic_points[cell_idx]
+                    rowspan = lp[1] - lp[0] + 1
+                    colspan = lp[3] - lp[2] + 1
+                    bbox = bboxes[cell_idx]
+                    text = html.escape(texts[cell_idx]) if cell_idx < len(texts) else ""
+                    
+                    bbox_str = f"[{int(bbox[0])},{int(bbox[1])},{int(bbox[2])},{int(bbox[3])}]"
+                    
+                    if rowspan > 1 or colspan > 1:
+                        html_parts.append(
+                            f'<td data-bbox="{bbox_str}" rowspan="{rowspan}" colspan="{colspan}">{text}</td>'
+                        )
+                    else:
+                        html_parts.append(f'<td data-bbox="{bbox_str}">{text}</td>')
+                    
+                    c += colspan
+            
+            html_parts.append("</tr>")
+        
+        html_parts.append("</tbody>")
+        html_parts.append("</table>")
+        
+        return "".join(html_parts)
+    
+    def _fill_text_by_center_point(
+        self,
+        bboxes: List[List[float]],
         ocr_boxes: List[Dict[str, Any]],
     ) -> List[str]:
-        """使用中心点落格分配文本,行内按 y 排序后拼接。"""
-        texts_per_cell: List[str] = []
-        centers = []
+        """
+        使用中心点落格策略填充文本。
+        
+        参考 fill_html_with_ocr_by_bbox:
+        - OCR文本中心点落入单元格bbox内则匹配
+        - 多行文本按y坐标排序拼接
+        
+        Args:
+            bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
+            ocr_boxes: OCR结果 [{"bbox": [...], "text": "..."}, ...]
+            
+        Returns:
+            每个单元格的文本列表
+        """
+        texts: List[str] = ["" for _ in bboxes]
+        
+        if not ocr_boxes:
+            return texts
+        
+        # 预处理OCR结果:计算中心点
+        ocr_items: List[Dict[str, Any]] = []
         for item in ocr_boxes:
-            poly = item.get("bbox", [])
-            if not poly:
+            box = self._normalize_bbox(item.get("bbox", []))
+            if not box:
                 continue
-            if len(poly) == 8:
-                xs = [poly[i] for i in range(0, 8, 2)]
-                ys = [poly[i] for i in range(1, 8, 2)]
-                cx = (min(xs) + max(xs)) / 2
-                cy = (min(ys) + max(ys)) / 2
-            elif len(poly) == 4:
-                x1, y1, x2, y2 = poly
-                cx = (x1 + x2) / 2
-                cy = (y1 + y2) / 2
-            else:
-                continue
-            centers.append((cx, cy, item.get("text", ""), item.get("confidence", 0.0)))
-
-        for bbox in cells_bbox:
+            cx = (box[0] + box[2]) / 2
+            cy = (box[1] + box[3]) / 2
+            ocr_items.append({
+                "center_x": cx,
+                "center_y": cy,
+                "y1": box[1],
+                "text": item.get("text", ""),
+                "confidence": item.get("confidence", 0.0),
+            })
+        
+        # 为每个单元格匹配OCR文本
+        for idx, bbox in enumerate(bboxes):
             x1, y1, x2, y2 = bbox
-            collected = [(t, cy) for cx, cy, t, conf in centers if x1 <= cx <= x2 and y1 <= cy <= y2]
-            collected.sort(key=lambda x: x[1])
-            cell_text = " ".join([t for t, _ in collected]) if collected else ""
-            texts_per_cell.append(cell_text)
-        return texts_per_cell
-
-    def recognize(
+            matched: List[Tuple[str, float]] = []
+            
+            for ocr in ocr_items:
+                if x1 <= ocr["center_x"] <= x2 and y1 <= ocr["center_y"] <= y2:
+                    matched.append((ocr["text"], ocr["y1"]))
+            
+            if matched:
+                # 按y坐标排序,确保多行文本顺序正确
+                matched.sort(key=lambda x: x[1])
+                texts[idx] = " ".join([t for t, _ in matched])
+        
+        return texts
+    
+    def _match_text_by_center(
+        self,
+        cells_bbox: List[List[float]],
+        ocr_boxes: List[Dict[str, Any]],
+    ) -> List[str]:
+        """使用中心点落格分配文本,行内按 y 排序后拼接。(旧版兼容)"""
+        return self._fill_text_by_center_point(cells_bbox, ocr_boxes)
+    
+    
+    def recognize_legacy(
         self,
         table_image: np.ndarray,
         ocr_boxes: List[Dict[str, Any]],
@@ -272,3 +520,637 @@ class MinerUWiredTableRecognizer:
                 col_idx += colspan
         
         return str(soup)
+    
+    # ========== 基于表格线交点的单元格计算 ==========
+    def _compute_cells_from_lines(
+        self,
+        hpred_up: np.ndarray,
+        vpred_up: np.ndarray,
+        upscale: float = 1.0,
+        debug_output_dir: Optional[str] = None
+    ) -> List[List[float]]:
+        """
+        基于连通域分析从表格线 Mask 提取单元格
+        
+        原理:横竖线叠加 -> 反色 -> 提取白色连通块 -> 也就是单元格
+        """
+        h, w = hpred_up.shape[:2]
+        
+        # 1. 预处理:二值化
+        _, h_bin = cv2.threshold(hpred_up, 127, 255, cv2.THRESH_BINARY)
+        _, v_bin = cv2.threshold(vpred_up, 127, 255, cv2.THRESH_BINARY)
+        
+        # 2. 形态学连接:轻微膨胀以闭合可能的断点
+        # 横线横向膨胀,竖线竖向膨胀
+        kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 1))
+        kernel_v = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 5))
+        h_bin = cv2.dilate(h_bin, kernel_h, iterations=1)
+        v_bin = cv2.dilate(v_bin, kernel_v, iterations=1)
+        
+        # 3. 合成网格图 (白线黑底)
+        grid_mask = cv2.bitwise_or(h_bin, v_bin)
+        
+        # 4. 反转图像 (黑线白底),此时单元格变成白色连通域
+        inv_grid = cv2.bitwise_not(grid_mask)
+        
+        # 5. 提取连通域
+        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
+        
+        bboxes = []
+        # 过滤掉背景(label=0)和过小的噪声
+        min_area = 50  # 最小面积阈值
+        
+        for i in range(1, num_labels):
+            area = stats[i, cv2.CC_STAT_AREA]
+            if area < min_area:
+                continue
+            
+            x = stats[i, cv2.CC_STAT_LEFT]
+            y = stats[i, cv2.CC_STAT_TOP]
+            w_cell = stats[i, cv2.CC_STAT_WIDTH]
+            h_cell = stats[i, cv2.CC_STAT_HEIGHT]
+            
+            # 过滤掉长条形的非单元格区域(例如边缘的细长空白)
+            if w_cell > w * 0.95 or h_cell > h * 0.95:
+                continue
+                
+            # 还原到原图坐标
+            # 注意:连通域提取的是内部空白,实际单元格边界应该包含线条的一半宽度
+            # 这里简单处理,直接使用内部空白作为 bbox,OCR 匹配时通常足够
+            bboxes.append([
+                x / upscale,
+                y / upscale,
+                (x + w_cell) / upscale,
+                (y + h_cell) / upscale
+            ])
+            
+        # 按阅读顺序排序 (先上后下,再左后右)
+        # 允许一定的行误差
+        bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
+        
+        logger.info(f"连通域分析提取到 {len(bboxes)} 个单元格")
+        
+        # 调试可视化
+        if debug_output_dir:
+            vis = np.zeros((h, w, 3), dtype=np.uint8)
+            vis[grid_mask > 0] = [0, 0, 255] # 红色线条
+            
+            # 绘制提取出的框
+            for i, box in enumerate(bboxes):
+                x1, y1, x2, y2 = [int(c * upscale) for c in box]
+                cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
+                
+            cv2.imwrite(f"{debug_output_dir}/connected_components.png", vis)
+            
+        return bboxes
+    
+    def _visualize_detected_lines(
+        self,
+        hpred: np.ndarray,
+        vpred: np.ndarray,
+        h_lines_y: List[int],
+        v_lines_x: List[int],
+        output_path: str
+    ):
+        """
+        可视化检测到的横竖线位置
+        """
+        h, w = hpred.shape[:2]
+        
+        # 创建彩色图像
+        vis_img = np.zeros((h, w, 3), dtype=np.uint8)
+        
+        # 显示原始mask(淡色背景)
+        vis_img[hpred > 128] = [100, 100, 255]  # 淡红色横线
+        vis_img[vpred > 128] = [255, 100, 100]  # 淡蓝色竖线
+        
+        # 绘制检测到的横线位置(亮绿色)
+        for y in h_lines_y:
+            if 0 <= y < h:
+                cv2.line(vis_img, (0, y), (w, y), (0, 255, 0), 2)
+        
+        # 绘制检测到的竖线位置(亮黄色)
+        for x in v_lines_x:
+            if 0 <= x < w:
+                cv2.line(vis_img, (x, 0), (x, h), (0, 255, 255), 2)
+        
+        # 添加标注
+        cv2.putText(
+            vis_img, f"H-lines: {len(h_lines_y)}, V-lines: {len(v_lines_x)}",
+            (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2
+        )
+        
+        cv2.imwrite(output_path, vis_img)
+        logger.info(f"检测线可视化: {output_path}")
+    
+    
+    def _recover_grid_structure(self, bboxes: List[List[float]]) -> List[Dict]:
+        """
+        从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
+        改进版:使用边界投影聚类,更稳健
+        """
+        if not bboxes:
+            return []
+
+        # 1. 收集所有 y 坐标 (top, bottom) 并聚类得到行分割线
+        y_coords = []
+        for b in bboxes:
+            y_coords.append(b[1])
+            y_coords.append(b[3])
+        y_coords.sort()
+        
+        row_dividers = []
+        if y_coords:
+            # 聚类阈值:行高的一小部分,例如 10px
+            threshold = 10 
+            curr = [y_coords[0]]
+            for y in y_coords[1:]:
+                if y - curr[-1] < threshold:
+                    curr.append(y)
+                else:
+                    row_dividers.append(sum(curr)/len(curr))
+                    curr = [y]
+            row_dividers.append(sum(curr)/len(curr))
+            
+        # 2. 收集所有 x 坐标 (left, right) 并聚类得到列分割线
+        x_coords = []
+        for b in bboxes:
+            x_coords.append(b[0])
+            x_coords.append(b[2])
+        x_coords.sort()
+        
+        col_dividers = []
+        if x_coords:
+            threshold = 10
+            curr = [x_coords[0]]
+            for x in x_coords[1:]:
+                if x - curr[-1] < threshold:
+                    curr.append(x)
+                else:
+                    col_dividers.append(sum(curr)/len(curr))
+                    curr = [x]
+            col_dividers.append(sum(curr)/len(curr))
+            
+        # 3. 匹配单元格到网格
+        structured_cells = []
+        for bbox in bboxes:
+            x1, y1, x2, y2 = bbox
+            
+            # 找最近的分割线索引
+            # Row Start: 离 y1 最近的 divider
+            r1 = min(range(len(row_dividers)), key=lambda i: abs(row_dividers[i] - y1))
+            # Row End: 离 y2 最近的 divider
+            r2 = min(range(len(row_dividers)), key=lambda i: abs(row_dividers[i] - y2))
+            
+            # Col Start: 离 x1 最近的 divider
+            c1 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - x1))
+            # Col End: 离 x2 最近的 divider
+            c2 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - x2))
+            
+            # 修正:防止 span=0
+            if r1 == r2: r2 = r1 + 1
+            if c1 == c2: c2 = c1 + 1
+            
+            # 确保顺序
+            if r1 > r2: r1, r2 = r2, r1
+            if c1 > c2: c1, c2 = c2, c1
+            
+            structured_cells.append({
+                "bbox": bbox,
+                "row": r1,
+                "col": c1,
+                "rowspan": r2 - r1,
+                "colspan": c2 - c1
+            })
+            
+        # 按行列排序
+        structured_cells.sort(key=lambda c: (c["row"], c["col"]))
+        
+        return structured_cells
+
+    def _build_html_from_merged_cells(self, merged_cells: List[Dict]) -> str:
+        """
+        基于矩阵填充法生成 HTML,防止错位
+        """
+        if not merged_cells:
+            return "<table><tbody></tbody></table>"
+        
+        # 1. 计算网格尺寸
+        max_row = 0
+        max_col = 0
+        for cell in merged_cells:
+            max_row = max(max_row, cell["row"] + cell.get("rowspan", 1))
+            max_col = max(max_col, cell["col"] + cell.get("colspan", 1))
+            
+        # 2. 构建占用矩阵 (True 表示该位置已被占据)
+        occupied = [[False for _ in range(max_col)] for _ in range(max_row)]
+        
+        # 3. 将单元格放入查找表,方便按坐标检索
+        cell_map = {}
+        for cell in merged_cells:
+            key = (cell["row"], cell["col"])
+            cell_map[key] = cell
+            
+        html_parts = ["<table><tbody>"]
+        
+        # 4. 逐行逐列扫描
+        for r in range(max_row):
+            html_parts.append("<tr>")
+            for c in range(max_col):
+                # 如果该位置已被之前的 rowspan/colspan 占据,跳过
+                if occupied[r][c]:
+                    continue
+                
+                # 检查是否有单元格起始于此
+                cell = cell_map.get((r, c))
+                
+                if cell:
+                    # 有单元格:输出 td 并标记占用区域
+                    bbox = cell["bbox"]
+                    colspan = cell.get("colspan", 1)
+                    rowspan = cell.get("rowspan", 1)
+                    text = html.escape(cell.get("text", ""))
+                    bbox_str = f"[{int(bbox[0])},{int(bbox[1])},{int(bbox[2])},{int(bbox[3])}]"
+                    
+                    attrs = [f'data-bbox="{bbox_str}"']
+                    if colspan > 1:
+                        attrs.append(f'colspan="{colspan}"')
+                    if rowspan > 1:
+                        attrs.append(f'rowspan="{rowspan}"')
+                    
+                    html_parts.append(f'<td {" ".join(attrs)}>{text}</td>')
+                    
+                    # 标记占用
+                    for i in range(rowspan):
+                        for j in range(colspan):
+                            if r + i < max_row and c + j < max_col:
+                                occupied[r + i][c + j] = True
+                else:
+                    # 无单元格(空洞):输出空 td 占位,防止后续单元格左移
+                    # 这种情况通常是网格对齐产生的微小缝隙,或者是漏检
+                    html_parts.append("<td></td>")
+                    occupied[r][c] = True
+                    
+            html_parts.append("</tr>")
+        
+        html_parts.append("</tbody></table>")
+        return "".join(html_parts)
+
+    def _visualize_grid_structure(
+        self,
+        table_image: np.ndarray,
+        cells: List[Dict],
+        output_path: str
+    ):
+        """可视化表格逻辑结构 (row, col, span)"""
+        vis = table_image.copy()
+        if len(vis.shape) == 2:
+            vis = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
+            
+        for cell in cells:
+            x1, y1, x2, y2 = [int(c) for c in cell["bbox"]]
+            
+            # 绘制边框
+            cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
+            
+            # 绘制逻辑坐标
+            info = f"R{cell['row']}C{cell['col']}"
+            if cell.get('rowspan', 1) > 1: info += f" rs{cell['rowspan']}"
+            if cell.get('colspan', 1) > 1: info += f" cs{cell['colspan']}"
+            
+            # 居中显示
+            font_scale = 0.5
+            thickness = 1
+            (tw, th), _ = cv2.getTextSize(info, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
+            tx = x1 + (x2 - x1 - tw) // 2
+            ty = y1 + (y2 - y1 + th) // 2
+            
+            # 描边以增加可读性
+            cv2.putText(vis, info, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness + 2)
+            cv2.putText(vis, info, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 255, 255), thickness)
+            
+        cv2.imwrite(output_path, vis)
+        logger.info(f"表格结构可视化: {output_path}")
+
+    def recognize_v4(
+        self,
+        table_image: np.ndarray,
+        ocr_boxes: List[Dict[str, Any]],
+        debug_output_dir: Optional[str] = None,
+    ) -> Dict[str, Any]:
+        """
+        V4版本:直接从表格线计算单元格,绕过 MinerU 的 cal_region_boxes
+        """
+        upscale = self.upscale_ratio if self.upscale_ratio and self.upscale_ratio > 0 else 1.0
+        h, w = table_image.shape[:2]
+        
+        # Step 1: 获取 UNet 预测的横竖线 mask(供后续合并检测使用)
+        if upscale != 1.0:
+            img_up = cv2.resize(table_image, (int(w * upscale), int(h * upscale)))
+        else:
+            img_up = table_image
+        
+        wired_rec = self.table_model.wired_table_model
+        img = wired_rec.load_img(img_up)
+        img_info = wired_rec.table_structure.preprocess(img)
+        pred = wired_rec.table_structure.infer(img_info)
+        
+        hpred = np.where(pred == 1, 255, 0).astype(np.uint8)
+        vpred = np.where(pred == 2, 255, 0).astype(np.uint8)
+        
+        h_up, w_up = img_up.shape[:2]
+        hpred_up = cv2.resize(hpred, (w_up, h_up), interpolation=cv2.INTER_NEAREST)
+        vpred_up = cv2.resize(vpred, (w_up, h_up), interpolation=cv2.INTER_NEAREST)
+        
+        # Step 1.5: 可视化表格线(调试用)- 需要缩放回原图
+        if debug_output_dir:
+            hpred_orig = cv2.resize(hpred_up, (w, h), interpolation=cv2.INTER_NEAREST)
+            vpred_orig = cv2.resize(vpred_up, (w, h), interpolation=cv2.INTER_NEAREST)
+            self._visualize_table_lines(
+                table_image,
+                hpred_orig,
+                vpred_orig,
+                output_path=f"{debug_output_dir}/unet_table_lines.png"
+            )
+        
+        # Step 2: 使用连通域法提取单元格 (替换了原来的投影法)
+        bboxes = self._compute_cells_from_lines(hpred_up, vpred_up, upscale, debug_output_dir)
+        
+        if not bboxes:
+            raise RuntimeError("未能提取出单元格")
+
+        # Step 3: 重建网格结构 (计算 row, col, rowspan, colspan)
+        # 这一步替代了原来的 _merge_cells_without_separator
+        merged_cells = self._recover_grid_structure(bboxes)
+        
+        # Step 3.5: 可视化逻辑结构 (新增)
+        if debug_output_dir:
+            self._visualize_grid_structure(
+                table_image, merged_cells,
+                output_path=f"{debug_output_dir}/grid_structure.png"
+            )
+        
+        # Step 4: 统一计算文本填充
+        bboxes_merged = [cell["bbox"] for cell in merged_cells]
+        texts = self._fill_text_by_center_point(bboxes_merged, ocr_boxes or [])
+        
+        # Step 4.5: 对空单元格尝试二次 OCR (新增)
+        # 针对漏检问题(特别是竖排小字),进行切片放大识别
+        if hasattr(self, 'ocr_engine') and self.ocr_engine and any(not t for t in texts):
+            crop_list = []
+            crop_indices = []
+            h_img, w_img = table_image.shape[:2]
+            margin = self.cell_crop_margin
+            
+            for i, text in enumerate(texts):
+                if text.strip():
+                    continue
+                
+                bbox = bboxes_merged[i]
+                x1, y1, x2, y2 = map(int, bbox)
+                
+                # 边界保护 + 少量外扩
+                x1 = max(0, x1 - margin)
+                y1 = max(0, y1 - margin)
+                x2 = min(w_img, x2 + margin)
+                y2 = min(h_img, y2 + margin)
+                
+                if x2 <= x1 or y2 <= y1:
+                    continue
+                    
+                cell_img = table_image[y1:y2, x1:x2]
+                if cell_img.size == 0:
+                    continue
+                
+                # --- 关键改进:放大与旋转 ---
+                cell_h, cell_w = cell_img.shape[:2]
+                
+                # 1. 放大图像:对于表格中的小字,放大能显著提高识别率
+                # 建议放大 2 倍,如果原图特别小可以更大
+                scale = 2.0
+                if cell_h < 64 or cell_w < 64: # 只有较小的图才放大,避免大图过大
+                     cell_img = cv2.resize(cell_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+                
+                # 2. 处理竖排文本:如果高宽比很大(>2),很可能是竖排表头(如"优先股")
+                # 通用 OCR 模型通常只支持横排,旋转 90 度变成横排
+                if cell_h > cell_w * 2:
+                    cell_img = cv2.rotate(cell_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
+                # -------------------------
+                
+                crop_list.append(cell_img)
+                crop_indices.append(i)
+            
+            if crop_list:
+                try:
+                    # 批量识别,det=False 表示直接识别内容(假设裁剪图就是文本行)
+                    ocr_res = self.ocr_engine.ocr(crop_list, det=False)
+                    
+                    # 解析结果 (兼容 PaddleOCR 返回格式)
+                    # ocr_res 结构通常为 [(text, score), (text, score), ...] 对应每张图
+                    # 但有时可能包裹在列表中,需做兼容处理
+                    results = ocr_res
+                    if isinstance(ocr_res, list) and len(ocr_res) == 1 and isinstance(ocr_res[0], list) and len(ocr_res[0]) == len(crop_list):
+                         # 兼容 legacy 代码中遇到的 [[(t,s), (t,s)...]] 情况
+                         results = ocr_res[0]
+                    
+                    if len(results) == len(crop_list):
+                        for idx, res in enumerate(results):
+                            # res 可能是 (text, score) 或 [(text, score)] 或 None
+                            if not res: continue
+                            
+                            text = ""
+                            score = 0.0
+                            
+                            if isinstance(res, tuple):
+                                text, score = res
+                            elif isinstance(res, list) and len(res) > 0:
+                                text, score = res[0]
+                            
+                            if score >= self.ocr_conf_threshold and text:
+                                texts[crop_indices[idx]] = text
+                                
+                except Exception as e:
+                    logger.warning(f"二次OCR失败: {e}")
+
+        # 将文本填入 merged_cells
+        for i, cell in enumerate(merged_cells):
+            cell["text"] = texts[i] if i < len(texts) else ""
+        
+        # Step 5: 生成带文本和 colspan/rowspan 的 HTML
+        html_filled = self._build_html_from_merged_cells(merged_cells)
+        
+        # Step 6: 可视化文本填充(调试用)
+        if debug_output_dir:
+            self._visualize_with_text(
+                table_image, bboxes_merged, texts,
+                output_path=f"{debug_output_dir}/text_filled_v4.png"
+            )
+        
+        # Step 7: 组装 cells 输出
+        cells = []
+        for idx, cell in enumerate(merged_cells):
+            cells.append({
+                "bbox": cell["bbox"],
+                "row": cell.get("row", 0),
+                "col": cell.get("col", 0),
+                "rowspan": cell.get("rowspan", 1),
+                "colspan": cell.get("colspan", 1),
+                "text": cell["text"],
+                "matched_text": cell["text"],
+                "score": 100.0,
+            })
+        
+        return {
+            "html": html_filled,
+            "cells": cells,
+        }
+    
+
+    # ========== 调试可视化 ==========
+    def _visualize_table_lines(
+        self,
+        table_image: np.ndarray,
+        hpred: np.ndarray,
+        vpred: np.ndarray,
+        output_path: str
+    ) -> np.ndarray:
+        """
+        可视化 UNet 检测到的表格线
+        
+        Args:
+            table_image: 原始图片
+            hpred: 横线mask(已缩放到原图大小)
+            vpred: 竖线mask(已缩放到原图大小)
+            output_path: 输出路径
+        """
+        vis_img = table_image.copy()
+        if len(vis_img.shape) == 2:
+            vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2BGR)
+        
+        # 横线用红色,竖线用蓝色
+        vis_img[hpred > 128] = [0, 0, 255]  # 红色横线
+        vis_img[vpred > 128] = [255, 0, 0]  # 蓝色竖线
+        
+        cv2.imwrite(output_path, vis_img)
+        logger.info(f"表格线可视化: {output_path}")
+        
+        return vis_img
+
+    def _visualize_table_structure(
+        self,
+        image: np.ndarray,
+        bboxes: List[List[float]],
+        output_path: Optional[str] = None,
+        title: str = "Table Structure"
+    ) -> np.ndarray:
+        """
+        可视化表格结构检测结果
+        
+        Args:
+            image: 原始图片
+            bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
+            output_path: 保存路径(可选)
+            title: 标题
+            
+        Returns:
+            标注后的图片
+        """
+        import random
+        
+        vis_img = image.copy()
+        if len(vis_img.shape) == 2:
+            vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2BGR)
+        
+        # 为每个单元格分配随机颜色
+        colors = []
+        for _ in range(len(bboxes)):
+            colors.append((
+                random.randint(50, 255),
+                random.randint(50, 255),
+                random.randint(50, 255)
+            ))
+        
+        # 绘制单元格
+        for idx, bbox in enumerate(bboxes):
+            x1, y1, x2, y2 = map(int, bbox)
+            color = colors[idx]
+            
+            # 绘制矩形边框
+            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
+            
+            # 绘制单元格索引
+            cv2.putText(
+                vis_img, str(idx), 
+                (x1 + 2, y1 + 15),
+                cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1
+            )
+        
+        # 添加标题
+        cv2.putText(
+            vis_img, f"{title} ({len(bboxes)} cells)",
+            (10, 25),
+            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2
+        )
+        
+        if output_path:
+            cv2.imwrite(output_path, vis_img)
+            logger.info(f"表格结构可视化已保存: {output_path}")
+        
+        return vis_img
+    
+    def _visualize_with_text(
+        self,
+        image: np.ndarray,
+        bboxes: List[List[float]],
+        texts: List[str],
+        output_path: Optional[str] = None
+    ) -> np.ndarray:
+        """
+        可视化单元格及其文本内容
+        """
+        vis_img = image.copy()
+        if len(vis_img.shape) == 2:
+            vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2BGR)
+        
+        for idx, (bbox, text) in enumerate(zip(bboxes, texts)):
+            x1, y1, x2, y2 = map(int, bbox)
+            
+            # 有文本用绿色,无文本用红色
+            color = (0, 255, 0) if text else (0, 0, 255)
+            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
+            
+            # 显示文本预览(最多10个字符)
+            preview = text[:10] + "..." if len(text) > 10 else text
+            if preview:
+                cv2.putText(
+                    vis_img, preview,
+                    (x1 + 2, y1 + 15),
+                    cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 0, 0), 1
+                )
+        
+        if output_path:
+            cv2.imwrite(output_path, vis_img)
+            logger.info(f"文本填充可视化已保存: {output_path}")
+        
+        return vis_img
+    
+    def recognize(
+        self,
+        table_image: np.ndarray,
+        ocr_boxes: List[Dict[str, Any]],
+    ) -> Dict[str, Any]:
+        """
+        统一入口:根据配置选择 recognize_legacy() 或 recognize_v2()。
+        
+        配置项 use_custom_postprocess:
+        - True: 使用 recognize_v2()(自定义后处理)
+        - False: 使用 recognize()(原始流程)
+        """
+        if self.use_custom_postprocess:
+            try:
+                return self.recognize_v4(table_image, ocr_boxes, debug_output_dir="./output")
+            except Exception:
+                # 回退
+                return self.recognize_legacy(table_image, ocr_boxes)
+        else:
+            return self.recognize_legacy(table_image, ocr_boxes)