Bläddra i källkod

feat: 优化HTML生成逻辑,新增构建表格结构的函数以支持不一致的列数

zhch158_admin 1 månad sedan
förälder
incheckning
8e62656be4

+ 58 - 27
paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import math
-
+import bisect
 import numpy as np
 
 from ..components import convert_points_to_boxes
@@ -203,30 +203,35 @@ def get_html_result(
         str: Generated HTML content as a string.
     """
     pred_html = []
-    td_index = 0
-    td_count = 0
-    matched_list_index = 0
+    # 全局单元格序号,从 0 开始,和 table_cells_flag 使用同一坐标系
+    td_global = 0
+
     head_structure = pred_structures[0:3]
     html = "".join(head_structure)
     table_structure = pred_structures[3:-3]
     for tag in table_structure:
-        matched_index = all_matched_index[matched_list_index]
         if "</td>" in tag:
+            # 通过全局 td 序号定位当前“行索引”和“列索引”
+            # table_cells_flag 是每行起始单元格的前缀计数(已 append 了末尾)
+            row_idx = max(0, bisect.bisect_right(table_cells_flag, td_global) - 1)
+            col_idx = td_global - table_cells_flag[row_idx]
+            matched_index = all_matched_index[row_idx] if row_idx < len(all_matched_index) else {}
+
             if "<td></td>" == tag:
-                pred_html.extend("<td>")
-            if td_index in matched_index.keys():
-                if len(matched_index[td_index]) == 0:
+                pred_html.append("<td>")
+            if col_idx in matched_index.keys():
+                if len(matched_index[col_idx]) == 0:
                     continue
                 b_with = False
                 if (
-                    "<b>" in ocr_contents[matched_index[td_index][0]]
-                    and len(matched_index[td_index]) > 1
+                    "<b>" in ocr_contents[matched_index[col_idx][0]]
+                    and len(matched_index[col_idx]) > 1
                 ):
                     b_with = True
-                    pred_html.extend("<b>")
-                for i, td_index_index in enumerate(matched_index[td_index]):
+                    pred_html.append("<b>")
+                for i, td_index_index in enumerate(matched_index[col_idx]):
                     content = ocr_contents[td_index_index]
-                    if len(matched_index[td_index]) > 1:
+                    if len(matched_index[col_idx]) > 1:
                         if len(content) == 0:
                             continue
                         if content[0] == " ":
@@ -237,23 +242,17 @@ def get_html_result(
                             content = content[:-4]
                         if len(content) == 0:
                             continue
-                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+                        if i != len(matched_index[col_idx]) - 1 and " " != content[-1]:
                             content += " "
-                    pred_html.extend(content)
+                    pred_html.append(content)
                 if b_with:
-                    pred_html.extend("</b>")
+                    pred_html.append("</b>")
             if "<td></td>" == tag:
                 pred_html.append("</td>")
             else:
                 pred_html.append(tag)
-            td_index += 1
-            td_count += 1
-            if (
-                td_count >= table_cells_flag[matched_list_index + 1]
-                and matched_list_index < len(all_matched_index) - 1
-            ):
-                matched_list_index += 1
-                td_index = 0
+            # 推进到下一个全局单元格
+            td_global += 1
         else:
             pred_html.append(tag)
     html += "".join(pred_html)
@@ -408,6 +407,23 @@ def map_and_get_max(table_cells_flag, row_start_index):
     return max_values
 
 
+def build_structure_from_cells(table_cells_flag: list) -> list:
+    """
+    用单元格检测的行起始标记(前缀和)构造一个简单的表格结构tokens:
+    head(3项) + [<tr>, <td></td>*n, </tr>]*R + end(3项)
+    """
+    head = ["<html>", "<body>", "<table>"]
+    body = []
+    for r in range(len(table_cells_flag) - 1):
+        body.append("<tr>")
+        cols = table_cells_flag[r + 1] - table_cells_flag[r]
+        for _ in range(cols):
+            body.append("<td></td>")
+        body.append("</tr>")
+    end = ["</table>", "</body>", "</html>"]
+    return head + body + end
+
+
 def get_table_recognition_res(
     table_box: list,
     table_structure_result: list,
@@ -475,9 +491,24 @@ def get_table_recognition_res(
     matched_index = match_table_and_ocr(
         table_cells_result, ocr_dt_boxes, table_cells_flag, table_cells_flag
     )
-    pred_html = get_html_result(
-        matched_index, ocr_texts_res, table_structure_result, row_start_index
-    )
+    # 对齐检测列数与结构列数,若不一致则回退到基于cells的结构骨架
+    use_cells_skeleton = False
+    for i in range(len(table_cells_flag) - 1):
+        cols_cells = table_cells_flag[i + 1] - table_cells_flag[i]
+        cols_struct = row_start_index[i + 1] - row_start_index[i]
+        if cols_cells != cols_struct:
+            use_cells_skeleton = True
+            break
+
+    if use_cells_skeleton:
+        skeleton = build_structure_from_cells(table_cells_flag)
+        pred_html = get_html_result(
+            matched_index, ocr_texts_res, skeleton, table_cells_flag
+        )
+    else:
+        pred_html = get_html_result(
+            matched_index, ocr_texts_res, table_structure_result, row_start_index
+        )
 
     single_img_res = {
         "cell_box_list": table_cells_result,