Quellcode durchsuchen

feat: 新增表格识别个性化适配器,优化行内重叠合并逻辑,提升对不规则表格的处理能力

zhch158_admin vor 1 Monat
Ursprung
Commit
f3f55ffae3
2 geänderte Dateien mit 279 neuen und 0 gelöschten Zeilen
  1. 0 0
      zhch/adapters/__init__.py
  2. 279 0
      zhch/adapters/table_recognition_adapter.py

+ 0 - 0
zhch/adapters/__init__.py


+ 279 - 0
zhch/adapters/table_recognition_adapter.py

@@ -0,0 +1,279 @@
+"""
+表格识别个性化适配器 (v6 - 行内重叠合并修正版)
+
+核心思想:
+1. 废弃全局坐标聚类,改为按行分组和对齐,极大提升对倾斜、不规则表格的鲁棒性。
+2. 结构生成与内容填充彻底分离:
+   - `build_robust_html_from_cells`: 仅根据单元格几何位置,生成带`data-bbox`的HTML骨架。
+   - `fill_html_with_ocr_by_bbox`: 根据`data-bbox`从全局OCR结果中查找文本并填充。
+3. 通过适配器直接替换PaddleX Pipeline中的核心方法,实现无侵入式升级。
+"""
+import importlib
+from typing import Any, Dict, List
+import numpy as np
+
+from paddlex.inference.pipelines.table_recognition.result import SingleTableRecognitionResult
+from paddlex.inference.pipelines.table_recognition.pipeline_v2 import OCRResult
+
+# --- 1. 核心算法:基于排序和行分组的HTML结构生成 ---
+def filter_nested_boxes(boxes: List[list]) -> List[list]:
+    """
+    移除被其他框完全包含的框。
+    boxes: List[[x1, y1, x2, y2]]
+    """
+    if not boxes:
+        return []
+    
+    filtered = []
+    # 按面积从大到小排序,优先保留大框
+    boxes.sort(key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)
+    
+    for i, box in enumerate(boxes):
+        is_nested = False
+        for j in range(i): # 只需和排在前面的(更大的)框比较
+            outer_box = boxes[j]
+            # 判断 box 是否被 outer_box 包含
+            if outer_box[0] <= box[0] and outer_box[1] <= box[1] and \
+               outer_box[2] >= box[2] and outer_box[3] >= box[3]:
+                is_nested = True
+                break
+        if not is_nested:
+            filtered.append(box)
+    return filtered
+
+def merge_overlapping_cells_in_row(row_cells: List[list], iou_threshold: float = 0.5) -> List[list]:
+    """
+    合并单行内水平方向上高度重叠的单元格。
+    """
+    if not row_cells:
+        return []
+
+    # 按x坐标排序
+    cells = sorted(row_cells, key=lambda c: c[0])
+    
+    merged_cells = []
+    i = 0
+    while i < len(cells):
+        current_cell = list(cells[i]) # 使用副本
+        j = i + 1
+        while j < len(cells):
+            next_cell = cells[j]
+            
+            # 计算交集
+            inter_x1 = max(current_cell[0], next_cell[0])
+            inter_y1 = max(current_cell[1], next_cell[1])
+            inter_x2 = min(current_cell[2], next_cell[2])
+            inter_y2 = min(current_cell[3], next_cell[3])
+            
+            inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
+            
+            # 如果交集面积大于其中一个框面积的阈值,则认为是重叠
+            current_area = (current_cell[2] - current_cell[0]) * (current_cell[3] - current_cell[1])
+            next_area = (next_cell[2] - next_cell[0]) * (next_cell[3] - next_cell[1])
+            
+            if inter_area > min(current_area, next_area) * iou_threshold:
+                # 合并两个框,取外包围框
+                current_cell[0] = min(current_cell[0], next_cell[0])
+                current_cell[1] = min(current_cell[1], next_cell[1])
+                current_cell[2] = max(current_cell[2], next_cell[2])
+                current_cell[3] = max(current_cell[3], next_cell[3])
+                j += 1
+            else:
+                break # 不再与更远的单元格合并
+        
+        merged_cells.append(current_cell)
+        i = j
+        
+    return merged_cells
+
+def build_robust_html_from_cells(cells_det_results: List[list]) -> str:
+    """
+    通过按行排序、分组、合并和对齐,稳健地将单元格Bbox列表转换为带data-bbox的HTML结构。
+    """
+    if not cells_det_results:
+        return "<table><tbody></tbody></table>"
+
+    cells = filter_nested_boxes(cells_det_results)
+    cells.sort(key=lambda c: (c[1], c[0]))
+
+    rows = []
+    if cells:
+        current_row = [cells[0]]
+        row_anchor_y = (cells[0][1] + cells[0][3]) / 2
+        row_anchor_height = cells[0][3] - cells[0][1]
+
+        for cell in cells[1:]:
+            cell_y_center = (cell[1] + cell[3]) / 2
+            if abs(cell_y_center - row_anchor_y) < row_anchor_height * 0.7:
+                current_row.append(cell)
+            else:
+                rows.append(current_row)
+                current_row = [cell]
+                row_anchor_y = (cell[1] + cell[3]) / 2
+                row_anchor_height = cell[3] - cell[1]
+        rows.append(current_row)
+
+    html = "<table><tbody>"
+    for row_cells in rows:
+        # 🎯 核心修正:在生成HTML前,合并行内的重叠单元格
+        merged_row_cells = merge_overlapping_cells_in_row(row_cells)
+        
+        html += "<tr>"
+        for cell in merged_row_cells:
+            bbox_str = f"[{','.join(map(str, map(int, cell)))}]"
+            html += f'<td data-bbox="{bbox_str}"></td>'
+        html += "</tr>"
+    html += "</tbody></table>"
+    
+    return html
+
+# --- 2. 内容填充工具 ---
+
+def fill_html_with_ocr_by_bbox(html_skeleton: str, ocr_dt_boxes: list, ocr_texts: list) -> str:
+    """
+    根据带有 data-bbox 的 HTML 骨架和全局 OCR 结果填充表格内容。
+    """
+    try:
+        from bs4 import BeautifulSoup
+    except ImportError:
+        print("⚠️  BeautifulSoup not installed. Cannot fill table content. Returning skeleton.")
+        return html_skeleton
+
+    soup = BeautifulSoup(html_skeleton, 'html.parser')
+    # # ocr_dt_boxes = cells_ocr_res.get("rec_boxes", [])
+    # ocr_texts = cells_ocr_res.get("rec_texts", [])
+
+    # 为快速查找,将OCR结果组织起来
+    ocr_items = []
+    for box, text in zip(ocr_dt_boxes, ocr_texts):
+        center_x = (box[0] + box[2]) / 2
+        center_y = (box[1] + box[3]) / 2
+        ocr_items.append({'box': box, 'text': text, 'center': (center_x, center_y)})
+
+    for td in soup.find_all('td'):
+        if not td.has_attr('data-bbox'):
+            continue
+        
+        bbox_str = td['data-bbox'].strip('[]')
+        cell_box = list(map(float, bbox_str.split(',')))
+        cx1, cy1, cx2, cy2 = cell_box
+
+        cell_texts_with_pos = []
+        # 查找所有中心点在该单元格内的OCR文本
+        for item in ocr_items:
+            if cx1 <= item['center'][0] <= cx2 and cy1 <= item['center'][1] <= cy2:
+                # 记录文本和其y坐标,用于后续排序
+                cell_texts_with_pos.append((item['text'], item['box'][1]))
+        
+        if cell_texts_with_pos:
+            # 按y坐标排序,确保多行文本的顺序正确
+            cell_texts_with_pos.sort(key=lambda x: x[1])
+            # 合并文本
+            td.string = " ".join([text for text, y in cell_texts_with_pos])
+            
+    return str(soup)
+
+# --- 3. 适配器主函数和应用逻辑 ---
+
+# 保存原始方法的引用
+_original_predict_single = None
+
+def enhanced_predict_single_table_recognition_res(
+    # self, *args, **kwargs):
+    self,
+    image_array: np.ndarray,
+    overall_ocr_res: OCRResult,
+    table_box: list,
+    use_e2e_wired_table_rec_model: bool = False,
+    use_e2e_wireless_table_rec_model: bool = False,
+    use_wired_table_cells_trans_to_html: bool = False,
+    use_wireless_table_cells_trans_to_html: bool = False,
+    use_ocr_results_with_table_cells: bool = True,
+    flag_find_nei_text: bool = True,
+) -> SingleTableRecognitionResult:
+    """
+    这是将被注入到 _TableRecognitionPipelineV2 实例中的增强版方法。
+    它调用我们新的、解耦的结构生成和内容填充逻辑。
+    """
+    print(">>> [Adapter] enhanced_predict_single_table_recognition_res called")
+    
+    # 🎯 复用原始逻辑来获取 table_cells_result
+    table_cls_pred = list(self.table_cls_model(image_array))[0]
+    table_cls_result = self.extract_results(table_cls_pred, "cls")
+
+    if table_cls_result == "wired_table":
+        table_cells_pred = list(self.wired_table_cells_detection_model(image_array, threshold=0.3))[0]
+    else: # wireless_table
+        table_cells_pred = list(self.wireless_table_cells_detection_model(image_array, threshold=0.3))[0]
+    
+    table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
+    table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
+    cells_texts_list = self.gen_ocr_with_table_cells(image_array, table_cells_result)
+
+    # 🎯 注入我们的核心逻辑
+    # 只有当 use_wired_table_cells_trans_to_html 为 True 时,才使用我们的新方法
+    # 这样可以保持与原始行为的兼容性,并提供一个开关
+    if use_wired_table_cells_trans_to_html:
+        print(">>> [Adapter] Using robust HTML generation from cells.")
+        # 步骤1: 使用我们鲁棒的算法生成HTML骨架
+        html_skeleton = build_robust_html_from_cells(table_cells_result)
+
+        # 步骤2: 使用全局OCR结果和Bbox来填充内容
+        pred_html = fill_html_with_ocr_by_bbox(html_skeleton, table_cells_result, cells_texts_list)
+
+        single_img_res = {
+            "cell_box_list": table_cells_result,
+            "table_ocr_pred": {}, # 内容已填充,无需单独的 table_ocr_pred
+            "pred_html": pred_html,
+        }
+        # 构造并返回结果
+        res = SingleTableRecognitionResult(single_img_res)
+        res["neighbor_texts"] = "" # 保持字段存在
+        return res
+    else:
+        # 🎯 如果开关关闭,则调用原始的、未被补丁的方法
+        print(">>> [Adapter] Falling back to original predict_single_table_recognition_res.")
+        return _original_predict_single(self, *args, **kwargs)
+
+
+def apply_table_recognition_adapter():
+    """
+    应用表格识别适配器。
+    我们直接替换 _TableRecognitionPipelineV2 类中的 `predict_single_table_recognition_res` 方法。
+    """
+    global _original_predict_single
+    
+    try:
+        # 导入目标类
+        from paddlex.inference.pipelines.table_recognition.pipeline_v2 import _TableRecognitionPipelineV2
+        
+        # 保存原函数,防止重复应用补丁
+        if _original_predict_single is None:
+             _original_predict_single = _TableRecognitionPipelineV2.predict_single_table_recognition_res
+        
+        # 替换为增强版
+        _TableRecognitionPipelineV2.predict_single_table_recognition_res = enhanced_predict_single_table_recognition_res
+        
+        print("✅ Table recognition adapter applied successfully (v3 - corrected).")
+        return True
+        
+    except Exception as e:
+        print(f"❌ Failed to apply table recognition adapter: {e}")
+        return False
+
+
+def restore_original_function():
+    """恢复原始函数"""
+    global _original_predict_single
+    try:
+        from paddlex.inference.pipelines.table_recognition.pipeline_v2 import _TableRecognitionPipelineV2
+        
+        if _original_predict_single is not None:
+            _TableRecognitionPipelineV2.predict_single_table_recognition_res = _original_predict_single
+            _original_predict_single = None # 重置状态
+            print("✅ Original function restored.")
+            return True
+        return False
+    except Exception as e:
+        print(f"❌ Failed to restore original function: {e}")
+        return False