瀏覽代碼

[Fix bug] fix bugs for table-rec-pipe_v2 (#2961)

* fix bugs

* refine codes

* fix bugs
Liu Jiaxuan 9 月之前
父節點
當前提交
e136336a72

+ 8 - 9
paddlex/inference/pipelines/layout_parsing/result_v2.py

@@ -15,6 +15,7 @@ from __future__ import annotations
 
 import copy
 from pathlib import Path
+from PIL import Image, ImageDraw
 from typing import Dict
 
 import cv2
@@ -69,20 +70,18 @@ class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
             res_img_dict["text_paragraphs_ocr_res"] = general_ocr_res.img["ocr_res_img"]
 
         if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
-            table_cell_img = copy.deepcopy(
-                self["doc_preprocessor_res"]["output_img"],
+            table_cell_img = Image.fromarray(
+                copy.deepcopy(self["doc_preprocessor_res"]["output_img"])
             )
+            table_draw = ImageDraw.Draw(table_cell_img)
+            rectangle_color = (255, 0, 0)
             for sno in range(len(self["table_res_list"])):
                 table_res = self["table_res_list"][sno]
                 cell_box_list = table_res["cell_box_list"]
                 for box in cell_box_list:
-                    x1, y1, x2, y2 = (int(pos) for pos in box)
-                    cv2.rectangle(
-                        table_cell_img,
-                        (x1, y1),
-                        (x2, y2),
-                        (255, 0, 0),
-                        2,
+                    x1, y1, x2, y2 = [int(pos) for pos in box]
+                    table_draw.rectangle(
+                        [x1, y1, x2, y2], outline=rectangle_color, width=2
                     )
             res_img_dict["table_cell_img"] = table_cell_img
 

+ 47 - 3
paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py

@@ -32,8 +32,10 @@ def get_ori_image_coordinate(x: int, y: int, box_list: list) -> list:
     """
     if not box_list:
         return box_list
-    offset = np.array([x, y] * 2)
+    offset = np.array([x, y] * 4)
     box_list = np.array(box_list)
+    if box_list.shape[-1] == 2:
+        offset = offset.reshape(4, 2)
     ori_box_list = offset + box_list
     return ori_box_list
 
@@ -264,6 +266,46 @@ def sort_table_cells_boxes(boxes):
     return sorted_boxes
 
 
+def convert_to_four_point_coordinates(boxes):
+    """
+    Convert bounding boxes from [x1, y1, x2, y2] format to 
+    [x1, y1, x2, y1, x2, y2, x1, y2] format.
+
+    Parameters:
+    - boxes: A list of bounding boxes, each defined as a list of integers 
+             in the format [x1, y1, x2, y2].
+
+    Returns:
+    - A list of bounding boxes, each converted to the format 
+      [x1, y1, x2, y1, x2, y2, x1, y2].
+    """
+    # Initialize an empty list to store the converted bounding boxes
+    converted_boxes = []
+
+    # Loop over each box in the input list
+    for box in boxes:
+        x1, y1, x2, y2 = box
+        
+        # Define the four corner points
+        top_left = (x1, y1)
+        top_right = (x2, y1)
+        bottom_right = (x2, y2)
+        bottom_left = (x1, y2)
+
+        # Create a new list for the converted box
+        converted_box = [
+            top_left[0], top_left[1],  # Top-left corner
+            top_right[0], top_right[1],  # Top-right corner
+            bottom_right[0], bottom_right[1],  # Bottom-right corner
+            bottom_left[0], bottom_left[1]   # Bottom-left corner
+        ]
+
+        # Append the converted box to the list
+        converted_boxes.append(converted_box)
+
+    return converted_boxes
+
+
 def get_table_recognition_res(
     table_box: list,
     table_structure_result: list,
@@ -282,13 +324,15 @@ def get_table_recognition_res(
     Returns:
         SingleTableRecognitionResult: An object containing the single table recognition result.
     """
+    table_cells_result =convert_to_four_point_coordinates(table_cells_result)
+
     table_box = np.array([table_box])
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
 
     crop_start_point = [table_box[0][0], table_box[0][1]]
     img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
 
-    ori_table_cells = convert_table_structure_pred_bbox(
+    table_cells_result = convert_table_structure_pred_bbox(
         table_cells_result, crop_start_point, img_shape
     )
 
@@ -302,7 +346,7 @@ def get_table_recognition_res(
     pred_html = get_html_result(matched_index, ocr_texts_res, table_structure_result)
 
     single_img_res = {
-        "cell_box_list": ori_table_cells,
+        "cell_box_list": table_cells_result,
         "table_ocr_pred": table_ocr_pred,
         "pred_html": pred_html,
     }