6 Commits 3a5b2ab300 ... 652b321bd6

Autore SHA1 Messaggio Data
  zhch158_admin 652b321bd6 feat: Update batch processing in main_v2.py to include output directory parameter for document processing, enhancing flexibility in file management. 2 giorni fa
  zhch158_admin 1bb438fba3 fix: Improve coordinate transformation accuracy in WiredTableVisualizer to reduce cumulative errors and enhance debugging with detailed logging of cell coordinates during visualization. 2 giorni fa
  zhch158_admin c628acd7b7 feat: Enhance text filling strategy in TextFiller class by introducing overlap ratio calculation for improved OCR box matching, optimizing performance with binary search, and refining text extraction logic. 2 giorni fa
  zhch158_admin 60aa86e4bf feat: Enhance GridRecovery class by adding optional parameters for original image dimensions and improving scaling logic for cell extraction, along with detailed debug logging for better traceability. 2 giorni fa
  zhch158_admin 26b500f344 feat: Add documentation for grid recovery scenarios in UNet, detailing causes of empty rows/columns and how `compress_grid` addresses them. 2 giorni fa
  zhch158_admin 6b063ced58 fix: Enhance UNet preprocessing and resizing logic in MinerUWiredTableRecognizer to ensure consistent prediction dimensions and improve debugging information 2 giorni fa

+ 79 - 0
docs/ocr_tools/universal_doc_parser/unet表格识别-grid_recovery.md

@@ -0,0 +1,79 @@
+
+## 产生空行/空列的情况
+```@ocr_platform/ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py:466 recover_grid_structure什么情况会产生空行或空列?
+```
+
+### 1. 网格线检测产生的额外分割线
+```352:380:ocr_platform/ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py
+    def recover_grid_structure(bboxes: List[List[float]]) -> List[Dict]:
+        # ...
+        # 1. 识别行分割线 (Y轴)
+        y_coords = []
+        for b in bboxes:
+            y_coords.append(b[1])  # top
+            y_coords.append(b[3])  # bottom
+        
+        row_dividers = GridRecovery.find_grid_lines(y_coords, tolerance=5, min_support=2)
+```
+
+- 问题:`find_grid_lines` 收集所有单元格的 top/bottom(或 left/right)坐标,聚类后生成网格线。
+- 如果某些位置有 ≥2 个坐标对齐(满足 `min_support=2`),就会产生一条网格线。
+- 结果:可能产生比实际行/列更多的网格线,从而产生空的行区间或列区间。
+
+### 2. 单元格匹配时未覆盖某些行/列区间
+```411:428:ocr_platform/ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py
+            # 匹配行
+            matched_rows = []
+            for r in row_intervals:
+                inter_top = max(b_top, r["top"])
+                inter_bottom = min(b_bottom, r["bottom"])
+                inter_h = max(0, inter_bottom - inter_top)
+                
+                if r["height"] > 0 and (inter_h / r["height"] > 0.5 or inter_h / b_h > 0.5):
+                    matched_rows.append(r["index"])
+```
+
+- 问题:匹配条件为 `inter_h / r["height"] > 0.5` 或 `inter_h / b_h > 0.5`。
+- 如果某个行/列区间很小,且没有单元格满足该条件,该区间就不会被匹配。
+- 结果:该行/列区间没有单元格的起始位置,成为空行/空列。
+
+### 3. rowspan/colspan 跨越的间隙行/列
+```500:504:ocr_platform/ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py
+        for cell in cells:
+            if cell["row"] < max_row:
+                row_occupied[cell["row"]] = True  # 只标记起始行
+            if cell["col"] < max_col:
+                col_occupied[cell["col"]] = True  # 只标记起始列
+```
+
+- 问题:`compress_grid` 只标记单元格的起始行/列(`cell["row"]` 和 `cell["col"]`)。
+- 如果一个单元格跨越多个行/列(rowspan/colspan > 1),中间的行/列如果没有其他单元格起始,就不会被标记为占用。
+- 结果:这些中间行/列成为空行/空列。
+
+### 4. 表格边缘的额外网格线
+- 如果表格边缘的单元格边界对齐形成网格线,但该行/列区间内没有单元格,也会产生空行/空列。
+
+## 示例场景
+
+假设有以下单元格:
+- Cell A: row=0, rowspan=3(跨越行 0, 1, 2)
+- Cell B: row=3, rowspan=1(只在行 3)
+
+如果网格线检测产生了行 0, 1, 2, 3, 4 的分割线:
+- 行 0: 被 Cell A 占用(起始)
+- 行 1: 被 Cell A 跨越,但没有单元格起始 → 空行
+- 行 2: 被 Cell A 跨越,但没有单元格起始 → 空行
+- 行 3: 被 Cell B 占用(起始)
+- 行 4: 没有单元格 → 空行
+
+`compress_grid` 会移除这些空行,最终得到紧凑的网格。
+
+## 总结
+
+空行/空列主要来自:
+1. 网格线检测产生的额外分割线
+2. 单元格匹配条件未覆盖某些区间
+3. rowspan/colspan 跨越的中间行/列
+4. 表格边缘的额外网格线
+
+`compress_grid` 会移除这些空行/空列,确保最终网格紧凑。

