Browse Source

add layout parsing result (#2934)

changdazhou 10 months ago
parent
commit
64420d27e0

+ 1 - 2
api_examples/pipelines/test_layout_parsing.py

@@ -14,7 +14,7 @@
 
 
 from paddlex import create_pipeline
 from paddlex import create_pipeline
 
 
-pipeline = create_pipeline(pipeline="layout_parsing")
+pipeline = create_pipeline(pipeline_name="layout_parsing")
 
 
 output = pipeline.predict(
 output = pipeline.predict(
     "./test_samples/demo_paper.png",
     "./test_samples/demo_paper.png",
@@ -53,7 +53,6 @@ output = pipeline.predict(
 # )
 # )
 
 
 for res in output:
 for res in output:
-    print(res)
     res.print()
     res.print()
     res.save_to_img("./output")
     res.save_to_img("./output")
     res.save_to_json("./output")
     res.save_to_json("./output")

+ 169 - 7
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -12,12 +12,11 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from typing import Any, Dict, Optional, Union, List
-import os, sys
+from email.mime import image
+from typing import Any, Dict, Optional, Union, List, Tuple
 import numpy as np
 import numpy as np
-import cv2
 from ..base import BasePipeline
 from ..base import BasePipeline
-from .utils import get_sub_regions_ocr_res
+from .utils import get_sub_regions_ocr_res, sorted_layout_boxes
 from ..components import convert_points_to_boxes
 from ..components import convert_points_to_boxes
 from .result import LayoutParsingResult
 from .result import LayoutParsingResult
 from ....utils import logging
 from ....utils import logging
@@ -91,6 +90,22 @@ class LayoutParsingPipeline(BasePipeline):
             {"model_config_error": "config error for layout_det_model!"},
             {"model_config_error": "config error for layout_det_model!"},
         )
         )
         self.layout_det_model = self.create_model(layout_det_config)
         self.layout_det_model = self.create_model(layout_det_config)
+        layout_kwargs = {}
+        if (threshold := layout_det_config.get("threshold", None)) is not None:
+            layout_kwargs["threshold"] = threshold
+        if (layout_nms := layout_det_config.get("layout_nms", None)) is not None:
+            layout_kwargs["layout_nms"] = layout_nms
+        if (
+            layout_unclip_ratio := layout_det_config.get("layout_unclip_ratio", None)
+        ) is not None:
+            layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
+        if (
+            layout_merge_bboxes_mode := layout_det_config.get(
+                "layout_merge_bboxes_mode", None
+            )
+        ) is not None:
+            layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
+        self.layout_det_model = self.create_model(layout_det_config, **layout_kwargs)
 
 
         if self.use_general_ocr or self.use_table_recognition:
         if self.use_general_ocr or self.use_table_recognition:
             general_ocr_config = config.get("SubPipelines", {}).get(
             general_ocr_config = config.get("SubPipelines", {}).get(
@@ -152,7 +167,127 @@ class LayoutParsingPipeline(BasePipeline):
             if box_info["label"].lower() in ["formula", "table", "seal"]:
             if box_info["label"].lower() in ["formula", "table", "seal"]:
                 object_boxes.append(box_info["coordinate"])
                 object_boxes.append(box_info["coordinate"])
         object_boxes = np.array(object_boxes)
         object_boxes = np.array(object_boxes)
-        return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
+        sub_regions_ocr_res = get_sub_regions_ocr_res(
+            overall_ocr_res, object_boxes, flag_within=False
+        )
+        return sub_regions_ocr_res
+
+    def get_layout_parsing_res(
+        self,
+        image: list,
+        layout_det_res: DetResult,
+        overall_ocr_res: OCRResult,
+        table_res_list: list,
+        seal_res_list: list,
+        formula_res_list: list,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_rec_score_thresh: Optional[float] = None,
+    ) -> list:
+        """
+        Retrieves the layout parsing result based on the layout detection result, OCR result, and other recognition results.
+        Args:
+            image (list): The input image.
+            layout_det_res (DetResult): The detection result containing the layout information of the document.
+            overall_ocr_res (OCRResult): The overall OCR result containing text information.
+            table_res_list (list): A list of table recognition results.
+            seal_res_list (list): A list of seal recognition results.
+            formula_res_list (list): A list of formula recognition results.
+            text_det_limit_side_len (Optional[int], optional): The maximum side length of the text detection region. Defaults to None.
+            text_det_limit_type (Optional[str], optional): The type of limit for the text detection region. Defaults to None.
+            text_det_thresh (Optional[float], optional): The confidence threshold for text detection. Defaults to None.
+            text_det_box_thresh (Optional[float], optional): The confidence threshold for text detection bounding boxes. Defaults to None
+            text_det_unclip_ratio (Optional[float], optional): The unclip ratio for text detection. Defaults to None.
+            text_rec_score_thresh (Optional[float], optional): The score threshold for text recognition. Defaults to None.
+        Returns:
+            list: A list of dictionaries representing the layout parsing result.
+        """
+        layout_parsing_res = []
+        matched_ocr_dict = {}
+        formula_index = 0
+        table_index = 0
+        seal_index = 0
+        image = np.array(image)
+        image_labels = ["image", "figure", "img", "fig"]
+        object_boxes = []
+        for object_box_idx, box_info in enumerate(layout_det_res["boxes"]):
+            single_box_res = {}
+            box = box_info["coordinate"]
+            label = box_info["label"].lower()
+            single_box_res["layout_bbox"] = box
+            object_boxes.append(box)
+            if label == "formula":
+                single_box_res["formula"] = formula_res_list[formula_index][
+                    "rec_formula"
+                ]
+                formula_index += 1
+            elif label == "table":
+                single_box_res["table"] = table_res_list[table_index]["pred_html"]
+                table_index += 1
+            elif label == "seal":
+                single_box_res["seal"] = "".join(seal_res_list[seal_index]["rec_texts"])
+                seal_index += 1
+            else:
+                ocr_res_in_box, matched_idxs = get_sub_regions_ocr_res(
+                    overall_ocr_res, [box], return_match_idx=True
+                )
+                for matched_idx in matched_idxs:
+                    if matched_ocr_dict.get(matched_idx, None) is None:
+                        matched_ocr_dict[matched_idx] = [object_box_idx]
+                    else:
+                        matched_ocr_dict[matched_idx].append(object_box_idx)
+                if label in image_labels:
+                    x1, y1, x2, y2 = [int(i) for i in box]
+                    sub_image = image[y1:y2, x1:x2, :]
+                    single_box_res["image"] = sub_image
+                    single_box_res[f"{label}_text"] = "\n".join(
+                        ocr_res_in_box["rec_texts"]
+                    )
+                else:
+                    single_box_res["text"] = "\n".join(ocr_res_in_box["rec_texts"])
+            if single_box_res:
+                layout_parsing_res.append(single_box_res)
+        for layout_box_ids in matched_ocr_dict.values():
+            # one ocr is matched to multiple layout boxes, split the text into multiple lines
+            if len(layout_box_ids) > 1:
+                for idx in layout_box_ids:
+                    wht_im = np.ones(image.shape, dtype=image.dtype) * 255
+                    box = layout_parsing_res[idx]["layout_bbox"]
+                    x1, y1, x2, y2 = [int(i) for i in box]
+                    wht_im[y1:y2, x1:x2, :] = image[y1:y2, x1:x2, :]
+                    sub_ocr_res = next(
+                        self.general_ocr_pipeline(
+                            wht_im,
+                            text_det_limit_side_len=text_det_limit_side_len,
+                            text_det_limit_type=text_det_limit_type,
+                            text_det_thresh=text_det_thresh,
+                            text_det_box_thresh=text_det_box_thresh,
+                            text_det_unclip_ratio=text_det_unclip_ratio,
+                            text_rec_score_thresh=text_rec_score_thresh,
+                        )
+                    )
+                    layout_parsing_res[idx]["text"] = "\n".join(
+                        sub_ocr_res["rec_texts"]
+                    )
+
+        ocr_without_layout_boxes = get_sub_regions_ocr_res(
+            overall_ocr_res, object_boxes, flag_within=False
+        )
+
+        for ocr_rec_box, ocr_rec_text in zip(
+            ocr_without_layout_boxes["rec_boxes"], ocr_without_layout_boxes["rec_texts"]
+        ):
+            single_box_res = {}
+            single_box_res["layout_bbox"] = ocr_rec_box
+            single_box_res["text_without_layout"] = ocr_rec_text
+            layout_parsing_res.append(single_box_res)
+
+        layout_parsing_res = sorted_layout_boxes(layout_parsing_res, w=image.shape[1])
+
+        return layout_parsing_res
 
 
     def check_model_settings_valid(self, input_params: Dict) -> bool:
     def check_model_settings_valid(self, input_params: Dict) -> bool:
         """
         """
@@ -262,6 +397,10 @@ class LayoutParsingPipeline(BasePipeline):
         seal_det_box_thresh: Optional[float] = None,
         seal_det_box_thresh: Optional[float] = None,
         seal_det_unclip_ratio: Optional[float] = None,
         seal_det_unclip_ratio: Optional[float] = None,
         seal_rec_score_thresh: Optional[float] = None,
         seal_rec_score_thresh: Optional[float] = None,
+        layout_threshold: Optional[Union[float, dict]] = None,
+        layout_nms: Optional[bool] = None,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
         **kwargs,
         **kwargs,
     ) -> LayoutParsingResult:
     ) -> LayoutParsingResult:
         """
         """
@@ -308,7 +447,15 @@ class LayoutParsingPipeline(BasePipeline):
 
 
             doc_preprocessor_image = doc_preprocessor_res["output_img"]
             doc_preprocessor_image = doc_preprocessor_res["output_img"]
 
 
-            layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+            layout_det_res = next(
+                self.layout_det_model(
+                    doc_preprocessor_image,
+                    threshold=layout_threshold,
+                    layout_nms=layout_nms,
+                    layout_unclip_ratio=layout_unclip_ratio,
+                    layout_merge_bboxes_mode=layout_merge_bboxes_mode,
+                )
+            )
 
 
             if (
             if (
                 model_settings["use_general_ocr"]
                 model_settings["use_general_ocr"]
@@ -382,10 +529,24 @@ class LayoutParsingPipeline(BasePipeline):
                     )
                     )
                 )
                 )
                 formula_res_list = formula_res_all["formula_res_list"]
                 formula_res_list = formula_res_all["formula_res_list"]
