Przeglądaj źródła

feat: 增强表格识别适配器,新增缺失单元格推断功能并优化HTML生成逻辑

zhch158_admin 1 miesiąc temu
rodzic
commit
1f31948dc6
1 zmienionych plików z 222 dodań i 31 usunięć
  1. 222 31
      zhch/adapters/table_recognition_adapter.py

+ 222 - 31
zhch/adapters/table_recognition_adapter.py

@@ -15,6 +15,20 @@ import numpy as np
 from paddlex.inference.pipelines.table_recognition.result import SingleTableRecognitionResult
 from paddlex.inference.pipelines.table_recognition.pipeline_v2 import OCRResult
 
+def _normalize_bbox(box: list) -> list:
+    """
+    将8点坐标或4点坐标统一转换为 [x1, y1, x2, y2]
+    """
+    if len(box) == 8:
+        # 8点坐标:取最小和最大值
+        xs = [box[0], box[2], box[4], box[6]]
+        ys = [box[1], box[3], box[5], box[7]]
+        return [min(xs), min(ys), max(xs), max(ys)]
+    elif len(box) == 4:
+        return box[:4]
+    else:
+        raise ValueError(f"Unsupported bbox format: {box}")
+
 # --- 1. 核心算法:基于排序和行分组的HTML结构生成 ---
 def filter_nested_boxes(boxes: List[list]) -> List[list]:
     """
@@ -93,24 +107,41 @@ def build_robust_html_from_cells(cells_det_results: List[list]) -> str:
     if not cells_det_results:
         return "<table><tbody></tbody></table>"
 
-    cells = filter_nested_boxes(cells_det_results)
+    # ✅ 关键修复:使用副本防止修改原始列表
+    import copy
+    cells_copy = copy.deepcopy(cells_det_results)
+    cells = filter_nested_boxes(cells_copy)
     cells.sort(key=lambda c: (c[1], c[0]))
 
     rows = []
     if cells:
         current_row = [cells[0]]
-        row_anchor_y = (cells[0][1] + cells[0][3]) / 2
-        row_anchor_height = cells[0][3] - cells[0][1]
+        # ✅ 使用该行的Y范围而不是单个锚点
+        row_y1 = cells[0][1]
+        row_y2 = cells[0][3]
 
         for cell in cells[1:]:
-            cell_y_center = (cell[1] + cell[3]) / 2
-            if abs(cell_y_center - row_anchor_y) < row_anchor_height * 0.7:
+            # ✅ 计算垂直方向的重叠
+            overlap_y1 = max(row_y1, cell[1])
+            overlap_y2 = min(row_y2, cell[3])
+            overlap_height = max(0, overlap_y2 - overlap_y1)
+            
+            # 单元格和当前行的平均高度
+            cell_height = cell[3] - cell[1]
+            row_height = row_y2 - row_y1
+            avg_height = (cell_height + row_height) / 2
+            
+            # ✅ 重叠高度超过平均高度的50%,认为是同一行
+            if overlap_height > avg_height * 0.5:
                 current_row.append(cell)
+                # 更新该行的Y范围(扩展以包含新单元格)
+                row_y1 = min(row_y1, cell[1])
+                row_y2 = max(row_y2, cell[3])
             else:
                 rows.append(current_row)
                 current_row = [cell]
-                row_anchor_y = (cell[1] + cell[3]) / 2
-                row_anchor_height = cell[3] - cell[1]
+                row_y1 = cell[1]
+                row_y2 = cell[3]
         rows.append(current_row)
 
     html = "<table><tbody>"
@@ -178,8 +209,127 @@ def fill_html_with_ocr_by_bbox(html_skeleton: str, ocr_dt_boxes: list, ocr_texts
 # 保存原始方法的引用
 _original_predict_single = None
 
+def infer_missing_cells_from_ocr(
+    detected_cells: List[list],
+    cells_texts_list: List[str],
+    overall_ocr_boxes: List[list],
+    overall_ocr_texts: List[str],
+    table_box: list
+) -> tuple[List[list], List[str]]:
+    """
+    根据全局OCR结果推断缺失的单元格
+    
+    Args:
+        detected_cells: 已检测到的单元格坐标 [[x1,y1,x2,y2], ...]
+        overall_ocr_boxes: 全局OCR框坐标
+        overall_ocr_texts: 全局OCR文本
+        table_box: 表格区域 [x1,y1,x2,y2]
+    
+    Returns:
+        补全后的单元格列表
+    """
+    import copy
+    
+    # 1. 找出未被覆盖的OCR框
+    uncovered_ocr_boxes = []
+    uncovered_ocr_texts = []
+    
+    for ocr_box, ocr_text in zip(overall_ocr_boxes, overall_ocr_texts):
+        # 计算OCR框中心点
+        ocr_cx = (ocr_box[0] + ocr_box[2]) / 2
+        ocr_cy = (ocr_box[1] + ocr_box[3]) / 2
+        
+        # 检查是否被任何单元格覆盖
+        is_covered = False
+        for cell in detected_cells:
+            if cell[0] <= ocr_cx <= cell[2] and cell[1] <= ocr_cy <= cell[3]:
+                is_covered = True
+                break
+        
+        if not is_covered:
+            uncovered_ocr_boxes.append(ocr_box)
+            uncovered_ocr_texts.append(ocr_text)
+    
+    if not uncovered_ocr_boxes:
+        return detected_cells, cells_texts_list  # 没有漏检
+    
+    # 2. 按行分组已检测的单元格
+    cells_sorted = sorted(detected_cells, key=lambda c: (c[1], c[0]))
+    rows = []
+    if cells_sorted:
+        current_row = [cells_sorted[0]]
+        row_y = (cells_sorted[0][1] + cells_sorted[0][3]) / 2
+        row_height = cells_sorted[0][3] - cells_sorted[0][1]
+        
+        for cell in cells_sorted[1:]:
+            cell_y = (cell[1] + cell[3]) / 2
+            if abs(cell_y - row_y) < row_height * 0.7:
+                current_row.append(cell)
+            else:
+                rows.append(current_row)
+                current_row = [cell]
+                row_y = (cell[1] + cell[3]) / 2
+                row_height = cell[3] - cell[1]
+        rows.append(current_row)
+    
+    # 3. 为每个未覆盖的OCR框推断单元格
+    inferred_cells = []
+    inferred_texts = []
+    for ocr_box, ocr_text in zip(uncovered_ocr_boxes, uncovered_ocr_texts):
+        ocr_cy = (ocr_box[1] + ocr_box[3]) / 2
+        
+        # 找到OCR框所在的行
+        target_row_idx = None
+        for i, row_cells in enumerate(rows):
+            row_y1 = min(c[1] for c in row_cells)
+            row_y2 = max(c[3] for c in row_cells)
+            if row_y1 <= ocr_cy <= row_y2:
+                target_row_idx = i
+                break
+        
+        if target_row_idx is None:
+            # 无法确定所属行,跳过
+            print(f"⚠️  无法为OCR文本 '{ocr_text}' 确定所属行")
+            continue
+        
+        target_row = rows[target_row_idx]
+        
+        # 4. 推断单元格边界
+        # 上下边界:使用该行的统一高度
+        cell_y1 = min(c[1] for c in target_row)
+        cell_y2 = max(c[3] for c in target_row)
+        
+        # 左右边界:根据OCR框位置和相邻单元格推断
+        ocr_cx = (ocr_box[0] + ocr_box[2]) / 2
+        
+        # 找左边最近的单元格
+        left_cells = [c for c in target_row if c[2] < ocr_cx]
+        if left_cells:
+            cell_x1 = max(c[2] for c in left_cells)  # 左边单元格的右边界
+        else:
+            cell_x1 = table_box[0]  # 表格左边界
+        
+        # 找右边最近的单元格
+        right_cells = [c for c in target_row if c[0] > ocr_cx]
+        if right_cells:
+            cell_x2 = min(c[0] for c in right_cells)  # 右边单元格的左边界
+        else:
+            cell_x2 = table_box[2]  # 表格右边界
+        
+        # 创建推断的单元格
+        inferred_cell = [cell_x1, cell_y1, cell_x2, cell_y2]
+        inferred_cells.append(inferred_cell)
+        inferred_texts.append(ocr_text)
+
+        print(f"✅ 为OCR文本 '{ocr_text}' 推断单元格: {inferred_cell}")
+    
+    # 5. 合并检测到的和推断的单元格
+    all_cells = detected_cells + inferred_cells
+    all_texts = cells_texts_list + inferred_texts
+    return all_cells, all_texts
+
+
 def enhanced_predict_single_table_recognition_res(
-    # self, *args, **kwargs):
     self,
     image_array: np.ndarray,
     overall_ocr_res: OCRResult,
@@ -191,13 +341,10 @@ def enhanced_predict_single_table_recognition_res(
     use_ocr_results_with_table_cells: bool = True,
     flag_find_nei_text: bool = True,
 ) -> SingleTableRecognitionResult:
-    """
-    这是将被注入到 _TableRecognitionPipelineV2 实例中的增强版方法。
-    它调用我们新的、解耦的结构生成和内容填充逻辑。
-    """
+    """增强版方法 - 使用OCR引导的单元格补全"""
     print(">>> [Adapter] enhanced_predict_single_table_recognition_res called")
     
-    # 🎯 复用原始逻辑来获取 table_cells_result
+    # 🎯 Step 1: 获取table_cells_result (原始逻辑)
     table_cls_pred = list(self.table_cls_model(image_array))[0]
     table_cls_result = self.extract_results(table_cls_pred, "cls")
 
@@ -208,32 +355,76 @@ def enhanced_predict_single_table_recognition_res(
     
     table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
     table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
-    cells_texts_list = self.gen_ocr_with_table_cells(image_array, table_cells_result)
-
-    # 🎯 注入我们的核心逻辑
-    # 只有当 use_wired_table_cells_trans_to_html 为 True 时,才使用我们的新方法
-    # 这样可以保持与原始行为的兼容性,并提供一个开关
-    if use_wired_table_cells_trans_to_html:
-        print(">>> [Adapter] Using robust HTML generation from cells.")
-        # 步骤1: 使用我们鲁棒的算法生成HTML骨架
-        html_skeleton = build_robust_html_from_cells(table_cells_result)
+    table_cells_result.sort(key=lambda c: (c[1], c[0]))
+    
+    # 🎯 Step 2: 坐标转换
+    from paddlex.inference.pipelines.table_recognition.table_recognition_post_processing_v2 import (
+        convert_to_four_point_coordinates,
+        convert_table_structure_pred_bbox,
+        get_sub_regions_ocr_res
+    )
+    import numpy as np
+    
+    # 转换为4点坐标
+    table_cells_result_4pt = convert_to_four_point_coordinates(table_cells_result)
+    
+    # 准备坐标转换参数
+    table_box_array = np.array([table_box])
+    crop_start_point = [table_box[0], table_box[1]]
+    img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
+    
+    # 转换到原图坐标系
+    table_cells_result_orig = convert_table_structure_pred_bbox(
+        table_cells_result_4pt, crop_start_point, img_shape
+    )
+    # 处理NumPy数组
+    if isinstance(table_cells_result_orig, np.ndarray):
+        table_cells_result_orig = table_cells_result_orig.tolist()
+    table_cells_result_orig.sort(key=lambda c: (c[1], c[0]))
 
-        # 步骤2: 使用全局OCR结果和Bbox来填充内容
-        pred_html = fill_html_with_ocr_by_bbox(html_skeleton, table_cells_result, cells_texts_list)
+    # 🎯 Step 3: 获取表格区域的OCR结果
+    table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box_array)
+    
+    # 🎯 Step 4: **关键改进** - OCR引导的单元格补全
+    if use_wired_table_cells_trans_to_html and use_ocr_results_with_table_cells:
+        # ✅ 对每个单元格做OCR(使用裁剪前的坐标)
+        cells_texts_list = self.gen_ocr_with_table_cells(image_array, table_cells_result)
+        # ✅ 补全缺失的单元格
+        completed_cells, cells_texts_list = infer_missing_cells_from_ocr(
+            detected_cells=table_cells_result_orig,
+            cells_texts_list=cells_texts_list,
+            overall_ocr_boxes=table_ocr_pred["rec_boxes"],
+            overall_ocr_texts=table_ocr_pred["rec_texts"],
+            table_box=table_box
+        )
 
+        # ✅ 生成HTML骨架(使用转换后的原图坐标)
+        html_skeleton = build_robust_html_from_cells(completed_cells)
+        
+        # ✅ 填充内容(使用单元格中心点坐标和单元格OCR文本)
+        pred_html = fill_html_with_ocr_by_bbox(
+            html_skeleton,
+            completed_cells,      # ✅ 单元格bbox
+            cells_texts_list      # ✅ 单元格OCR文本
+        )
+        
         single_img_res = {
-            "cell_box_list": table_cells_result,
-            "table_ocr_pred": {}, # 内容已填充,无需单独的 table_ocr_pred
+            "cell_box_list": completed_cells,
+            "table_ocr_pred": table_ocr_pred,  # 保留完整OCR信息
             "pred_html": pred_html,
         }
-        # 构造并返回结果
+        
         res = SingleTableRecognitionResult(single_img_res)
-        res["neighbor_texts"] = "" # 保持字段存在
+        res["neighbor_texts"] = ""
         return res
     else:
-        # 🎯 如果开关关闭,则调用原始的、未被补丁的方法
-        print(">>> [Adapter] Falling back to original predict_single_table_recognition_res.")
-        return _original_predict_single(self, *args, **kwargs)
+        # 回退到原始实现
+        return _original_predict_single(
+            self, image_array, overall_ocr_res, table_box,
+            use_e2e_wired_table_rec_model, use_e2e_wireless_table_rec_model,
+            use_wired_table_cells_trans_to_html, use_wireless_table_cells_trans_to_html,
+            use_ocr_results_with_table_cells, flag_find_nei_text
+        )
 
 
 def apply_table_recognition_adapter():