table_recognition_adapter.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. """
  2. 表格识别个性化适配器 (v6 - 行内重叠合并修正版)
  3. 核心思想:
  4. 1. 废弃全局坐标聚类,改为按行分组和对齐,极大提升对倾斜、不规则表格的鲁棒性。
  5. 2. 结构生成与内容填充彻底分离:
  6. - `build_robust_html_from_cells`: 仅根据单元格几何位置,生成带`data-bbox`的HTML骨架。
  7. - `fill_html_with_ocr_by_bbox`: 根据`data-bbox`从全局OCR结果中查找文本并填充。
  8. 3. 通过适配器直接替换PaddleX Pipeline中的核心方法,实现无侵入式升级。
  9. """
  10. import importlib
  11. from typing import Any, Dict, List
  12. import numpy as np
  13. from paddlex.inference.pipelines.table_recognition.result import SingleTableRecognitionResult
  14. from paddlex.inference.pipelines.table_recognition.pipeline_v2 import OCRResult
  15. def _normalize_bbox(box: list) -> list:
  16. """
  17. 将8点坐标或4点坐标统一转换为 [x1, y1, x2, y2]
  18. """
  19. if len(box) == 8:
  20. # 8点坐标:取最小和最大值
  21. xs = [box[0], box[2], box[4], box[6]]
  22. ys = [box[1], box[3], box[5], box[7]]
  23. return [min(xs), min(ys), max(xs), max(ys)]
  24. elif len(box) == 4:
  25. return box[:4]
  26. else:
  27. raise ValueError(f"Unsupported bbox format: {box}")
  28. # --- 1. 核心算法:基于排序和行分组的HTML结构生成 ---
  29. def filter_nested_boxes(boxes: List[list]) -> List[list]:
  30. """
  31. 移除被其他框完全包含的框。
  32. boxes: List[[x1, y1, x2, y2]]
  33. """
  34. if not boxes:
  35. return []
  36. filtered = []
  37. # 按面积从大到小排序,优先保留大框
  38. boxes.sort(key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)
  39. for i, box in enumerate(boxes):
  40. is_nested = False
  41. for j in range(i): # 只需和排在前面的(更大的)框比较
  42. outer_box = boxes[j]
  43. # 判断 box 是否被 outer_box 包含
  44. if outer_box[0] <= box[0] and outer_box[1] <= box[1] and \
  45. outer_box[2] >= box[2] and outer_box[3] >= box[3]:
  46. is_nested = True
  47. break
  48. if not is_nested:
  49. filtered.append(box)
  50. return filtered
  51. def merge_overlapping_cells_in_row(row_cells: List[list], iou_threshold: float = 0.5) -> List[list]:
  52. """
  53. 合并单行内水平方向上高度重叠的单元格。
  54. """
  55. if not row_cells:
  56. return []
  57. # 按x坐标排序
  58. cells = sorted(row_cells, key=lambda c: c[0])
  59. merged_cells = []
  60. i = 0
  61. while i < len(cells):
  62. current_cell = list(cells[i]) # 使用副本
  63. j = i + 1
  64. while j < len(cells):
  65. next_cell = cells[j]
  66. # 计算交集
  67. inter_x1 = max(current_cell[0], next_cell[0])
  68. inter_y1 = max(current_cell[1], next_cell[1])
  69. inter_x2 = min(current_cell[2], next_cell[2])
  70. inter_y2 = min(current_cell[3], next_cell[3])
  71. inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
  72. # 如果交集面积大于其中一个框面积的阈值,则认为是重叠
  73. current_area = (current_cell[2] - current_cell[0]) * (current_cell[3] - current_cell[1])
  74. next_area = (next_cell[2] - next_cell[0]) * (next_cell[3] - next_cell[1])
  75. if inter_area > min(current_area, next_area) * iou_threshold:
  76. # 合并两个框,取外包围框
  77. current_cell[0] = min(current_cell[0], next_cell[0])
  78. current_cell[1] = min(current_cell[1], next_cell[1])
  79. current_cell[2] = max(current_cell[2], next_cell[2])
  80. current_cell[3] = max(current_cell[3], next_cell[3])
  81. j += 1
  82. else:
  83. break # 不再与更远的单元格合并
  84. merged_cells.append(current_cell)
  85. i = j
  86. return merged_cells
  87. def build_robust_html_from_cells(cells_det_results: List[list]) -> str:
  88. """
  89. 通过按行排序、分组、合并和对齐,稳健地将单元格Bbox列表转换为带data-bbox的HTML结构。
  90. """
  91. if not cells_det_results:
  92. return "<table><tbody></tbody></table>"
  93. # ✅ 关键修复:使用副本防止修改原始列表
  94. import copy
  95. cells_copy = copy.deepcopy(cells_det_results)
  96. cells = filter_nested_boxes(cells_copy)
  97. cells.sort(key=lambda c: (c[1], c[0]))
  98. rows = []
  99. if cells:
  100. current_row = [cells[0]]
  101. # ✅ 使用该行的Y范围而不是单个锚点
  102. row_y1 = cells[0][1]
  103. row_y2 = cells[0][3]
  104. for cell in cells[1:]:
  105. # ✅ 计算垂直方向的重叠
  106. overlap_y1 = max(row_y1, cell[1])
  107. overlap_y2 = min(row_y2, cell[3])
  108. overlap_height = max(0, overlap_y2 - overlap_y1)
  109. # 单元格和当前行的平均高度
  110. cell_height = cell[3] - cell[1]
  111. row_height = row_y2 - row_y1
  112. avg_height = (cell_height + row_height) / 2
  113. # ✅ 重叠高度超过平均高度的50%,认为是同一行
  114. if overlap_height > avg_height * 0.5:
  115. current_row.append(cell)
  116. # 更新该行的Y范围(扩展以包含新单元格)
  117. row_y1 = min(row_y1, cell[1])
  118. row_y2 = max(row_y2, cell[3])
  119. else:
  120. rows.append(current_row)
  121. current_row = [cell]
  122. row_y1 = cell[1]
  123. row_y2 = cell[3]
  124. rows.append(current_row)
  125. html = "<table><tbody>"
  126. for row_cells in rows:
  127. # 🎯 核心修正:在生成HTML前,合并行内的重叠单元格
  128. merged_row_cells = merge_overlapping_cells_in_row(row_cells)
  129. html += "<tr>"
  130. for cell in merged_row_cells:
  131. bbox_str = f"[{','.join(map(str, map(int, cell)))}]"
  132. html += f'<td data-bbox="{bbox_str}"></td>'
  133. html += "</tr>"
  134. html += "</tbody></table>"
  135. return html
  136. # --- 2. 内容填充工具 ---
  137. def fill_html_with_ocr_by_bbox(html_skeleton: str, ocr_dt_boxes: list, ocr_texts: list) -> str:
  138. """
  139. 根据带有 data-bbox 的 HTML 骨架和全局 OCR 结果填充表格内容。
  140. """
  141. try:
  142. from bs4 import BeautifulSoup
  143. except ImportError:
  144. print("⚠️ BeautifulSoup not installed. Cannot fill table content. Returning skeleton.")
  145. return html_skeleton
  146. soup = BeautifulSoup(html_skeleton, 'html.parser')
  147. # # ocr_dt_boxes = cells_ocr_res.get("rec_boxes", [])
  148. # ocr_texts = cells_ocr_res.get("rec_texts", [])
  149. # 为快速查找,将OCR结果组织起来
  150. ocr_items = []
  151. for box, text in zip(ocr_dt_boxes, ocr_texts):
  152. center_x = (box[0] + box[2]) / 2
  153. center_y = (box[1] + box[3]) / 2
  154. ocr_items.append({'box': box, 'text': text, 'center': (center_x, center_y)})
  155. for td in soup.find_all('td'):
  156. if not td.has_attr('data-bbox'):
  157. continue
  158. bbox_str = td['data-bbox'].strip('[]')
  159. cell_box = list(map(float, bbox_str.split(',')))
  160. cx1, cy1, cx2, cy2 = cell_box
  161. cell_texts_with_pos = []
  162. # 查找所有中心点在该单元格内的OCR文本
  163. for item in ocr_items:
  164. if cx1 <= item['center'][0] <= cx2 and cy1 <= item['center'][1] <= cy2:
  165. # 记录文本和其y坐标,用于后续排序
  166. cell_texts_with_pos.append((item['text'], item['box'][1]))
  167. if cell_texts_with_pos:
  168. # 按y坐标排序,确保多行文本的顺序正确
  169. cell_texts_with_pos.sort(key=lambda x: x[1])
  170. # 合并文本
  171. td.string = " ".join([text for text, y in cell_texts_with_pos])
  172. return str(soup)
  173. # --- 3. 适配器主函数和应用逻辑 ---
  174. # 保存原始方法的引用
  175. _original_predict_single = None
  176. def infer_missing_cells_from_ocr(
  177. detected_cells: List[list],
  178. cells_texts_list: List[str],
  179. overall_ocr_boxes: List[list],
  180. overall_ocr_texts: List[str],
  181. table_box: list
  182. ) -> tuple[List[list], List[str]]:
  183. """
  184. 根据全局OCR结果推断缺失的单元格
  185. Args:
  186. detected_cells: 已检测到的单元格坐标 [[x1,y1,x2,y2], ...]
  187. overall_ocr_boxes: 全局OCR框坐标
  188. overall_ocr_texts: 全局OCR文本
  189. table_box: 表格区域 [x1,y1,x2,y2]
  190. Returns:
  191. 补全后的单元格列表
  192. """
  193. import copy
  194. # 1. 找出未被覆盖的OCR框
  195. uncovered_ocr_boxes = []
  196. uncovered_ocr_texts = []
  197. for ocr_box, ocr_text in zip(overall_ocr_boxes, overall_ocr_texts):
  198. # 计算OCR框中心点
  199. ocr_cx = (ocr_box[0] + ocr_box[2]) / 2
  200. ocr_cy = (ocr_box[1] + ocr_box[3]) / 2
  201. # 检查是否被任何单元格覆盖
  202. is_covered = False
  203. for cell in detected_cells:
  204. if cell[0] <= ocr_cx <= cell[2] and cell[1] <= ocr_cy <= cell[3]:
  205. is_covered = True
  206. break
  207. if not is_covered:
  208. uncovered_ocr_boxes.append(ocr_box)
  209. uncovered_ocr_texts.append(ocr_text)
  210. if not uncovered_ocr_boxes:
  211. return detected_cells, cells_texts_list # 没有漏检
  212. # 2. 按行分组已检测的单元格
  213. cells_sorted = sorted(detected_cells, key=lambda c: (c[1], c[0]))
  214. rows = []
  215. if cells_sorted:
  216. current_row = [cells_sorted[0]]
  217. row_y = (cells_sorted[0][1] + cells_sorted[0][3]) / 2
  218. row_height = cells_sorted[0][3] - cells_sorted[0][1]
  219. for cell in cells_sorted[1:]:
  220. cell_y = (cell[1] + cell[3]) / 2
  221. if abs(cell_y - row_y) < row_height * 0.7:
  222. current_row.append(cell)
  223. else:
  224. rows.append(current_row)
  225. current_row = [cell]
  226. row_y = (cell[1] + cell[3]) / 2
  227. row_height = cell[3] - cell[1]
  228. rows.append(current_row)
  229. # 3. 为每个未覆盖的OCR框推断单元格
  230. inferred_cells = []
  231. inferred_texts = []
  232. for ocr_box, ocr_text in zip(uncovered_ocr_boxes, uncovered_ocr_texts):
  233. ocr_cy = (ocr_box[1] + ocr_box[3]) / 2
  234. # 找到OCR框所在的行
  235. target_row_idx = None
  236. for i, row_cells in enumerate(rows):
  237. row_y1 = min(c[1] for c in row_cells)
  238. row_y2 = max(c[3] for c in row_cells)
  239. if row_y1 <= ocr_cy <= row_y2:
  240. target_row_idx = i
  241. break
  242. if target_row_idx is None:
  243. # 无法确定所属行,跳过
  244. print(f"⚠️ 无法为OCR文本 '{ocr_text}' 确定所属行")
  245. continue
  246. target_row = rows[target_row_idx]
  247. # 4. 推断单元格边界
  248. # 上下边界:使用该行的统一高度
  249. cell_y1 = min(c[1] for c in target_row)
  250. cell_y2 = max(c[3] for c in target_row)
  251. # 左右边界:根据OCR框位置和相邻单元格推断
  252. ocr_cx = (ocr_box[0] + ocr_box[2]) / 2
  253. # 找左边最近的单元格
  254. left_cells = [c for c in target_row if c[2] < ocr_cx]
  255. if left_cells:
  256. cell_x1 = max(c[2] for c in left_cells) # 左边单元格的右边界
  257. else:
  258. cell_x1 = table_box[0] # 表格左边界
  259. # 找右边最近的单元格
  260. right_cells = [c for c in target_row if c[0] > ocr_cx]
  261. if right_cells:
  262. cell_x2 = min(c[0] for c in right_cells) # 右边单元格的左边界
  263. else:
  264. cell_x2 = table_box[2] # 表格右边界
  265. # 创建推断的单元格
  266. inferred_cell = [cell_x1, cell_y1, cell_x2, cell_y2]
  267. inferred_cells.append(inferred_cell)
  268. inferred_texts.append(ocr_text)
  269. print(f"✅ 为OCR文本 '{ocr_text}' 推断单元格: {inferred_cell}")
  270. # 5. 合并检测到的和推断的单元格
  271. all_cells = detected_cells + inferred_cells
  272. all_texts = cells_texts_list + inferred_texts
  273. return all_cells, all_texts
  274. def enhanced_predict_single_table_recognition_res(
  275. self,
  276. image_array: np.ndarray,
  277. overall_ocr_res: OCRResult,
  278. table_box: list,
  279. use_e2e_wired_table_rec_model: bool = False,
  280. use_e2e_wireless_table_rec_model: bool = False,
  281. use_wired_table_cells_trans_to_html: bool = False,
  282. use_wireless_table_cells_trans_to_html: bool = False,
  283. use_ocr_results_with_table_cells: bool = True,
  284. flag_find_nei_text: bool = True,
  285. ) -> SingleTableRecognitionResult:
  286. """增强版方法 - 使用OCR引导的单元格补全"""
  287. print(">>> [Adapter] enhanced_predict_single_table_recognition_res called")
  288. # 🎯 Step 1: 获取table_cells_result (原始逻辑)
  289. table_cls_pred = list(self.table_cls_model(image_array))[0]
  290. table_cls_result = self.extract_results(table_cls_pred, "cls")
  291. if table_cls_result == "wired_table":
  292. table_cells_pred = list(self.wired_table_cells_detection_model(image_array, threshold=0.3))[0]
  293. else: # wireless_table
  294. table_cells_pred = list(self.wireless_table_cells_detection_model(image_array, threshold=0.3))[0]
  295. table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
  296. table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
  297. table_cells_result.sort(key=lambda c: (c[1], c[0]))
  298. # 🎯 Step 2: 坐标转换
  299. from paddlex.inference.pipelines.table_recognition.table_recognition_post_processing_v2 import (
  300. convert_to_four_point_coordinates,
  301. convert_table_structure_pred_bbox,
  302. get_sub_regions_ocr_res
  303. )
  304. import numpy as np
  305. # 转换为4点坐标
  306. table_cells_result_4pt = convert_to_four_point_coordinates(table_cells_result)
  307. # 准备坐标转换参数
  308. table_box_array = np.array([table_box])
  309. crop_start_point = [table_box[0], table_box[1]]
  310. img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
  311. # 转换到原图坐标系
  312. table_cells_result_orig = convert_table_structure_pred_bbox(
  313. table_cells_result_4pt, crop_start_point, img_shape
  314. )
  315. # 处理NumPy数组
  316. if isinstance(table_cells_result_orig, np.ndarray):
  317. table_cells_result_orig = table_cells_result_orig.tolist()
  318. table_cells_result_orig.sort(key=lambda c: (c[1], c[0]))
  319. # 🎯 Step 3: 获取表格区域的OCR结果
  320. table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box_array)
  321. # 🎯 Step 4: **关键改进** - OCR引导的单元格补全
  322. if (use_wired_table_cells_trans_to_html or use_wireless_table_cells_trans_to_html) and use_ocr_results_with_table_cells:
  323. # ✅ 修复: 确保 general_ocr_pipeline 被初始化
  324. if self.general_ocr_pipeline is None:
  325. if hasattr(self, 'general_ocr_config_bak') and self.general_ocr_config_bak is not None:
  326. print("🔧 [Adapter] Initializing general_ocr_pipeline from backup config")
  327. self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)
  328. else:
  329. print("⚠️ [Adapter] No OCR pipeline available, falling back to original implementation")
  330. return _original_predict_single(
  331. self, image_array, overall_ocr_res, table_box,
  332. use_e2e_wired_table_rec_model, use_e2e_wireless_table_rec_model,
  333. use_wired_table_cells_trans_to_html, use_wireless_table_cells_trans_to_html,
  334. use_ocr_results_with_table_cells, flag_find_nei_text
  335. )
  336. # ✅ 对每个单元格做OCR(使用裁剪前的坐标)
  337. cells_texts_list = self.gen_ocr_with_table_cells(image_array, table_cells_result)
  338. # ✅ 补全缺失的单元格
  339. completed_cells, cells_texts_list = infer_missing_cells_from_ocr(
  340. detected_cells=table_cells_result_orig,
  341. cells_texts_list=cells_texts_list,
  342. overall_ocr_boxes=table_ocr_pred["rec_boxes"],
  343. overall_ocr_texts=table_ocr_pred["rec_texts"],
  344. table_box=table_box
  345. )
  346. # ✅ 生成HTML骨架(使用转换后的原图坐标)
  347. html_skeleton = build_robust_html_from_cells(completed_cells)
  348. # ✅ 填充内容(使用单元格bbox和单元格OCR文本)
  349. pred_html = fill_html_with_ocr_by_bbox(
  350. html_skeleton,
  351. completed_cells, # ✅ 单元格bbox
  352. cells_texts_list # ✅ 单元格OCR文本
  353. )
  354. single_img_res = {
  355. "cell_box_list": completed_cells,
  356. "table_ocr_pred": table_ocr_pred, # 保留完整OCR信息
  357. "pred_html": pred_html,
  358. }
  359. res = SingleTableRecognitionResult(single_img_res)
  360. res["neighbor_texts"] = ""
  361. return res
  362. else:
  363. print(f"⚠️ Fallback to original implementation: {table_cls_result}")
  364. return _original_predict_single(
  365. self, image_array, overall_ocr_res, table_box,
  366. use_e2e_wired_table_rec_model, use_e2e_wireless_table_rec_model,
  367. use_wired_table_cells_trans_to_html, use_wireless_table_cells_trans_to_html,
  368. use_ocr_results_with_table_cells, flag_find_nei_text
  369. )
  370. def apply_table_recognition_adapter():
  371. """
  372. 应用表格识别适配器。
  373. 我们直接替换 _TableRecognitionPipelineV2 类中的 `predict_single_table_recognition_res` 方法。
  374. """
  375. global _original_predict_single
  376. try:
  377. # 导入目标类
  378. from paddlex.inference.pipelines.table_recognition.pipeline_v2 import _TableRecognitionPipelineV2
  379. # 保存原函数,防止重复应用补丁
  380. if _original_predict_single is None:
  381. _original_predict_single = _TableRecognitionPipelineV2.predict_single_table_recognition_res
  382. # 替换为增强版
  383. _TableRecognitionPipelineV2.predict_single_table_recognition_res = enhanced_predict_single_table_recognition_res
  384. print("✅ Table recognition adapter applied successfully (v3 - corrected).")
  385. return True
  386. except Exception as e:
  387. print(f"❌ Failed to apply table recognition adapter: {e}")
  388. return False
  389. def restore_original_function():
  390. """恢复原始函数"""
  391. global _original_predict_single
  392. try:
  393. from paddlex.inference.pipelines.table_recognition.pipeline_v2 import _TableRecognitionPipelineV2
  394. if _original_predict_single is not None:
  395. _TableRecognitionPipelineV2.predict_single_table_recognition_res = _original_predict_single
  396. _original_predict_single = None # 重置状态
  397. print("✅ Original function restored.")
  398. return True
  399. return False
  400. except Exception as e:
  401. print(f"❌ Failed to restore original function: {e}")
  402. return False