-                print(formula_res_list)
             else:
             else:
                 formula_res_list = []
                 formula_res_list = []
 
 
+            parsing_res_list = self.get_layout_parsing_res(
+                doc_preprocessor_image,
+                layout_det_res=layout_det_res,
+                overall_ocr_res=overall_ocr_res,
+                table_res_list=table_res_list,
+                seal_res_list=seal_res_list,
+                formula_res_list=formula_res_list,
+                text_det_limit_side_len=text_det_limit_side_len,
+                text_det_limit_type=text_det_limit_type,
+                text_det_thresh=text_det_thresh,
+                text_det_box_thresh=text_det_box_thresh,
+                text_det_unclip_ratio=text_det_unclip_ratio,
+                text_rec_score_thresh=text_rec_score_thresh,
+            )
+
             single_img_res = {
             single_img_res = {
                 "input_path": batch_data.input_paths[0],
                 "input_path": batch_data.input_paths[0],
                 "page_index": batch_data.page_indexes[0],
                 "page_index": batch_data.page_indexes[0],
@@ -396,6 +557,7 @@ class LayoutParsingPipeline(BasePipeline):
                 "table_res_list": table_res_list,
                 "table_res_list": table_res_list,
                 "seal_res_list": seal_res_list,
                 "seal_res_list": seal_res_list,
                 "formula_res_list": formula_res_list,
                 "formula_res_list": formula_res_list,
+                "parsing_res_list": parsing_res_list,
                 "model_settings": model_settings,
                 "model_settings": model_settings,
             }
             }
             yield LayoutParsingResult(single_img_res)
             yield LayoutParsingResult(single_img_res)

