table_recognition_post_processing_V2 bugfix.md 14 KB

get_html_result(并引入 bisect)

你现在在 get_html_result 里用 matched_list_index/td_index 手动推进行列指针,遇到 table_cells_flag 出现重复(0 长度行)或结构标注与检测边界错位时很容易“对不齐”。虽然此时 matched_index 已经包含 {0:[0], 1:[6], 2:[57]},但行列指针没有严格按全局 td 序号映射到“第几行第几列”,导致没有拿到 57。

把指针推进方式改为“用全局单元格序号 td_global,通过二分在 table_cells_flag 中定位行,再计算列”。这样就不会受 0 长度行、重复 flag 等影响,保证能取到 ocr_texts_res[57]。

请按下面修改 get_html_result(并引入 bisect),其余逻辑保持不变。

# ...existing code...
import math
+import bisect
# ...existing code...

-def get_html_result(
-    all_matched_index: dict, ocr_contents: dict, pred_structures: list, table_cells_flag
-) -> str:
+def get_html_result(
+    all_matched_index: dict, ocr_contents: dict, pred_structures: list, table_cells_flag
+) -> str:
     """
     Generates HTML content based on the matched index, OCR contents, and predicted structures.
     """
-    pred_html = []
-    td_index = 0
-    td_count = 0
-    matched_list_index = 0
+    pred_html = []
+    # 全局单元格序号,从 0 开始,和 table_cells_flag 使用同一坐标系
+    td_global = 0

     head_structure = pred_structures[0:3]
     html = "".join(head_structure)
     table_structure = pred_structures[3:-3]
     for tag in table_structure:
-        if "</td>" in tag:
-            # 预对齐:处理 table_cells_flag 中可能出现的“0 长度行”
-            while (
-                matched_list_index < len(all_matched_index) - 1
-                and td_count >= table_cells_flag[matched_list_index + 1]
-            ):
-                matched_list_index += 1
-                td_index = 0
-
-            matched_index = all_matched_index[matched_list_index]
+        if "</td>" in tag:
+            # 通过全局 td 序号定位当前“行索引”和“列索引”
+            # table_cells_flag 是每行起始单元格的前缀计数(已 append 了末尾)
+            row_idx = max(0, bisect.bisect_right(table_cells_flag, td_global) - 1)
+            col_idx = td_global - table_cells_flag[row_idx]
+            matched_index = all_matched_index[row_idx] if row_idx < len(all_matched_index) else {}

             if "<td></td>" == tag:
-                pred_html.extend("<td>")
-            if td_index in matched_index.keys():
-                if len(matched_index[td_index]) == 0:
+                pred_html.append("<td>")
+            if col_idx in matched_index.keys():
+                if len(matched_index[col_idx]) == 0:
                     continue
                 b_with = False
                 if (
-                    "<b>" in ocr_contents[matched_index[td_index][0]]
-                    and len(matched_index[td_index]) > 1
+                    "<b>" in ocr_contents[matched_index[col_idx][0]]
+                    and len(matched_index[col_idx]) > 1
                 ):
                     b_with = True
-                    pred_html.extend("<b>")
-                for i, td_index_index in enumerate(matched_index[td_index]):
+                    pred_html.append("<b>")
+                for i, td_index_index in enumerate(matched_index[col_idx]):
                     content = ocr_contents[td_index_index]
-                    if len(matched_index[td_index]) > 1:
+                    if len(matched_index[col_idx]) > 1:
                         if len(content) == 0:
                             continue
                         if content[0] == " ":
                             content = content[1:]
                         if "<b>" in content:
                             content = content[3:]
                         if "</b>" in content:
                             content = content[:-4]
                         if len(content) == 0:
                             continue
-                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+                        if i != len(matched_index[col_idx]) - 1 and " " != content[-1]:
                             content += " "
-                    pred_html.extend(content)
+                    pred_html.append(content)
                 if b_with:
-                    pred_html.extend("</b>")
+                    pred_html.append("</b>")
             if "<td></td>" == tag:
                 pred_html.append("</td>")
             else:
                 pred_html.append(tag)
-            td_index += 1
-            td_count += 1
-            # 事后对齐:连续 0 长度行时继续前进
-            while (
-                matched_list_index < len(all_matched_index) - 1
-                and td_count >= table_cells_flag[matched_list_index + 1]
-            ):
-                matched_list_index += 1
-                td_index = 0
+            # 推进到下一个全局单元格
+            td_global += 1
         else:
             pred_html.append(tag)
     html += "".join(pred_html)
     end_structure = pred_structures[-3:]
     html += "".join(end_structure)
     return html