+ 12 - 4
ocr_tools/universal_doc_parser/main_v2.py

@@ -169,7 +169,12 @@ def process_single_input(
                 }
             else:
                 # 批量处理模式(原有逻辑)
-                results = pipeline.process_document(str(input_path), page_range=page_range)
+                # 批量处理模式(原有逻辑)
+                results = pipeline.process_document(
+                    str(input_path), 
+                    page_range=page_range,
+                    output_dir=str(output_dir)
+                )
                 process_time = (datetime.now() - start_time).total_seconds()
                 
                 logger.info(f"⏱️ 处理耗时: {process_time:.2f}秒")
@@ -402,8 +407,8 @@ if __name__ == "__main__":
             # "input": "/Users/zhch158/workspace/data/流水分析/康强_北京农村商业银行.pdf",
             # "output_dir": "./output/康强_北京农村商业银行_bank_statement_v2",
 
-            "input": "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/A用户_单元格扫描流水_page_002.png",
-            "output_dir": "./output/A用户_单元格扫描流水_bank_statement_wired_unet",
+            # "input": "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/A用户_单元格扫描流水_page_002.png",
+            # "output_dir": "./output/A用户_单元格扫描流水_bank_statement_wired_unet",
             
             # "input": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水.pdf",
             # "output_dir": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/bank_statement_yusys_v2",
@@ -411,12 +416,15 @@ if __name__ == "__main__":
             # "input": "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/2023年度报告母公司_page_005.png",
             # "input": "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/2023年度报告母公司_page_003_270.png",
             # "input": "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/2023年度报告母公司_page_003_270_skew(-0.4).png",
+            # "input": "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司.pdf",
             # "output_dir": "./output/2023年度报告母公司/bank_statement_wired_unet",
 
             # "input": "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司.pdf",
-            # "output_dir": "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司/bank_statement_wired_unet",
             # "output_dir": "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司/bank_statement_yusys_v2",
 
+            "input": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.pdf",
+            "output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/bank_statement_wired_unet",
+
             # "input": "/Users/zhch158/workspace/repository.git/ocr_platform/ocr_tools/universal_doc_parser/tests/600916_中国黄金_2022年报_page_096.png",
             # "output_dir": "./output/600916_中国黄金_2022年报/bank_statement_wired_unet",
             # "input": "/Users/zhch158/workspace/data/流水分析/600916_中国黄金_2022年报.pdf",

+ 52 - 5
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -203,15 +203,58 @@ class MinerUWiredTableRecognizer:
                 
                 wired_rec = self.table_model.wired_table_model
                 img_obj = wired_rec.load_img(img_up_)
+                
+                # 手动计算 UNet 预处理后的图像尺寸(模拟 preprocess 中的 resize_img 逻辑)
+                # UNet 使用 scale = (inp_height, inp_width) = (1024, 1024),keep_ratio=True
+                h_up_, w_up_ = img_up_.shape[:2]
+                inp_height = 1024
+                inp_width = 1024
+                max_long_edge = max(inp_height, inp_width)  # 1024
+                max_short_edge = min(inp_height, inp_width)  # 1024
+                # 计算缩放因子(保持长宽比)
+                scale_factor = min(max_long_edge / max(h_up_, w_up_), max_short_edge / min(h_up_, w_up_))
+                preprocessed_w = int(w_up_ * scale_factor + 0.5)
+                preprocessed_h = int(h_up_ * scale_factor + 0.5)
+                
                 img_info = wired_rec.table_structure.preprocess(img_obj)
                 pred_ = wired_rec.table_structure.infer(img_info)
                 
                 hpred_ = np.where(pred_ == 1, 255, 0).astype(np.uint8)
                 vpred_ = np.where(pred_ == 2, 255, 0).astype(np.uint8)
                 
