浏览代码

feat: 修复 get_html_result 行列对齐问题,使用 bisect 定位行列,增加基于单元格的结构骨架回退机制

zhch158_admin 1 月之前
父节点
当前提交
ff03ac7b08
共有 1 个文件被更改,包括 310 次插入0 次删除
  1. 310 0
      zhch/tabble_recognition_post_processing_V2 bugfix.md

+ 310 - 0
zhch/tabble_recognition_post_processing_V2 bugfix.md

@@ -0,0 +1,310 @@
+## get_html_result(并引入 bisect)
+你现在在 get_html_result 里用 matched_list_index/td_index 手动推进行列指针,遇到 table_cells_flag 出现重复(0 长度行)或结构标注与检测边界错位时很容易“对不齐”。虽然此时 matched_index 已经包含 {0:[0], 1:[6], 2:[57]},但行列指针没有严格按全局 td 序号映射到“第几行第几列”,导致没有拿到 57。
+
+把指针推进方式改为“用全局单元格序号 td_global,通过二分在 table_cells_flag 中定位行,再计算列”。这样就不会受 0 长度行、重复 flag 等影响,保证能取到 ocr_texts_res[57]。
+
+请按下面修改 get_html_result(并引入 bisect),其余逻辑保持不变。
+
+````python
+# ...existing code...
+import math
++import bisect
+# ...existing code...
+
+-def get_html_result(
+-    all_matched_index: dict, ocr_contents: dict, pred_structures: list, table_cells_flag
+-) -> str:
++def get_html_result(
++    all_matched_index: dict, ocr_contents: dict, pred_structures: list, table_cells_flag
++) -> str:
+     """
+     Generates HTML content based on the matched index, OCR contents, and predicted structures.
+     """
+-    pred_html = []
+-    td_index = 0
+-    td_count = 0
+-    matched_list_index = 0
++    pred_html = []
++    # 全局单元格序号,从 0 开始,和 table_cells_flag 使用同一坐标系
++    td_global = 0
+
+     head_structure = pred_structures[0:3]
+     html = "".join(head_structure)
+     table_structure = pred_structures[3:-3]
+     for tag in table_structure:
+-        if "</td>" in tag:
+-            # 预对齐:处理 table_cells_flag 中可能出现的“0 长度行”
+-            while (
+-                matched_list_index < len(all_matched_index) - 1
+-                and td_count >= table_cells_flag[matched_list_index + 1]
+-            ):
+-                matched_list_index += 1
+-                td_index = 0
+-
+-            matched_index = all_matched_index[matched_list_index]
++        if "</td>" in tag:
++            # 通过全局 td 序号定位当前“行索引”和“列索引”
++            # table_cells_flag 是每行起始单元格的前缀计数(已 append 了末尾)
++            row_idx = max(0, bisect.bisect_right(table_cells_flag, td_global) - 1)
++            col_idx = td_global - table_cells_flag[row_idx]
++            matched_index = all_matched_index[row_idx] if row_idx < len(all_matched_index) else {}
+
+             if "<td></td>" == tag:
+-                pred_html.extend("<td>")
+-            if td_index in matched_index.keys():
+-                if len(matched_index[td_index]) == 0:
++                pred_html.append("<td>")
++            if col_idx in matched_index.keys():
++                if len(matched_index[col_idx]) == 0:
+                     continue
+                 b_with = False
+                 if (
+-                    "<b>" in ocr_contents[matched_index[td_index][0]]
+-                    and len(matched_index[td_index]) > 1
++                    "<b>" in ocr_contents[matched_index[col_idx][0]]
++                    and len(matched_index[col_idx]) > 1
+                 ):
+                     b_with = True
+-                    pred_html.extend("<b>")
+-                for i, td_index_index in enumerate(matched_index[td_index]):
++                    pred_html.append("<b>")
++                for i, td_index_index in enumerate(matched_index[col_idx]):
+                     content = ocr_contents[td_index_index]
+-                    if len(matched_index[td_index]) > 1:
++                    if len(matched_index[col_idx]) > 1:
+                         if len(content) == 0:
+                             continue
+                         if content[0] == " ":
+                             content = content[1:]
+                         if "<b>" in content:
+                             content = content[3:]
+                         if "</b>" in content:
+                             content = content[:-4]
+                         if len(content) == 0:
+                             continue
+-                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
++                        if i != len(matched_index[col_idx]) - 1 and " " != content[-1]:
+                             content += " "
+-                    pred_html.extend(content)
++                    pred_html.append(content)
+                 if b_with:
+-                    pred_html.extend("</b>")
++                    pred_html.append("</b>")
+             if "<td></td>" == tag:
+                 pred_html.append("</td>")
+             else:
+                 pred_html.append(tag)
+-            td_index += 1
+-            td_count += 1
+-            # 事后对齐:连续 0 长度行时继续前进
+-            while (
+-                matched_list_index < len(all_matched_index) - 1
+-                and td_count >= table_cells_flag[matched_list_index + 1]
+-            ):
+-                matched_list_index += 1
+-                td_index = 0
++            # 推进到下一个全局单元格
++            td_global += 1
+         else:
+             pred_html.append(tag)
+     html += "".join(pred_html)
+     end_structure = pred_structures[-3:]
+     html += "".join(end_structure)
+     return html
+````
+
+要点
+- 用 bisect_right 在 table_cells_flag 中定位当前 td_global 属于哪一行,再用 col_idx = td_global - table_cells_flag[row_idx] 得到列索引。不会再受 0 长度行、重复 flag 的影响。
+- 将原先对字符串用 extend 的地方改为 append,避免把字符串拆成字符列表(功能不受影响,但更稳)。
+- 这样在你给出的状态 row_idx=2、matched_index={0:[0],1:[6],2:[57]}、col_idx=2 时,会稳定取到 ocr_texts_res[57]“上期发生额”。
+
+另外建议
+- 若单元格内有折行,请在调用管线时把 use_table_cells_split_ocr 关掉,使用整格 OCR,避免漏行问题:
+  - outputs = pipeline.predict(..., use_ocr_results_with_table_cells=True, use_table_cells_split_ocr=False)
+- 如果是有线表格且单元格检测很准,可保留 use_wired_table_cells_trans_to_html=True;否则可先置为 False 看对齐是否更稳。
+
+
+## 对齐检测列数与结构列数,若不一致则回退到基于cells的结构骨架
+**build_structure_from_cells bugfix**
+
+你现在看到 row_idx=3、matched_index={0:[0],1:[6],2:[57]},但 col_idx 只有 1,没有走到 2,根因是“行内单元格数量不一致”:
+- matched_index 是按“单元格检测”得到的分组,这一行有 3 个格(第 2 个里是 57→‘上期发生额’)。
+- 而 pred_structures(表格结构预测)在该行只产出了 2 个 <td>,get_html_result 是沿着 pred_structures 逐个 <td> 写内容的,所以第三个格没有“落位”,自然取不到 57。
+
+解决思路
+- 当“结构预测的列数”与“单元格检测的列数”不一致时,改用“按单元格检测结果生成 HTML 骨架”,再把 matched_index 内容填进去。这样每个检测到的格都有对应的 <td> 槽位。
+
+提交补丁
+- 在 table_recognition_post_processing_v2.py 中增加一个用 cells 构建表格骨架的函数,并在不一致时回退到该骨架。
+
+````python
+# ...existing code...
+import math
+import bisect
+import numpy as np
+# ...existing code...
+
++def build_structure_from_cells(table_cells_flag: list) -> list:
++    """
++    用单元格检测的行起始标记(前缀和)构造一个简单的表格结构tokens:
++    head(3项) + [<tr>, <td></td>*n, </tr>]*R + end(3项)
++    """
++    head = ["<html>", "<body>", "<table>"]
++    body = []
++    for r in range(len(table_cells_flag) - 1):
++        body.append("<tr>")
++        cols = table_cells_flag[r + 1] - table_cells_flag[r]
++        for _ in range(cols):
++            body.append("<td></td>")
++        body.append("</tr>")
++    end = ["</table>", "</body>", "</html>"]
++    return head + body + end
+
+# ...existing get_html_result stays as you pasted (用 bisect + td_global 的版本)...
+
+def get_table_recognition_res(
+    table_box: list,
+    table_structure_result: list,
+    table_cells_result: list,
+    overall_ocr_res: OCRResult,
+    table_ocr_pred: dict,
+    cells_texts_list: list,
+    use_table_cells_ocr_results: bool,
+    use_table_cells_split_ocr: bool,
+) -> SingleTableRecognitionResult:
+    # ...existing code above...
+    table_cells_result, table_cells_flag = sort_table_cells_boxes(table_cells_result)
+    row_start_index = find_row_start_index(table_structure_result)
+    table_cells_flag = map_and_get_max(table_cells_flag, row_start_index)
+    table_cells_flag.append(len(table_cells_result))
+    row_start_index.append(len(table_cells_result))
+    matched_index = match_table_and_ocr(
+        table_cells_result, ocr_dt_boxes, table_cells_flag, table_cells_flag
+    )
+-    pred_html = get_html_result(
+-        matched_index, ocr_texts_res, table_structure_result, row_start_index
+-    )
++    # 对齐检测列数与结构列数,若不一致则回退到基于cells的结构骨架
++    use_cells_skeleton = False
++    for i in range(len(table_cells_flag) - 1):
++        cols_cells = table_cells_flag[i + 1] - table_cells_flag[i]
++        cols_struct = row_start_index[i + 1] - row_start_index[i]
++        if cols_cells != cols_struct:
++            use_cells_skeleton = True
++            break
++
++    if use_cells_skeleton:
++        skeleton = build_structure_from_cells(table_cells_flag)
++        pred_html = get_html_result(
++            matched_index, ocr_texts_res, skeleton, table_cells_flag
++        )
++    else:
++        pred_html = get_html_result(
++            matched_index, ocr_texts_res, table_structure_result, row_start_index
++        )
+    # ...existing code below...
+````
+
+要点
+- 当结构预测的 <td> 数量比单元格检测少时(你当前的情况),用 cells 构造的骨架就会生成足够的 <td>,此时 col_idx 会走到 2,自然能取到 ocr_texts_res[57]‘上期发生额’。
+- get_html_result 里使用 bisect 在行起始数组里快速定位行索引 row_idx 和列索引 col_idx,避免 0 长度行或重复标记引起的错位。
+
+运行建议
+- 若页面有折行文本,建议在调用管线时关闭单元格内拆分重识别,避免丢字:
+  - outputs = pipeline.predict(..., use_ocr_results_with_table_cells=True, use_table_cells_split_ocr=False)
+- 有线表格且单元格检测很准时,可开启 use_wired_table_cells_trans_to_html=True;若仍有不一致,可先设为 False 观察。
+
+这样改后,即便结构预测与单元格检测不完全一致,也能保证每个检测到的单元格都能被填充到 HTML 中,不会再出现“第三格的 57 没有位置写入”的问题。
+
+
+## 下面给出在本地生成补丁、推送到远端并提交 PR 的完整流程(Mac 终端命令)。任选其一方式:生成 .patch 文件交付,或直接走 GitHub PR。
+
+一、在本地创建修复分支并提交代码
+````bash
+# 1) 在仓库根目录
+cd /Users/zhch158/workspace/repository.git/PaddleX
+
+# 2) 取最新主分支
+git checkout main
+git pull
+
+# 3) 新建修复分支
+git checkout -b fix/table_rec_v2-bisect-skeleton
+
+# 4) 编辑并保存你的修改
+#   - paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py
+#   - zhch/tabble_recognition_post_processing_V2 bugfix.md(可选:作为说明)
+# 保存后查看改动
+git status
+git diff
+
+# 5) 添加到暂存区并提交
+git add paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py \
+        "zhch/tabble_recognition_post_processing_V2 bugfix.md"
+
+git commit -m "table_recognition_v2: fix cell-to-HTML alignment
+- use bisect on table_cells_flag to locate row/col
+- fallback to cells-based skeleton when cols mismatch
+- keep multi-line text by avoiding split OCR loss"
+
+# 6) 本地验证(可选)
+python -m pip install -e .
+python zhch/test_table_recognition_v2.py
+````
+
+二、生成补丁文件(用于邮件/离线提交)
+- 简单 diff 补丁:
+````bash
+git diff main...fix/table_rec_v2-bisect-skeleton > table_rec_v2_fix.patch
+````
+
+- 邮件式补丁(带提交信息,适合上游用 git am 合并):
+````bash
+git format-patch -1 HEAD -o ./patches
+# 生成的 ./patches/0001-table_recognition_v2-xxx.patch 即可提交
+````
+
+- 接收方应用补丁:
+````bash
+# 简单补丁
+git apply table_rec_v2_fix.patch
+
+# 邮件式补丁
+git am ./patches/0001-table_recognition_v2-xxx.patch
+````
+
+三、推送分支并提交 GitHub PR(推荐)
+````bash
+# 1) 确认 origin 指向你的 Fork;若没有 Fork,请先在 GitHub 页面 fork 官方仓库
+git remote -v
+
+# 2) 推送分支到你的远端
+git push -u origin fix/table_rec_v2-bisect-skeleton
+
+# 3) 打开 GitHub,针对该分支创建 Pull Request 到上游的 main 分支
+#    PR 描述建议包含:
+#    - 问题现象与复现步骤(含示例图片/JSON、td_global/row_idx/col_idx日志)
+#    - 解决方案概要(bisect 定位行列 + cells skeleton 回退)
+#    - 关键代码位置 diff 链接
+#    - 前后对比截图(HTML/Markdown中“上期发生额”能正确落位)
+#    - 风险评估与开关说明(use_table_cells_split_ocr 建议)
+````
+
+四、PR 内容建议
+- 标题:table_recognition_v2: fix get_html_result row/col mismatch with bisect; fallback to cells skeleton
+- 关联合并的说明文档:zhch/tabble_recognition_post_processing_V2 bugfix.md
+- 附加材料:
+  - 最小复现图与对应 table_cells_flag、matched_index、td_global 日志
+  - 修复前后 HTML 片段对比(包含“上期发生额”示例)
+  - 若加了新参数或回退逻辑,写明默认行为不变
+
+五、若仅在你本仓使用(不对外 PR)
+- 直接把 patch 存在 zhch/patches 目录,并在 README 中记录应用方法;
+- 或将分支合并回你仓库 main:
+````bash
+git checkout main
+git merge --no-ff fix/table_rec_v2-bisect-skeleton -m "Merge: table_recognition_v2 bugfix"
+git push
+````
+
+如需我帮你把当前修改生成规范的 patch/PR 模板描述,或补上对应的代码块文件,请告知你最终决定的提交方式(.patch 还是 PR)。我也可以按你的提交信息直接生成 format-patch 文件内容。