+ 11 - 4
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -15,8 +15,8 @@
 import os
 import os
 from typing import Dict
 from typing import Dict
 import numpy as np
 import numpy as np
+from PIL import Image, ImageDraw
 import copy
 import copy
-import cv2
 from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, StrMixin, JsonMixin
 from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, StrMixin, JsonMixin
 
 
 
 
@@ -50,14 +50,19 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
             res_img_dict["text_paragraphs_ocr_res"] = general_ocr_res.img["ocr_res_img"]
             res_img_dict["text_paragraphs_ocr_res"] = general_ocr_res.img["ocr_res_img"]
 
 
         if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
         if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
-            table_cell_img = copy.deepcopy(self["doc_preprocessor_res"]["output_img"])
+            table_cell_img = Image.fromarray(
+                copy.deepcopy(self["doc_preprocessor_res"]["output_img"])
+            )
+            table_draw = ImageDraw.Draw(table_cell_img)
+            rectangle_color = (255, 0, 0)
             for sno in range(len(self["table_res_list"])):
             for sno in range(len(self["table_res_list"])):
                 table_res = self["table_res_list"][sno]
                 table_res = self["table_res_list"][sno]
                 cell_box_list = table_res["cell_box_list"]
                 cell_box_list = table_res["cell_box_list"]
                 for box in cell_box_list:
                 for box in cell_box_list:
                     x1, y1, x2, y2 = [int(pos) for pos in box]
                     x1, y1, x2, y2 = [int(pos) for pos in box]