-                h_up_, w_up_ = img_up_.shape[:2]
-                hpred_up_ = cv2.resize(hpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
-                vpred_up_ = cv2.resize(vpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                # 调试:记录尺寸信息
+                pred_h, pred_w = pred_.shape[:2]
+                logger.debug(
+                    f"UNet 推理: 上采样图像尺寸=[{h_up_}, {w_up_}], "
+                    f"预处理后尺寸=[{preprocessed_h}, {preprocessed_w}], "
+                    f"预测结果尺寸=[{pred_h}, {pred_w}], "
+                    f"缩放因子={scale_factor:.6f}, upscale={upscale:.3f}"
+                )
+                
+                # 关键修复:正确地将预测结果 resize 回上采样尺寸
+                # UNet 的 postprocess 使用 ori_shape = img.shape 来 resize 预测结果
+                # 在我们的情况下,img 是 img_up_(上采样后的图像)
+                # 所以我们应该使用 img_up_.shape 来 resize 预测结果
+                # 但是,由于预处理时改变了图像尺寸(保持长宽比),我们需要确保 resize 是正确的
+                
+                # 验证:检查预测结果尺寸是否与预处理后的尺寸一致
+                if pred_h != preprocessed_h or pred_w != preprocessed_w:
+                    logger.warning(
+                        f"⚠️ 预测结果尺寸 [{pred_h}, {pred_w}] 与预处理后尺寸 [{preprocessed_h}, {preprocessed_w}] 不一致!"
+                        f"这可能导致坐标偏移。使用预处理后尺寸进行 resize。"
+                    )
+                    # 如果尺寸不一致,先 resize 到预处理后尺寸,再 resize 到上采样尺寸
+                    # 但实际上,预测结果应该就是预处理后的尺寸,所以这个警告不应该出现
+                    hpred_temp = cv2.resize(hpred_, (preprocessed_w, preprocessed_h), interpolation=cv2.INTER_NEAREST)
+                    vpred_temp = cv2.resize(vpred_, (preprocessed_w, preprocessed_h), interpolation=cv2.INTER_NEAREST)
+                    hpred_up_ = cv2.resize(hpred_temp, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                    vpred_up_ = cv2.resize(vpred_temp, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                else:
+                    # 正常情况:预测结果就是预处理后的尺寸,直接 resize 到上采样尺寸
+                    # 这相当于 UNet postprocess 中的逻辑:cv2.resize(pred, (ori_shape[1], ori_shape[0]))
+                    hpred_up_ = cv2.resize(hpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                    vpred_up_ = cv2.resize(vpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                
                 return hpred_up_, vpred_up_, img_up_
 
             # Step 1: 首次运行 UNet 获取初步 mask
@@ -258,11 +301,15 @@ class MinerUWiredTableRecognizer:
             if dbg and dbg.enabled and dbg.output_dir:
                 debug_dir = dbg.output_dir
                 debug_prefix = f"{dbg.prefix}_grid" if dbg.prefix else "grid"
-                
+            
+            # 传入原图的实际尺寸,用于计算真实的缩放比例
+            # 这样可以正确处理 UNet 预处理改变图像尺寸的情况
             bboxes = self.grid_recovery.compute_cells_from_lines(
                 hpred_up, 
                 vpred_up, 
                 upscale,
+                orig_h=h,
+                orig_w=w,
                 debug_dir=debug_dir,
                 debug_prefix=debug_prefix
             )
@@ -306,7 +353,7 @@ class MinerUWiredTableRecognizer:
                 texts = self.text_filler.second_pass_ocr_fill(
                     table_image, bboxes_merged, texts, scores, 
                     need_reocr_indices=need_reocr_indices,
-                    force_all=True  # Force Per-Cell OCR
+                    force_all=False  # Force Per-Cell OCR
                 )
 
             for i, cell in enumerate(merged_cells):

+ 77 - 38
ocr_tools/universal_doc_parser/models/adapters/wired_table/grid_recovery.py

@@ -3,7 +3,7 @@
 
 提供从表格线提取单元格和恢复网格结构的功能。
 """
-from typing import List, Dict
+from typing import List, Dict, Optional
 import cv2
 import numpy as np
 from loguru import logger
@@ -17,7 +17,9 @@ class GridRecovery:
         hpred_up: np.ndarray,
         vpred_up: np.ndarray,
         upscale: float = 1.0,
-        debug_dir: str = None,
+        orig_h: Optional[int] = None,
+        orig_w: Optional[int] = None,
+        debug_dir: Optional[str] = None,
         debug_prefix: str = "",
     ) -> List[List[float]]:
         """
@@ -33,7 +35,9 @@ class GridRecovery:
         Args:
             hpred_up: 横线预测mask(上采样后)
             vpred_up: 竖线预测mask(上采样后)
-            upscale: 上采样比例
+            upscale: 上采样比例(用于向后兼容,如果提供了 orig_h/orig_w 则会被覆盖)
+            orig_h: 原图的实际高度(用于计算真实的缩放比例)
+            orig_w: 原图的实际宽度(用于计算真实的缩放比例)
             debug_dir: 调试输出目录 (Optional)
             debug_prefix: 调试文件名前缀 (Optional)
             
@@ -145,7 +149,7 @@ class GridRecovery:
             # Pre-calculate Max Allowed Lengths (Original Length * Multiplier)
             # Multiplier = 2.0 means we allow the line to double in size, but not more.
             # This effectively stops short noise from becoming page-height lines.
-            extension_multiplier = 3.0 
+            extension_multiplier = 2.0 
             
             row_max_lens = [dist_sqrt(b[:2], b[2:]) * extension_multiplier for b in rowboxes]
             col_max_lens = [dist_sqrt(b[:2], b[2:]) * extension_multiplier for b in colboxes]
@@ -275,6 +279,23 @@ class GridRecovery:
         # 7. 连通域
         num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
         
+        # 计算真实的缩放比例
+        # 如果提供了原图尺寸,使用真实的缩放比例;否则使用 upscale(向后兼容)
+        if orig_h is not None and orig_w is not None and orig_h > 0 and orig_w > 0:
+            scale_h = h / orig_h
+            scale_w = w / orig_w
+            logger.debug(
+                f"连通域分析: mask尺寸=[{h}, {w}], 原图尺寸=[{orig_h}, {orig_w}], "
+                f"真实缩放比例=[{scale_h:.3f}, {scale_w:.3f}], upscale={upscale:.3f}"
+            )
+        else:
+            scale_h = upscale
+            scale_w = upscale
+            logger.debug(
+                f"连通域分析: mask尺寸=[{h}, {w}], upscale={upscale:.3f}, "
+                f"预期原图尺寸≈[{h/upscale:.1f}, {w/upscale:.1f}] (使用 upscale,未提供原图尺寸)"
+            )
+        
         bboxes = []
         
         # 8. 过滤
@@ -290,22 +311,35 @@ class GridRecovery:
             if area < 50:
                 continue
                 
-            orig_h = h_cell / upscale
-            orig_w = w_cell / upscale
+            # 使用真实的缩放比例转换为原图坐标
+            cell_orig_h = h_cell / scale_h
+            cell_orig_w = w_cell / scale_w
             
-            if orig_h < 4.0 or orig_w < 4.0:
+            if cell_orig_h < 4.0 or cell_orig_w < 4.0:
                 continue
             
             bboxes.append([
-                x / upscale,
-                y / upscale,
-                (x + w_cell) / upscale,
-                (y + h_cell) / upscale
+                x / scale_w,
+                y / scale_h,
+                (x + w_cell) / scale_w,
+                (y + h_cell) / scale_h
             ])
         
         bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
         
-        logger.info(f"矢量重构分析提取到 {len(bboxes)} 个单元格 (Dynamic Alpha: {dynamic_alpha})")
+        # 调试日志:输出样本 bbox 坐标信息
+        if len(bboxes) > 0:
+            logger.debug(f"样本 bbox (原图坐标): 前3个 = {bboxes[:3]}, 后3个 = {bboxes[-3:]}")
+            logger.debug(f"bbox 坐标范围: x=[{min(b[0] for b in bboxes):.1f}, {max(b[2] for b in bboxes):.1f}], "
+                        f"y=[{min(b[1] for b in bboxes):.1f}, {max(b[3] for b in bboxes):.1f}]")
+        
+        if orig_h is not None and orig_w is not None:
+            logger.info(
+                f"矢量重构分析提取到 {len(bboxes)} 个单元格 "
+                f"(Dynamic Alpha: {dynamic_alpha}, 真实缩放比例=[{scale_h:.3f}, {scale_w:.3f}])"
+            )
+        else:
+            logger.info(f"矢量重构分析提取到 {len(bboxes)} 个单元格 (Dynamic Alpha: {dynamic_alpha}, upscale={upscale:.3f})")
         
         return bboxes
 
@@ -466,33 +500,43 @@ class GridRecovery:
     def compress_grid(cells: List[Dict]) -> List[Dict]:
         """
         压缩网格索引,移除空行和空列
-        
+
+        算法解释:
+        1. 遍历所有单元格,查找每个单元格的起始行/列和其跨度(rowspan/colspan),计算出原始网格的最大行数和列数。
+        2. 使用布尔列表记录每一行和每一列是否有单元格占用(row_occupied/col_occupied)。如果某一行或列没有任何单元格,则为"空"行/列。
+        3. 根据占用情况遍历并为原网格的每一行/列建立到新索引的映射表(row_map/col_map)。未被占用(即为空)的行或列不会增加新的索引。
+        4. 根据刚刚生成的映射表,对所有单元格的row/col索引进行压缩更新,使它们在新的"紧凑"网格中连续、不含空行空列。
+        5. 返回更新后的单元格列表。
+
+        该算法本质上是:保留有效的单元格行列,剔除全空的行列,并将单元格的行列索引重新编号,使得新网格紧凑无缝隙。
+
         Args:
             cells: 单元格列表
-            
+
         Returns:
             压缩后的单元格列表
         """
+        # 源码实现见下方
         if not cells:
             return []
-        
+
         # 1. 计算当前最大行列
         max_row = 0
         max_col = 0
         for cell in cells:
             max_row = max(max_row, cell["row"] + cell.get("rowspan", 1))
             max_col = max(max_col, cell["col"] + cell.get("colspan", 1))
-        
+
         # 2. 标记占用情况
         row_occupied = [False] * max_row
         col_occupied = [False] * max_col
-        
+
         for cell in cells:
             if cell["row"] < max_row:
                 row_occupied[cell["row"]] = True
             if cell["col"] < max_col:
                 col_occupied[cell["col"]] = True
-        
+
         # 3. 构建映射表
         row_map = [0] * (max_row + 1)
         current_row = 0
@@ -500,38 +544,33 @@ class GridRecovery:
             if row_occupied[r]:
                 current_row += 1
             row_map[r + 1] = current_row
-        
+
         col_map = [0] * (max_col + 1)
         current_col = 0
         for c in range(max_col):
             if col_occupied[c]:
                 current_col += 1
             col_map[c + 1] = current_col
-        
+
         # 4. 更新单元格索引
         new_cells = []
         for cell in cells:
             new_cell = cell.copy()
-            
+
             old_r1 = cell["row"]
-            old_r2 = old_r1 + cell.get("rowspan", 1)
-            new_r1 = row_map[old_r1]
-            new_r2 = row_map[old_r2]
-            
             old_c1 = cell["col"]
+            old_r2 = old_r1 + cell.get("rowspan", 1)
             old_c2 = old_c1 + cell.get("colspan", 1)
-            new_c1 = col_map[old_c1]
-            new_c2 = col_map[old_c2]
-            
-            new_span_r = max(1, new_r2 - new_r1)
-            new_span_c = max(1, new_c2 - new_c1)
-            
-            new_cell["row"] = new_r1
-            new_cell["col"] = new_c1
-            new_cell["rowspan"] = new_span_r
-            new_cell["colspan"] = new_span_c
-            
+
+            new_row = row_map[old_r1]
+            new_col = col_map[old_c1]
+            new_rowspan = row_map[old_r2] - row_map[old_r1]
+            new_colspan = col_map[old_c2] - col_map[old_c1]
+
+            new_cell["row"] = new_row
+            new_cell["col"] = new_col
+            new_cell["rowspan"] = new_rowspan
+            new_cell["colspan"] = new_colspan
             new_cells.append(new_cell)
-        
-        return new_cells
 
+        return new_cells

+ 120 - 68
ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

@@ -4,6 +4,7 @@
 提供表格单元格文本填充功能,包括OCR文本匹配和二次OCR填充。
 """
 from typing import List, Dict, Any, Tuple, Optional
+import bisect
 import cv2
 import numpy as np
 from loguru import logger
@@ -26,17 +27,58 @@ class TextFiller:
         self.cell_crop_margin: int = config.get("cell_crop_margin", 2)
         self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.5)
     
+    @staticmethod
+    def calculate_overlap_ratio(ocr_bbox: List[float], cell_bbox: List[float]) -> float:
+        """
+        计算 OCR box 与单元格的重叠比例(重叠面积 / OCR box 面积)
+        
+        这个比例表示 OCR box 有多少部分在单元格内,用于判断 OCR box 是否主要属于该单元格。
+        
+        Args:
+            ocr_bbox: OCR box 坐标 [x1, y1, x2, y2]
+            cell_bbox: 单元格坐标 [x1, y1, x2, y2]
+            
+        Returns:
+            重叠比例 (0.0 ~ 1.0),表示 OCR box 有多少部分在单元格内
+        """
+        if not ocr_bbox or not cell_bbox or len(ocr_bbox) < 4 or len(cell_bbox) < 4:
+            return 0.0
+        
+        # 计算交集
+        inter_x1 = max(ocr_bbox[0], cell_bbox[0])
+        inter_y1 = max(ocr_bbox[1], cell_bbox[1])
+        inter_x2 = min(ocr_bbox[2], cell_bbox[2])
+        inter_y2 = min(ocr_bbox[3], cell_bbox[3])
+        
+        if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
+            return 0.0
+        
+        inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
+        ocr_area = (ocr_bbox[2] - ocr_bbox[0]) * (ocr_bbox[3] - ocr_bbox[1])
+        
+        if ocr_area <= 0:
+            return 0.0
+        
+        return inter_area / ocr_area
+    
     def fill_text_by_center_point(
         self,
         bboxes: List[List[float]],
         ocr_boxes: List[Dict[str, Any]],
     ) -> Tuple[List[str], List[float], List[List[Dict[str, Any]]], List[int]]:
         """
-        使用中心点落格策略填充文本。
+        使用混合匹配策略填充文本:中心点 + 重叠比例。
+        
+        策略说明:
+        1. 首先用中心点快速筛选:OCR box 的中心点在单元格内
+        2. 然后检查重叠比例:OCR box 与单元格的重叠面积 / OCR box 面积 >= 0.5
+           (这确保 OCR box 主要属于该单元格,避免跨单元格匹配)
+        3. 如果多个单元格都满足条件,选择重叠比例最高的
         
-        参考 fill_html_with_ocr_by_bbox:
-        - OCR文本中心点落入单元格bbox内则匹配
-        - 多行文本按y坐标排序拼接
+        优点:
+        - 比纯 IOU 更宽松,能匹配到更多 OCR box
+        - 比纯中心点更准确,能过滤跨单元格的 OCR box
+        - 适合表格场景,OCR box 通常比单元格小或部分重叠
         
         Args:
             bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
@@ -45,7 +87,7 @@ class TextFiller:
         Returns:
             每个单元格的文本列表
             每个单元格的置信度列表
-            每个单元格匹配到的 OCR boxes 列表
+            每个单元格匹配到的 OCR boxes 列表(已过滤跨单元格的 OCR box)
             需要二次 OCR 的单元格索引列表(OCR box 跨多个单元格或过大)
         """
         texts: List[str] = ["" for _ in bboxes]
@@ -56,61 +98,96 @@ class TextFiller:
         if not ocr_boxes:
             return texts, scores, matched_boxes_list, need_reocr_indices
         
-        # 预处理OCR结果:计算中心点
+        # 预处理OCR结果:转换为 bbox 格式,并计算中心点
         ocr_items: List[Dict[str, Any]] = []
         for item in ocr_boxes:
             # 使用 CoordinateUtils.poly_to_bbox() 替换 _normalize_bbox()
             box = CoordinateUtils.poly_to_bbox(item.get("bbox", []))
-            if not box:
+            if not box or len(box) < 4:
                 continue
             cx = (box[0] + box[2]) / 2
             cy = (box[1] + box[3]) / 2
             ocr_items.append({
+                "bbox": box,
                 "center_x": cx,
                 "center_y": cy,
-                "y1": box[1],
-                "bbox": box,  # 保存 bbox 用于跨单元格检测
                 "text": item.get("text", ""),
                 "confidence": float(item.get("confidence", item.get("score", 1.0))),
                 "original_box": item,  # 保存完整的 OCR box 对象
             })
         
-        # 为每个单元格匹配OCR文本
-        for idx, bbox in enumerate(bboxes):
-            x1, y1, x2, y2 = bbox
-            matched: List[Tuple[str, float, float, Dict[str, Any]]] = [] # (text, y1, score, original_box)
+        # 按 (y1, x1) 排序,便于后续二分查找和提前退出
+        # 排序只需要一次,对整体性能影响很小(O(n log n))
+        ocr_items.sort(key=lambda item: (item["bbox"][1], item["bbox"][0]))
+        
+        # 重叠比例阈值:OCR box 与单元格的重叠面积必须 >= OCR box 面积的 50%
+        # 这确保 OCR box 主要属于该单元格
+        overlap_ratio_threshold = 0.5
+        
+        # 为每个单元格匹配OCR文本(使用中心点 + 重叠比例)
+        # 优化:使用二分查找和提前退出机制,减少遍历次数
+        # 创建一个 y1 值的列表用于二分查找(兼容 Python < 3.10)
+        ocr_y1_list = [item["bbox"][1] for item in ocr_items]
+        
+        for idx, cell_bbox in enumerate(bboxes):
+            cell_x1, cell_y1, cell_x2, cell_y2 = cell_bbox
+            matched: List[Tuple[str, float, float, float, float, Dict[str, Any]]] = [] # (text, y1, x1, overlap_ratio, score, original_box)
             
-            for ocr in ocr_items:
-                if x1 <= ocr["center_x"] <= x2 and y1 <= ocr["center_y"] <= y2:
-                    matched.append((ocr["text"], ocr["y1"], ocr["confidence"], ocr["original_box"]))
+            # 使用二分查找找到第一个 y1 >= cell_y1 的 OCR item
+            # 由于 ocr_items 已按 (y1, x1) 排序,可以使用 bisect_left
+            start_idx = bisect.bisect_left(ocr_y1_list, cell_y1)
+            
+            # 关键优化:OCR box 的 y1 可能 < cell_y1,但 y2 >= cell_y1(跨越单元格上边界)
+            # 为了不遗漏这种情况,我们需要向前查找一些 items
+            # 向前查找的最大数量:假设 OCR box 最大高度不超过 100 像素(可根据实际情况调整)
+            max_lookback = 20  # 向前查找最多 20 个 items
+            actual_start_idx = max(0, start_idx - max_lookback)
+            
+            # 从 actual_start_idx 开始遍历,当 y1 > cell_y2 时提前退出
+            for i in range(actual_start_idx, len(ocr_items)):
+                ocr_item = ocr_items[i]
+                ocr_bbox = ocr_item["bbox"]
+                
+                # 提前退出:如果 y1 > cell_y2,后续的 items 都不可能在单元格内
+                if ocr_bbox[1] > cell_y2:
+                    break
+                
+                # 快速过滤:如果 OCR box 的 y2 < cell_y1,说明它完全在单元格上方,跳过
+                if ocr_bbox[3] < cell_y1:
+                    continue
+                
+                cx = ocr_item["center_x"]
+                cy = ocr_item["center_y"]
+                
+                # 第一步:中心点必须在单元格内
+                if not (cell_x1 <= cx <= cell_x2 and cell_y1 <= cy <= cell_y2):
+                    continue
+                
+                # 第二步:检查重叠比例(OCR box 有多少部分在单元格内)
+                overlap_ratio = self.calculate_overlap_ratio(ocr_bbox, cell_bbox)
+                if overlap_ratio >= overlap_ratio_threshold:
+                    matched.append((
+                        ocr_item["text"], 
+                        ocr_bbox[1],  # y1 坐标
+                        ocr_bbox[0],  # 添加 x1 坐标
+                        overlap_ratio,
+                        ocr_item["confidence"], 
+                        ocr_item["original_box"]
+                    ))
             
             if matched:
-                # 按y坐标排序,确保多行文本顺序正确
-                matched.sort(key=lambda x: x[1])
-                texts[idx] = "".join([t for t, _, _, _ in matched])
+                # 直接按 y1 和 x1 排序,确保文本顺序正确
+                # y_tolerance 用于将相近的 y1 归为同一行(容差范围内视为同一行)
+                # 同一行内按 x1 从左到右排序
+                y_tolerance = 5
+                matched.sort(key=lambda x: (round(x[1] / y_tolerance), x[2]))  # 先按 y_group,再按 x1
+                
+                texts[idx] = "".join([t for t, _, _, _, _, _ in matched])
                 # 计算平均置信度
-                avg_score = sum([s for _, _, s, _ in matched]) / len(matched)
+                avg_score = sum([s for _, _, _, _, s, _ in matched]) / len(matched)
                 scores[idx] = avg_score
                 # 保存匹配到的 OCR boxes
-                matched_boxes_list[idx] = [box for _, _, _, box in matched]
-                
-                # 检测 OCR box 是否跨多个单元格或过大
-                for ocr_item in ocr_items:
-                    ocr_bbox = ocr_item["bbox"]
-                    # 检测是否跨多个单元格
-                    overlapping_cells = self.detect_ocr_box_spanning_cells(ocr_bbox, bboxes, overlap_threshold=0.3)
-                    if len(overlapping_cells) >= 2:
-                        # OCR box 跨多个单元格,标记所有相关单元格需要二次 OCR
-                        for cell_idx in overlapping_cells:
-                            if cell_idx not in need_reocr_indices:
-                                need_reocr_indices.append(cell_idx)
-                        logger.debug(f"检测到 OCR box 跨 {len(overlapping_cells)} 个单元格: {ocr_item['text'][:20]}...")
-                    
-                    # 检测 OCR box 是否相对于当前单元格过大
-                    if self.is_ocr_box_too_large(ocr_bbox, bbox, size_ratio_threshold=1.5):
-                        if idx not in need_reocr_indices:
-                            need_reocr_indices.append(idx)
-                        logger.debug(f"检测到 OCR box 相对于单元格过大 (单元格 {idx}): {ocr_item['text'][:20]}...")
+                matched_boxes_list[idx] = [box for _, _, _, _, _, box in matched]
             else:
                 scores[idx] = 0.0 # 无匹配文本,置信度为0
         
@@ -189,35 +266,6 @@ class TextFiller:
         
         return overlapping_cells
     
-    @staticmethod
-    def is_ocr_box_too_large(
-        ocr_bbox: List[float],
-        cell_bbox: List[float],
-        size_ratio_threshold: float = 1.5
-    ) -> bool:
-        """
-        检测 OCR box 是否相对于单元格过大
-        
-        Args:
-            ocr_bbox: OCR box 坐标 [x1, y1, x2, y2]
-            cell_bbox: 单元格坐标 [x1, y1, x2, y2]
-            size_ratio_threshold: 面积比阈值,如果 OCR box 面积 > 单元格面积 * 阈值,则认为过大
-            
-        Returns:
-            是否过大
-        """
-        if not ocr_bbox or len(ocr_bbox) < 4 or not cell_bbox or len(cell_bbox) < 4:
-            return False
-        
-        ocr_area = (ocr_bbox[2] - ocr_bbox[0]) * (ocr_bbox[3] - ocr_bbox[1])
-        cell_area = (cell_bbox[2] - cell_bbox[0]) * (cell_bbox[3] - cell_bbox[1])
-        
-        if cell_area <= 0:
-            return False
-        
-        size_ratio = ocr_area / cell_area
-        return size_ratio > size_ratio_threshold
-    
     def second_pass_ocr_fill(
         self,
         table_image: np.ndarray,
@@ -377,6 +425,10 @@ class TextFiller:
                     
                     if x2 > x1 and y2 > y1:
                         cropped = cell_img[y1:y2, x1:x2]
+                        ch, cw = cropped.shape[:2]
+                        # 小图放大
+                        if ch < 64 or cw < 64:
+                            cropped = cv2.resize(cropped, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
                         if cropped.size > 0:
                             rec_img_list.append(cropped)
                             rec_indices.append((cell_idx, box_idx))

+ 34 - 5
ocr_tools/universal_doc_parser/models/adapters/wired_table/visualization.py

@@ -59,11 +59,20 @@ class WiredTableVisualizer:
         Args:
             hpred_up: 横线预测mask(上采样后)
             vpred_up: 竖线预测mask(上采样后)
-            bboxes: 单元格bbox列表
+            bboxes: 单元格bbox列表(原图坐标)
             upscale: 上采样比例
             output_path: 输出路径
         """
         h, w = hpred_up.shape[:2]
+        
+        # 调试:验证上采样图像尺寸
+        expected_h = int(bboxes[-1][3] * upscale + 0.5) if bboxes else 0
+        expected_w = int(bboxes[-1][2] * upscale + 0.5) if bboxes else 0
+        logger.debug(
+            f"上采样图像尺寸: 实际=[{h}, {w}], "
+            f"预期(基于最大bbox)≈[{expected_h}, {expected_w}], "
+            f"upscale={upscale:.3f}"
+        )
 
         # 与连通域提取相同的预处理,以获得直观的网格线背景
         _, h_bin = cv2.threshold(hpred_up, 127, 255, cv2.THRESH_BINARY)
@@ -78,12 +87,32 @@ class WiredTableVisualizer:
         vis[grid_mask > 0] = [0, 0, 255]  # 红色线条
 
         # 在上采样坐标系上绘制单元格框
-        for box in bboxes:
-            x1, y1, x2, y2 = [int(c * upscale) for c in box]
-            cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
+        # 修复:使用更精确的坐标转换,避免累积误差
+        for idx, box in enumerate(bboxes):
+            # 使用四舍五入而不是直接截断,提高精度
+            x1 = int(box[0] * upscale + 0.5)
+            y1 = int(box[1] * upscale + 0.5)
+            x2 = int(box[2] * upscale + 0.5)
+            y2 = int(box[3] * upscale + 0.5)
+            
+            # 确保坐标在图像范围内
+            x1 = max(0, min(x1, w - 1))
+            y1 = max(0, min(y1, h - 1))
+            x2 = max(0, min(x2, w - 1))
+            y2 = max(0, min(y2, h - 1))
+            
+            if x2 > x1 and y2 > y1:
+                cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
+            
+            # 调试日志:输出前几个和后几个单元格的坐标转换信息
+            if idx < 3 or idx >= len(bboxes) - 3:
+                logger.debug(
+                    f"单元格 {idx}: 原图坐标 [{box[0]:.1f}, {box[1]:.1f}, {box[2]:.1f}, {box[3]:.1f}] "
+                    f"-> 上采样坐标 [{x1}, {y1}, {x2}, {y2}] (upscale={upscale:.3f})"
+                )
 
         cv2.imwrite(output_path, vis)
-        logger.info(f"连通域可视化: {output_path}")
+        logger.info(f"连通域可视化: {output_path} (共 {len(bboxes)} 个单元格)")
     
     @staticmethod
     def visualize_grid_structure(