要点

  • 用 bisect_right 在 table_cells_flag 中定位当前 td_global 属于哪一行,再用 col_idx = td_global - table_cells_flag[row_idx] 得到列索引。不会再受 0 长度行、重复 flag 的影响。
  • 将原先对字符串用 extend 的地方改为 append,避免把字符串拆成字符列表(功能不受影响,但更稳)。
  • 这样在你给出的状态 row_idx=2、matched_index={0:[0],1:[6],2:[57]}、col_idx=2 时,会稳定取到 ocr_texts_res[57]“上期发生额”。

另外建议

  • 若单元格内有折行,请在调用管线时把 use_table_cells_split_ocr 关掉,使用整格 OCR,避免漏行问题:
    • outputs = pipeline.predict(..., use_ocr_results_with_table_cells=True, use_table_cells_split_ocr=False)
  • 如果是有线表格且单元格检测很准,可保留 use_wired_table_cells_trans_to_html=True;否则可先置为 False 看对齐是否更稳。

对齐检测列数与结构列数,若不一致则回退到基于cells的结构骨架

build_structure_from_cells bugfix

你现在看到 row_idx=3、matched_index={0:[0],1:[6],2:[57]},但 col_idx 只有 1,没有走到 2,根因是“行内单元格数量不一致”:

  • matched_index 是按“单元格检测”得到的分组,这一行有 3 个格(第 2 个里是 57→‘上期发生额’)。
  • 而 pred_structures(表格结构预测)在该行只产出了 2 个 ,get_html_result 是沿着 pred_structures 逐个 写内容的,所以第三个格没有“落位”,自然取不到 57。

    解决思路

    • 当“结构预测的列数”与“单元格检测的列数”不一致时,改用“按单元格检测结果生成 HTML 骨架”,再把 matched_index 内容填进去。这样每个检测到的格都有对应的 槽位。
    • 提交补丁

      • 在 table_recognition_post_processing_v2.py 中增加一个用 cells 构建表格骨架的函数,并在不一致时回退到该骨架。

        # ...existing code...
        import math
        import bisect
        import numpy as np
        # ...existing code...
        
        +def build_structure_from_cells(table_cells_flag: list) -> list:
        +    """
        +    用单元格检测的行起始标记(前缀和)构造一个简单的表格结构tokens:
        +    head(3项) + [<tr>, <td></td>*n, </tr>]*R + end(3项)
        +    """
        +    head = ["<html>", "<body>", "<table>"]
        +    body = []
        +    for r in range(len(table_cells_flag) - 1):
        +        body.append("<tr>")
        +        cols = table_cells_flag[r + 1] - table_cells_flag[r]
        +        for _ in range(cols):
        +            body.append("<td></td>")
        +        body.append("</tr>")
        +    end = ["</table>", "</body>", "</html>"]
        +    return head + body + end
        
        # ...existing get_html_result stays as you pasted (用 bisect + td_global 的版本)...
        
        def get_table_recognition_res(
        table_box: list,
        table_structure_result: list,
        table_cells_result: list,
        overall_ocr_res: OCRResult,
        table_ocr_pred: dict,
        cells_texts_list: list,
        use_table_cells_ocr_results: bool,
        use_table_cells_split_ocr: bool,
        ) -> SingleTableRecognitionResult:
        # ...existing code above...
        table_cells_result, table_cells_flag = sort_table_cells_boxes(table_cells_result)
        row_start_index = find_row_start_index(table_structure_result)
        table_cells_flag = map_and_get_max(table_cells_flag, row_start_index)
        table_cells_flag.append(len(table_cells_result))
        row_start_index.append(len(table_cells_result))
        matched_index = match_table_and_ocr(
            table_cells_result, ocr_dt_boxes, table_cells_flag, table_cells_flag
        )
        -    pred_html = get_html_result(
        -        matched_index, ocr_texts_res, table_structure_result, row_start_index
        -    )
        +    # 对齐检测列数与结构列数,若不一致则回退到基于cells的结构骨架
        +    use_cells_skeleton = False
        +    for i in range(len(table_cells_flag) - 1):
        +        cols_cells = table_cells_flag[i + 1] - table_cells_flag[i]
        +        cols_struct = row_start_index[i + 1] - row_start_index[i]
        +        if cols_cells != cols_struct:
        +            use_cells_skeleton = True
        +            break
        +
        +    if use_cells_skeleton:
        +        skeleton = build_structure_from_cells(table_cells_flag)
        +        pred_html = get_html_result(
        +            matched_index, ocr_texts_res, skeleton, table_cells_flag
        +        )
        +    else:
        +        pred_html = get_html_result(
        +            matched_index, ocr_texts_res, table_structure_result, row_start_index
        +        )
        # ...existing code below...
        

      要点

      • 当结构预测的 数量比单元格检测少时(你当前的情况),用 cells 构造的骨架就会生成足够的 ,此时 col_idx 会走到 2,自然能取到 ocr_texts_res[57]‘上期发生额’。
      • get_html_result 里使用 bisect 在行起始数组里快速定位行索引 row_idx 和列索引 col_idx,避免 0 长度行或重复标记引起的错位。
      • 运行建议

        • 若页面有折行文本,建议在调用管线时关闭单元格内拆分重识别,避免丢字:
          • outputs = pipeline.predict(..., use_ocr_results_with_table_cells=True, use_table_cells_split_ocr=False)
        • 有线表格且单元格检测很准时,可开启 use_wired_table_cells_trans_to_html=True;若仍有不一致,可先设为 False 观察。

        这样改后,即便结构预测与单元格检测不完全一致,也能保证每个检测到的单元格都能被填充到 HTML 中,不会再出现“第三格的 57 没有位置写入”的问题。

        下面给出在本地生成补丁、推送到远端并提交 PR 的完整流程(Mac 终端命令)。任选其一方式:生成 .patch 文件交付,或直接走 GitHub PR。

        一、在本地创建修复分支并提交代码

        # 1) 在仓库根目录
        cd /Users/zhch158/workspace/repository.git/PaddleX
        
        # 2) 取最新主分支
        git checkout main
        git pull
        
        # 3) 新建修复分支
        git checkout -b fix/table_rec_v2-bisect-skeleton
        
        # 4) 编辑并保存你的修改
        #   - paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py
        #   - zhch/tabble_recognition_post_processing_V2 bugfix.md(可选:作为说明)
        # 保存后查看改动
        git status
        git diff
        
        # 5) 添加到暂存区并提交
        git add paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py \
                "zhch/tabble_recognition_post_processing_V2 bugfix.md"
        
        git commit -m "table_recognition_v2: fix cell-to-HTML alignment
        - use bisect on table_cells_flag to locate row/col
        - fallback to cells-based skeleton when cols mismatch
        - keep multi-line text by avoiding split OCR loss"
        
        # 6) 本地验证(可选)
        python -m pip install -e .
        python zhch/test_table_recognition_v2.py
        

        二、生成补丁文件(用于邮件/离线提交)

        • 简单 diff 补丁:

          git diff main...fix/table_rec_v2-bisect-skeleton > table_rec_v2_fix.patch
          
        • 邮件式补丁(带提交信息,适合上游用 git am 合并):

          git format-patch -1 HEAD -o ./patches
          # 生成的 ./patches/0001-table_recognition_v2-xxx.patch 即可提交
          
        • 接收方应用补丁:

          # 简单补丁
          git apply table_rec_v2_fix.patch
          
          # 邮件式补丁
          git am ./patches/0001-table_recognition_v2-xxx.patch
          

        三、推送分支并提交 GitHub PR(推荐)

        # 1) 确认 origin 指向你的 Fork;若没有 Fork,请先在 GitHub 页面 fork 官方仓库
        git remote -v
        
        # 2) 推送分支到你的远端
        git push -u origin fix/table_rec_v2-bisect-skeleton
        
        # 3) 打开 GitHub,针对该分支创建 Pull Request 到上游的 main 分支
        #    PR 描述建议包含:
        #    - 问题现象与复现步骤(含示例图片/JSON、td_global/row_idx/col_idx日志)
        #    - 解决方案概要(bisect 定位行列 + cells skeleton 回退)
        #    - 关键代码位置 diff 链接
        #    - 前后对比截图(HTML/Markdown中“上期发生额”能正确落位)
        #    - 风险评估与开关说明(use_table_cells_split_ocr 建议)
        

        四、PR 内容建议

        • 标题:table_recognition_v2: fix get_html_result row/col mismatch with bisect; fallback to cells skeleton
        • 关联合并的说明文档:zhch/tabble_recognition_post_processing_V2 bugfix.md
        • 附加材料:
          • 最小复现图与对应 table_cells_flag、matched_index、td_global 日志
          • 修复前后 HTML 片段对比(包含“上期发生额”示例)
          • 若加了新参数或回退逻辑,写明默认行为不变

        五、若仅在你本仓使用(不对外 PR)

        • 直接把 patch 存在 zhch/patches 目录,并在 README 中记录应用方法;
        • 或将分支合并回你仓库 main:

          git checkout main
          git merge --no-ff fix/table_rec_v2-bisect-skeleton -m "Merge: table_recognition_v2 bugfix"
          git push
          

        如需我帮你把当前修改生成规范的 patch/PR 模板描述,或补上对应的代码块文件,请告知你最终决定的提交方式(.patch 还是 PR)。我也可以按你的提交信息直接生成 format-patch 文件内容。