-                    cv2.rectangle(table_cell_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
-            res_img_dict["table_cell_img"] = table_cell_img
+                    table_draw.rectangle(
+                        [x1, y1, x2, y2], outline=rectangle_color, width=2
+                    )
 
 
         if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
         if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
             for sno in range(len(self["seal_res_list"])):
             for sno in range(len(self["seal_res_list"])):
@@ -93,6 +98,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
         model_settings = self["model_settings"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["model_settings"] = model_settings
+        data["parsing_res_list"] = self["parsing_res_list"]
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
         data["layout_det_res"] = self["layout_det_res"].str["res"]
         data["layout_det_res"] = self["layout_det_res"].str["res"]
@@ -143,6 +149,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
         model_settings = self["model_settings"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["model_settings"] = model_settings
+        data["parsing_res_list"] = self["parsing_res_list"]
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
         data["layout_det_res"] = self["layout_det_res"].json["res"]
         data["layout_det_res"] = self["layout_det_res"].json["res"]

+ 72 - 4
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -17,6 +17,7 @@ __all__ = [
     "get_layout_ordering",
     "get_layout_ordering",
     "recursive_img_array2path",
     "recursive_img_array2path",
     "get_show_color",
     "get_show_color",
+    "sorted_layout_boxes",
 ]
 ]
 
 
 import numpy as np
 import numpy as np
@@ -27,6 +28,7 @@ from pathlib import Path
 from typing import List
 from typing import List
 from ..ocr.result import OCRResult
 from ..ocr.result import OCRResult
 from ...models_new.object_detection.result import DetResult
 from ...models_new.object_detection.result import DetResult
+from ..components import convert_points_to_boxes
 
 
 
 
 def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
 def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
@@ -36,9 +38,8 @@ def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
     Args:
     Args:
         src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
         src_boxes (np.ndarray): A 2D numpy array of source bounding boxes.
         ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
         ref_boxes (np.ndarray): A 2D numpy array of reference bounding boxes.
-
     Returns:
     Returns:
-        list: A list of indices of source boxes that overlap with any reference box.
+        match_idx_list (list): A list of indices of source boxes that overlap with reference boxes.
     """
     """
     match_idx_list = []
     match_idx_list = []
     src_boxes_num = len(src_boxes)
     src_boxes_num = len(src_boxes)
@@ -57,7 +58,10 @@ def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> List:
 
 
 
 
 def get_sub_regions_ocr_res(
 def get_sub_regions_ocr_res(
-    overall_ocr_res: OCRResult, object_boxes: List, flag_within: bool = True
+    overall_ocr_res: OCRResult,
+    object_boxes: List,
+    flag_within: bool = True,
+    return_match_idx: bool = False,
 ) -> OCRResult:
 ) -> OCRResult:
     """
     """
     Filters OCR results to only include text boxes within specified object boxes based on a flag.
     Filters OCR results to only include text boxes within specified object boxes based on a flag.
@@ -66,6 +70,7 @@ def get_sub_regions_ocr_res(
         overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
         overall_ocr_res (OCRResult): The original OCR result containing all text boxes.
         object_boxes (list): A list of bounding boxes for the objects of interest.
         object_boxes (list): A list of bounding boxes for the objects of interest.
         flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
         flag_within (bool): If True, only include text boxes within the object boxes. If False, exclude text boxes within the object boxes.
+        return_match_idx (bool): If True, return the list of matching indices.
 
 
     Returns:
     Returns:
         OCRResult: A filtered OCR result containing only the relevant text boxes.
         OCRResult: A filtered OCR result containing only the relevant text boxes.
@@ -103,7 +108,70 @@ def get_sub_regions_ocr_res(
             sub_regions_ocr_res["rec_boxes"].append(
             sub_regions_ocr_res["rec_boxes"].append(
                 overall_ocr_res["rec_boxes"][box_no]
                 overall_ocr_res["rec_boxes"][box_no]
             )
             )
-    return sub_regions_ocr_res
+    return (
+        (sub_regions_ocr_res, match_idx_list)
+        if return_match_idx
+        else sub_regions_ocr_res
+    )
+
+
+def sorted_layout_boxes(res, w):
+    """
+    Sort text boxes in order from top to bottom, left to right
+    Args:
+        res: List of dictionaries containing layout information.
+        w: Width of image.
+
+    Returns:
+        List of dictionaries containing sorted layout information.
+    """
+    num_boxes = len(res)
+    if num_boxes == 1:
+        res[0]["layout"] = "single"
+        return res
+
+    # Sort on the y axis first or sort it on the x axis
+    sorted_boxes = sorted(res, key=lambda x: (x["layout_bbox"][1], x["layout_bbox"][0]))
+    _boxes = list(sorted_boxes)
+
+    new_res = []
+    res_left = []
+    res_right = []
+    i = 0
+
+    while True:
+        if i >= num_boxes:
+            break
+        # Check that the bbox is on the left
+        elif (
+            _boxes[i]["layout_bbox"][0] < w / 4
+            and _boxes[i]["layout_bbox"][2] < 3 * w / 5
+        ):
+            _boxes[i]["layout"] = "double"
+            res_left.append(_boxes[i])
+            i += 1
+        elif _boxes[i]["layout_bbox"][0] > 2 * w / 5:
+            _boxes[i]["layout"] = "double"
+            res_right.append(_boxes[i])
+            i += 1
+        else:
+            new_res += res_left
+            new_res += res_right
+            _boxes[i]["layout"] = "single"
+            new_res.append(_boxes[i])
+            res_left = []
+            res_right = []
+            i += 1
+
+    res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
+    res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
+
+    if res_left:
+        new_res += res_left
+    if res_right:
+        new_res += res_right
+
+    return new_res
 
 
 
 
 def _calculate_iou(box1, box2):
 def _calculate_iou(box1, box2):