""" 表格识别个性化适配器 (v6 - 行内重叠合并修正版) 核心思想: 1. 废弃全局坐标聚类,改为按行分组和对齐,极大提升对倾斜、不规则表格的鲁棒性。 2. 结构生成与内容填充彻底分离: - `build_robust_html_from_cells`: 仅根据单元格几何位置,生成带`data-bbox`的HTML骨架。 - `fill_html_with_ocr_by_bbox`: 根据`data-bbox`从全局OCR结果中查找文本并填充。 3. 通过适配器直接替换PaddleX Pipeline中的核心方法,实现无侵入式升级。 """ import importlib from typing import Any, Dict, List import numpy as np from paddlex.inference.pipelines.table_recognition.result import SingleTableRecognitionResult from paddlex.inference.pipelines.table_recognition.pipeline_v2 import OCRResult def _normalize_bbox(box: list) -> list: """ 将8点坐标或4点坐标统一转换为 [x1, y1, x2, y2] """ if len(box) == 8: # 8点坐标:取最小和最大值 xs = [box[0], box[2], box[4], box[6]] ys = [box[1], box[3], box[5], box[7]] return [min(xs), min(ys), max(xs), max(ys)] elif len(box) == 4: return box[:4] else: raise ValueError(f"Unsupported bbox format: {box}") # --- 1. 核心算法:基于排序和行分组的HTML结构生成 --- def filter_nested_boxes(boxes: List[list]) -> List[list]: """ 移除被其他框完全包含的框。 boxes: List[[x1, y1, x2, y2]] """ if not boxes: return [] filtered = [] # 按面积从大到小排序,优先保留大框 boxes.sort(key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True) for i, box in enumerate(boxes): is_nested = False for j in range(i): # 只需和排在前面的(更大的)框比较 outer_box = boxes[j] # 判断 box 是否被 outer_box 包含 if outer_box[0] <= box[0] and outer_box[1] <= box[1] and \ outer_box[2] >= box[2] and outer_box[3] >= box[3]: is_nested = True break if not is_nested: filtered.append(box) return filtered def merge_overlapping_cells_in_row(row_cells: List[list], iou_threshold: float = 0.5) -> List[list]: """ 合并单行内水平方向上高度重叠的单元格。 """ if not row_cells: return [] # 按x坐标排序 cells = sorted(row_cells, key=lambda c: c[0]) merged_cells = [] i = 0 while i < len(cells): current_cell = list(cells[i]) # 使用副本 j = i + 1 while j < len(cells): next_cell = cells[j] # 计算交集 inter_x1 = max(current_cell[0], next_cell[0]) inter_y1 = max(current_cell[1], next_cell[1]) inter_x2 = min(current_cell[2], next_cell[2]) inter_y2 = min(current_cell[3], next_cell[3]) inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) # 如果交集面积大于其中一个框面积的阈值,则认为是重叠 current_area = (current_cell[2] - current_cell[0]) * (current_cell[3] - current_cell[1]) next_area = (next_cell[2] - next_cell[0]) * (next_cell[3] - next_cell[1]) if inter_area > min(current_area, next_area) * iou_threshold: # 合并两个框,取外包围框 current_cell[0] = min(current_cell[0], next_cell[0]) current_cell[1] = min(current_cell[1], next_cell[1]) current_cell[2] = max(current_cell[2], next_cell[2]) current_cell[3] = max(current_cell[3], next_cell[3]) j += 1 else: break # 不再与更远的单元格合并 merged_cells.append(current_cell) i = j return merged_cells def build_robust_html_from_cells(cells_det_results: List[list]) -> str: """ 通过按行排序、分组、合并和对齐,稳健地将单元格Bbox列表转换为带data-bbox的HTML结构。 """ if not cells_det_results: return "
" # ✅ 关键修复:使用副本防止修改原始列表 import copy cells_copy = copy.deepcopy(cells_det_results) cells = filter_nested_boxes(cells_copy) cells.sort(key=lambda c: (c[1], c[0])) rows = [] if cells: current_row = [cells[0]] # ✅ 使用该行的Y范围而不是单个锚点 row_y1 = cells[0][1] row_y2 = cells[0][3] for cell in cells[1:]: # ✅ 计算垂直方向的重叠 overlap_y1 = max(row_y1, cell[1]) overlap_y2 = min(row_y2, cell[3]) overlap_height = max(0, overlap_y2 - overlap_y1) # 单元格和当前行的平均高度 cell_height = cell[3] - cell[1] row_height = row_y2 - row_y1 avg_height = (cell_height + row_height) / 2 # ✅ 重叠高度超过平均高度的50%,认为是同一行 if overlap_height > avg_height * 0.5: current_row.append(cell) # 更新该行的Y范围(扩展以包含新单元格) row_y1 = min(row_y1, cell[1]) row_y2 = max(row_y2, cell[3]) else: rows.append(current_row) current_row = [cell] row_y1 = cell[1] row_y2 = cell[3] rows.append(current_row) html = "" for row_cells in rows: # 🎯 核心修正:在生成HTML前,合并行内的重叠单元格 merged_row_cells = merge_overlapping_cells_in_row(row_cells) html += "" for cell in merged_row_cells: bbox_str = f"[{','.join(map(str, map(int, cell)))}]" html += f'' html += "" html += "
" return html # --- 2. 内容填充工具 --- def fill_html_with_ocr_by_bbox(html_skeleton: str, ocr_dt_boxes: list, ocr_texts: list) -> str: """ 根据带有 data-bbox 的 HTML 骨架和全局 OCR 结果填充表格内容。 """ try: from bs4 import BeautifulSoup except ImportError: print("⚠️ BeautifulSoup not installed. Cannot fill table content. Returning skeleton.") return html_skeleton soup = BeautifulSoup(html_skeleton, 'html.parser') # # ocr_dt_boxes = cells_ocr_res.get("rec_boxes", []) # ocr_texts = cells_ocr_res.get("rec_texts", []) # 为快速查找,将OCR结果组织起来 ocr_items = [] for box, text in zip(ocr_dt_boxes, ocr_texts): center_x = (box[0] + box[2]) / 2 center_y = (box[1] + box[3]) / 2 ocr_items.append({'box': box, 'text': text, 'center': (center_x, center_y)}) for td in soup.find_all('td'): if not td.has_attr('data-bbox'): continue bbox_str = td['data-bbox'].strip('[]') cell_box = list(map(float, bbox_str.split(','))) cx1, cy1, cx2, cy2 = cell_box cell_texts_with_pos = [] # 查找所有中心点在该单元格内的OCR文本 for item in ocr_items: if cx1 <= item['center'][0] <= cx2 and cy1 <= item['center'][1] <= cy2: # 记录文本和其y坐标,用于后续排序 cell_texts_with_pos.append((item['text'], item['box'][1])) if cell_texts_with_pos: # 按y坐标排序,确保多行文本的顺序正确 cell_texts_with_pos.sort(key=lambda x: x[1]) # 合并文本 td.string = " ".join([text for text, y in cell_texts_with_pos]) return str(soup) # --- 3. 适配器主函数和应用逻辑 --- # 保存原始方法的引用 _original_predict_single = None def infer_missing_cells_from_ocr( detected_cells: List[list], cells_texts_list: List[str], overall_ocr_boxes: List[list], overall_ocr_texts: List[str], table_box: list ) -> tuple[List[list], List[str]]: """ 根据全局OCR结果推断缺失的单元格 Args: detected_cells: 已检测到的单元格坐标 [[x1,y1,x2,y2], ...] overall_ocr_boxes: 全局OCR框坐标 overall_ocr_texts: 全局OCR文本 table_box: 表格区域 [x1,y1,x2,y2] Returns: 补全后的单元格列表 """ import copy # 1. 找出未被覆盖的OCR框 uncovered_ocr_boxes = [] uncovered_ocr_texts = [] for ocr_box, ocr_text in zip(overall_ocr_boxes, overall_ocr_texts): # 计算OCR框中心点 ocr_cx = (ocr_box[0] + ocr_box[2]) / 2 ocr_cy = (ocr_box[1] + ocr_box[3]) / 2 # 检查是否被任何单元格覆盖 is_covered = False for cell in detected_cells: if cell[0] <= ocr_cx <= cell[2] and cell[1] <= ocr_cy <= cell[3]: is_covered = True break if not is_covered: uncovered_ocr_boxes.append(ocr_box) uncovered_ocr_texts.append(ocr_text) if not uncovered_ocr_boxes: return detected_cells, cells_texts_list # 没有漏检 # 2. 按行分组已检测的单元格 cells_sorted = sorted(detected_cells, key=lambda c: (c[1], c[0])) rows = [] if cells_sorted: current_row = [cells_sorted[0]] row_y = (cells_sorted[0][1] + cells_sorted[0][3]) / 2 row_height = cells_sorted[0][3] - cells_sorted[0][1] for cell in cells_sorted[1:]: cell_y = (cell[1] + cell[3]) / 2 if abs(cell_y - row_y) < row_height * 0.7: current_row.append(cell) else: rows.append(current_row) current_row = [cell] row_y = (cell[1] + cell[3]) / 2 row_height = cell[3] - cell[1] rows.append(current_row) # 3. 为每个未覆盖的OCR框推断单元格 inferred_cells = [] inferred_texts = [] for ocr_box, ocr_text in zip(uncovered_ocr_boxes, uncovered_ocr_texts): ocr_cy = (ocr_box[1] + ocr_box[3]) / 2 # 找到OCR框所在的行 target_row_idx = None for i, row_cells in enumerate(rows): row_y1 = min(c[1] for c in row_cells) row_y2 = max(c[3] for c in row_cells) if row_y1 <= ocr_cy <= row_y2: target_row_idx = i break if target_row_idx is None: # 无法确定所属行,跳过 print(f"⚠️ 无法为OCR文本 '{ocr_text}' 确定所属行") continue target_row = rows[target_row_idx] # 4. 推断单元格边界 # 上下边界:使用该行的统一高度 cell_y1 = min(c[1] for c in target_row) cell_y2 = max(c[3] for c in target_row) # 左右边界:根据OCR框位置和相邻单元格推断 ocr_cx = (ocr_box[0] + ocr_box[2]) / 2 # 找左边最近的单元格 left_cells = [c for c in target_row if c[2] < ocr_cx] if left_cells: cell_x1 = max(c[2] for c in left_cells) # 左边单元格的右边界 else: cell_x1 = table_box[0] # 表格左边界 # 找右边最近的单元格 right_cells = [c for c in target_row if c[0] > ocr_cx] if right_cells: cell_x2 = min(c[0] for c in right_cells) # 右边单元格的左边界 else: cell_x2 = table_box[2] # 表格右边界 # 创建推断的单元格 inferred_cell = [cell_x1, cell_y1, cell_x2, cell_y2] inferred_cells.append(inferred_cell) inferred_texts.append(ocr_text) print(f"✅ 为OCR文本 '{ocr_text}' 推断单元格: {inferred_cell}") # 5. 合并检测到的和推断的单元格 all_cells = detected_cells + inferred_cells all_texts = cells_texts_list + inferred_texts return all_cells, all_texts def enhanced_predict_single_table_recognition_res( self, image_array: np.ndarray, overall_ocr_res: OCRResult, table_box: list, use_e2e_wired_table_rec_model: bool = False, use_e2e_wireless_table_rec_model: bool = False, use_wired_table_cells_trans_to_html: bool = False, use_wireless_table_cells_trans_to_html: bool = False, use_ocr_results_with_table_cells: bool = True, flag_find_nei_text: bool = True, ) -> SingleTableRecognitionResult: """增强版方法 - 使用OCR引导的单元格补全""" print(">>> [Adapter] enhanced_predict_single_table_recognition_res called") # 🎯 Step 1: 获取table_cells_result (原始逻辑) table_cls_pred = list(self.table_cls_model(image_array))[0] table_cls_result = self.extract_results(table_cls_pred, "cls") if table_cls_result == "wired_table": table_cells_pred = list(self.wired_table_cells_detection_model(image_array, threshold=0.3))[0] else: # wireless_table table_cells_pred = list(self.wireless_table_cells_detection_model(image_array, threshold=0.3))[0] table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det") table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score) table_cells_result.sort(key=lambda c: (c[1], c[0])) # 🎯 Step 2: 坐标转换 from paddlex.inference.pipelines.table_recognition.table_recognition_post_processing_v2 import ( convert_to_four_point_coordinates, convert_table_structure_pred_bbox, get_sub_regions_ocr_res ) import numpy as np # 转换为4点坐标 table_cells_result_4pt = convert_to_four_point_coordinates(table_cells_result) # 准备坐标转换参数 table_box_array = np.array([table_box]) crop_start_point = [table_box[0], table_box[1]] img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2] # 转换到原图坐标系 table_cells_result_orig = convert_table_structure_pred_bbox( table_cells_result_4pt, crop_start_point, img_shape ) # 处理NumPy数组 if isinstance(table_cells_result_orig, np.ndarray): table_cells_result_orig = table_cells_result_orig.tolist() table_cells_result_orig.sort(key=lambda c: (c[1], c[0])) # 🎯 Step 3: 获取表格区域的OCR结果 table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box_array) # 🎯 Step 4: **关键改进** - OCR引导的单元格补全 if (use_wired_table_cells_trans_to_html or use_wireless_table_cells_trans_to_html) and use_ocr_results_with_table_cells: # ✅ 修复: 确保 general_ocr_pipeline 被初始化 if self.general_ocr_pipeline is None: if hasattr(self, 'general_ocr_config_bak') and self.general_ocr_config_bak is not None: print("🔧 [Adapter] Initializing general_ocr_pipeline from backup config") self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak) else: print("⚠️ [Adapter] No OCR pipeline available, falling back to original implementation") return _original_predict_single( self, image_array, overall_ocr_res, table_box, use_e2e_wired_table_rec_model, use_e2e_wireless_table_rec_model, use_wired_table_cells_trans_to_html, use_wireless_table_cells_trans_to_html, use_ocr_results_with_table_cells, flag_find_nei_text ) # ✅ 对每个单元格做OCR(使用裁剪前的坐标) cells_texts_list = self.gen_ocr_with_table_cells(image_array, table_cells_result) # ✅ 补全缺失的单元格 completed_cells, cells_texts_list = infer_missing_cells_from_ocr( detected_cells=table_cells_result_orig, cells_texts_list=cells_texts_list, overall_ocr_boxes=table_ocr_pred["rec_boxes"], overall_ocr_texts=table_ocr_pred["rec_texts"], table_box=table_box ) # ✅ 生成HTML骨架(使用转换后的原图坐标) html_skeleton = build_robust_html_from_cells(completed_cells) # ✅ 填充内容(使用单元格bbox和单元格OCR文本) pred_html = fill_html_with_ocr_by_bbox( html_skeleton, completed_cells, # ✅ 单元格bbox cells_texts_list # ✅ 单元格OCR文本 ) single_img_res = { "cell_box_list": completed_cells, "table_ocr_pred": table_ocr_pred, # 保留完整OCR信息 "pred_html": pred_html, } res = SingleTableRecognitionResult(single_img_res) res["neighbor_texts"] = "" return res else: print(f"⚠️ Fallback to original implementation: {table_cls_result}") return _original_predict_single( self, image_array, overall_ocr_res, table_box, use_e2e_wired_table_rec_model, use_e2e_wireless_table_rec_model, use_wired_table_cells_trans_to_html, use_wireless_table_cells_trans_to_html, use_ocr_results_with_table_cells, flag_find_nei_text ) def apply_table_recognition_adapter(): """ 应用表格识别适配器。 我们直接替换 _TableRecognitionPipelineV2 类中的 `predict_single_table_recognition_res` 方法。 """ global _original_predict_single try: # 导入目标类 from paddlex.inference.pipelines.table_recognition.pipeline_v2 import _TableRecognitionPipelineV2 # 保存原函数,防止重复应用补丁 if _original_predict_single is None: _original_predict_single = _TableRecognitionPipelineV2.predict_single_table_recognition_res # 替换为增强版 _TableRecognitionPipelineV2.predict_single_table_recognition_res = enhanced_predict_single_table_recognition_res print("✅ Table recognition adapter applied successfully (v3 - corrected).") return True except Exception as e: print(f"❌ Failed to apply table recognition adapter: {e}") return False def restore_original_function(): """恢复原始函数""" global _original_predict_single try: from paddlex.inference.pipelines.table_recognition.pipeline_v2 import _TableRecognitionPipelineV2 if _original_predict_single is not None: _TableRecognitionPipelineV2.predict_single_table_recognition_res = _original_predict_single _original_predict_single = None # 重置状态 print("✅ Original function restored.") return True return False except Exception as e: print(f"❌ Failed to restore original function: {e}") return False