Преглед на файлове

fix text_paragraphs_ocr_res in layout parsing results (#3287)

* fix text_paragraphs_ocr_res in layout parsing results

* fix text_paragraphs_ocr_res in layout parsing results

* fix text_paragraphs_ocr_res in layout parsing results
dyning преди 9 месеца
родител
ревизия
2648806e83

+ 9 - 0
api_examples/pipelines/test_layout_parsing.py

@@ -52,6 +52,15 @@ output = pipeline.predict(
 #     use_table_recognition=True,
 # )
 
+# output = pipeline.predict(
+#     "./test_samples/layout_double_column.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_common_ocr=True,
+#     use_seal_recognition=True,
+#     use_table_recognition=True,
+# )
+
 for res in output:
     res.print()
     res.save_to_img("./output")

+ 2 - 2
paddlex/configs/pipelines/OCR.yaml

@@ -35,10 +35,10 @@ SubModules:
     module_name: textline_orientation
     model_name: PP-LCNet_x0_25_textline_ori 
     model_dir: null
-    batch_size: 1    
+    batch_size: 6    
   TextRecognition:
     module_name: text_recognition
     model_name: PP-OCRv4_mobile_rec 
     model_dir: null
-    batch_size: 1
+    batch_size: 6
     score_thresh: 0.0

+ 2 - 2
paddlex/configs/pipelines/layout_parsing.yaml

@@ -5,7 +5,7 @@ use_doc_preprocessor: True
 use_general_ocr: True
 use_seal_recognition: True
 use_table_recognition: True
-use_formula_recognition: True
+use_formula_recognition: False
 
 SubModules:
   LayoutDetection:
@@ -48,7 +48,7 @@ SubPipelines:
         module_name: text_recognition
         model_name: PP-OCRv4_server_rec
         model_dir: null
-        batch_size: 1
+        batch_size: 6
         score_thresh: 0
 
   TableRecognition:

+ 45 - 64
paddlex/inference/pipelines/layout_parsing/pipeline.py

@@ -147,29 +147,6 @@ class LayoutParsingPipeline(BasePipeline):
 
         return
 
-    def get_text_paragraphs_ocr_res(
-        self, overall_ocr_res: OCRResult, layout_det_res: DetResult
-    ) -> OCRResult:
-        """
-        Retrieves the OCR results for text paragraphs, excluding those of formulas, tables, and seals.
-
-        Args:
-            overall_ocr_res (OCRResult): The overall OCR result containing text information.
-            layout_det_res (DetResult): The detection result containing the layout information of the document.
-
-        Returns:
-            OCRResult: The OCR result for text paragraphs after excluding formulas, tables, and seals.
-        """
-        object_boxes = []
-        for box_info in layout_det_res["boxes"]:
-            if box_info["label"].lower() in ["formula", "table", "seal"]:
-                object_boxes.append(box_info["coordinate"])
-        object_boxes = np.array(object_boxes)
-        sub_regions_ocr_res = get_sub_regions_ocr_res(
-            overall_ocr_res, object_boxes, flag_within=False
-        )
-        return sub_regions_ocr_res
-
     def get_layout_parsing_res(
         self,
         image: list,
@@ -204,32 +181,53 @@ class LayoutParsingPipeline(BasePipeline):
             list: A list of dictionaries representing the layout parsing result.
         """
         layout_parsing_res = []
-        sub_image_list = []
         matched_ocr_dict = {}
-        sub_image_region_id = 0
         formula_index = 0
         table_index = 0
         seal_index = 0
         image = np.array(image)
-        image_labels = ["image", "figure", "img", "fig"]
         object_boxes = []
         for object_box_idx, box_info in enumerate(layout_det_res["boxes"]):
             single_box_res = {}
             box = box_info["coordinate"]
             label = box_info["label"].lower()
-            single_box_res["layout_bbox"] = box
+            single_box_res["block_bbox"] = box
+            single_box_res["block_label"] = label
+            single_box_res["block_content"] = ""
             object_boxes.append(box)
-            if label == "formula" and len(formula_res_list) > formula_index:
-                single_box_res["formula"] = formula_res_list[formula_index][
-                    "rec_formula"
-                ]
-                formula_index += 1
-            elif label == "table" and len(table_res_list) > table_index:
-                single_box_res["table"] = table_res_list[table_index]["pred_html"]
-                table_index += 1
-            elif label == "seal" and len(seal_res_list) > seal_index:
-                single_box_res["seal"] = "".join(seal_res_list[seal_index]["rec_texts"])
-                seal_index += 1
+            if label == "formula":
+                if len(formula_res_list) > 0:
+                    assert (
+                        len(formula_res_list) > formula_index
+                    ), f"The number of \
+                        formula regions of layout parsing pipeline \
+                        and formula recognition pipeline are different!"
+                    single_box_res["block_content"] = formula_res_list[formula_index][
+                        "rec_formula"
+                    ]
+                    formula_index += 1
+            elif label == "table":
+                if len(table_res_list) > 0:
+                    assert (
+                        len(table_res_list) > table_index
+                    ), f"The number of \
+                        table regions of layout parsing pipeline \
+                        and table recognition pipeline are different!"
+                    single_box_res["block_content"] = table_res_list[table_index][
+                        "pred_html"
+                    ]
+                    table_index += 1
+            elif label == "seal":
+                if len(seal_res_list) > 0:
+                    assert (
+                        len(seal_res_list) > seal_index
+                    ), f"The number of \
+                        seal regions of layout parsing pipeline \
+                        and seal recognition pipeline are different!"
+                    single_box_res["block_content"] = ", ".join(
+                        seal_res_list[seal_index]["rec_texts"]
+                    )
+                    seal_index += 1
             else:
                 ocr_res_in_box, matched_idxs = get_sub_regions_ocr_res(
                     overall_ocr_res, [box], return_match_idx=True
@@ -239,23 +237,14 @@ class LayoutParsingPipeline(BasePipeline):
                         matched_ocr_dict[matched_idx] = [object_box_idx]
                     else:
                         matched_ocr_dict[matched_idx].append(object_box_idx)
-                if label in image_labels:
-                    crop_img_info = self._crop_by_boxes(image, [box_info])
-                    crop_img_info = crop_img_info[0]
-                    sub_image_list.append(crop_img_info["img"])
-                    single_box_res[f"{label}_text"] = "\n".join(
-                        ocr_res_in_box["rec_texts"]
-                    )
-                else:
-                    single_box_res["text"] = "\n".join(ocr_res_in_box["rec_texts"])
-            if single_box_res:
-                layout_parsing_res.append(single_box_res)
+                single_box_res["block_content"] = "\n".join(ocr_res_in_box["rec_texts"])
+            layout_parsing_res.append(single_box_res)
         for layout_box_ids in matched_ocr_dict.values():
             # one ocr is matched to multiple layout boxes, split the text into multiple lines
             if len(layout_box_ids) > 1:
                 for idx in layout_box_ids:
                     wht_im = np.ones(image.shape, dtype=image.dtype) * 255
-                    box = layout_parsing_res[idx]["layout_bbox"]
+                    box = layout_parsing_res[idx]["block_bbox"]
                     x1, y1, x2, y2 = [int(i) for i in box]
                     wht_im[y1:y2, x1:x2, :] = image[y1:y2, x1:x2, :]
                     sub_ocr_res = next(
@@ -269,7 +258,7 @@ class LayoutParsingPipeline(BasePipeline):
                             text_rec_score_thresh=text_rec_score_thresh,
                         )
                     )
-                    layout_parsing_res[idx]["text"] = "\n".join(
+                    layout_parsing_res[idx]["block_content"] = "\n".join(
                         sub_ocr_res["rec_texts"]
                     )
 
@@ -281,13 +270,14 @@ class LayoutParsingPipeline(BasePipeline):
             ocr_without_layout_boxes["rec_boxes"], ocr_without_layout_boxes["rec_texts"]
         ):
             single_box_res = {}
-            single_box_res["layout_bbox"] = ocr_rec_box
-            single_box_res["text_without_layout"] = ocr_rec_text
+            single_box_res["block_bbox"] = ocr_rec_box
+            single_box_res["block_label"] = "other_text"
+            single_box_res["block_content"] = ocr_rec_text
             layout_parsing_res.append(single_box_res)
 
         layout_parsing_res = sorted_layout_boxes(layout_parsing_res, w=image.shape[1])
 
-        return layout_parsing_res, sub_image_list
+        return layout_parsing_res
 
     def check_model_settings_valid(self, input_params: Dict) -> bool:
         """
@@ -500,13 +490,6 @@ class LayoutParsingPipeline(BasePipeline):
             else:
                 overall_ocr_res = {}
 
-            if model_settings["use_general_ocr"]:
-                text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
-                    overall_ocr_res, layout_det_res
-                )
-            else:
-                text_paragraphs_ocr_res = {}
-
             if model_settings["use_table_recognition"]:
                 table_res_all = next(
                     self.table_recognition_pipeline(
@@ -557,7 +540,7 @@ class LayoutParsingPipeline(BasePipeline):
             else:
                 formula_res_list = []
 
-            parsing_res_list, sub_image_list = self.get_layout_parsing_res(
+            parsing_res_list = self.get_layout_parsing_res(
                 doc_preprocessor_image,
                 layout_det_res=layout_det_res,
                 overall_ocr_res=overall_ocr_res,
@@ -578,12 +561,10 @@ class LayoutParsingPipeline(BasePipeline):
                 "doc_preprocessor_res": doc_preprocessor_res,
                 "layout_det_res": layout_det_res,
                 "overall_ocr_res": overall_ocr_res,
-                "text_paragraphs_ocr_res": text_paragraphs_ocr_res,
                 "table_res_list": table_res_list,
                 "seal_res_list": seal_res_list,
                 "formula_res_list": formula_res_list,
                 "parsing_res_list": parsing_res_list,
                 "model_settings": model_settings,
-                "sub_image_list": sub_image_list,
             }
             yield LayoutParsingResult(single_img_res)

+ 0 - 37
paddlex/inference/pipelines/layout_parsing/result.py

@@ -49,16 +49,6 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
             res_img_dict["overall_ocr_res"] = self["overall_ocr_res"].img["ocr_res_img"]
 
-        if model_settings["use_general_ocr"]:
-            general_ocr_res = copy.deepcopy(self["overall_ocr_res"])
-            general_ocr_res["rec_polys"] = self["text_paragraphs_ocr_res"]["rec_polys"]
-            general_ocr_res["rec_texts"] = self["text_paragraphs_ocr_res"]["rec_texts"]
-            general_ocr_res["rec_scores"] = self["text_paragraphs_ocr_res"][
-                "rec_scores"
-            ]
-            general_ocr_res["rec_boxes"] = self["text_paragraphs_ocr_res"]["rec_boxes"]
-            res_img_dict["text_paragraphs_ocr_res"] = general_ocr_res.img["ocr_res_img"]
-
         if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
             table_cell_img = Image.fromarray(
                 copy.deepcopy(self["doc_preprocessor_res"]["output_img"])
@@ -94,15 +84,6 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
                 key = f"formula_res_region{formula_region_id}"
                 res_img_dict[key] = sub_formula_res_dict["res"]
 
-        if len(self["sub_image_list"]) > 0:
-            for sno in range(len(self["sub_image_list"])):
-                sub_region_image = Image.fromarray(
-                    copy.deepcopy(self["sub_image_list"][sno])
-                )
-                sub_region_image_id = sno + 1
-                key = f"sub_region_image{sub_region_image_id}"
-                res_img_dict[key] = sub_region_image
-
         return res_img_dict
 
     def _to_str(self, *args, **kwargs) -> Dict[str, str]:
@@ -126,15 +107,6 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         data["layout_det_res"] = self["layout_det_res"].str["res"]
         if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
             data["overall_ocr_res"] = self["overall_ocr_res"].str["res"]
-        if model_settings["use_general_ocr"]:
-            general_ocr_res = {}
-            general_ocr_res["rec_polys"] = self["text_paragraphs_ocr_res"]["rec_polys"]
-            general_ocr_res["rec_texts"] = self["text_paragraphs_ocr_res"]["rec_texts"]
-            general_ocr_res["rec_scores"] = self["text_paragraphs_ocr_res"][
-                "rec_scores"
-            ]
-            general_ocr_res["rec_boxes"] = self["text_paragraphs_ocr_res"]["rec_boxes"]
-            data["text_paragraphs_ocr_res"] = general_ocr_res
         if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
             data["table_res_list"] = []
             for sno in range(len(self["table_res_list"])):
@@ -178,15 +150,6 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         data["layout_det_res"] = self["layout_det_res"].json["res"]
         if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
             data["overall_ocr_res"] = self["overall_ocr_res"].json["res"]
-        if model_settings["use_general_ocr"]:
-            general_ocr_res = {}
-            general_ocr_res["rec_polys"] = self["text_paragraphs_ocr_res"]["rec_polys"]
-            general_ocr_res["rec_texts"] = self["text_paragraphs_ocr_res"]["rec_texts"]
-            general_ocr_res["rec_scores"] = self["text_paragraphs_ocr_res"][
-                "rec_scores"
-            ]
-            general_ocr_res["rec_boxes"] = self["text_paragraphs_ocr_res"]["rec_boxes"]
-            data["text_paragraphs_ocr_res"] = general_ocr_res
         if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
             data["table_res_list"] = []
             for sno in range(len(self["table_res_list"])):

+ 6 - 10
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -131,11 +131,10 @@ def sorted_layout_boxes(res, w):
     """
     num_boxes = len(res)
     if num_boxes == 1:
-        res[0]["layout"] = "single"
         return res
 
     # Sort on the y axis first or sort it on the x axis
-    sorted_boxes = sorted(res, key=lambda x: (x["layout_bbox"][1], x["layout_bbox"][0]))
+    sorted_boxes = sorted(res, key=lambda x: (x["block_bbox"][1], x["block_bbox"][0]))
     _boxes = list(sorted_boxes)
 
     new_res = []
@@ -148,27 +147,24 @@ def sorted_layout_boxes(res, w):
             break
         # Check that the bbox is on the left
         elif (
-            _boxes[i]["layout_bbox"][0] < w / 4
-            and _boxes[i]["layout_bbox"][2] < 3 * w / 5
+            _boxes[i]["block_bbox"][0] < w / 4
+            and _boxes[i]["block_bbox"][2] < 3 * w / 5
         ):
-            _boxes[i]["layout"] = "double"
             res_left.append(_boxes[i])
             i += 1
-        elif _boxes[i]["layout_bbox"][0] > 2 * w / 5:
-            _boxes[i]["layout"] = "double"
+        elif _boxes[i]["block_bbox"][0] > 2 * w / 5:
             res_right.append(_boxes[i])
             i += 1
         else:
             new_res += res_left
             new_res += res_right
-            _boxes[i]["layout"] = "single"
             new_res.append(_boxes[i])
             res_left = []
             res_right = []
             i += 1
 
-    res_left = sorted(res_left, key=lambda x: (x["layout_bbox"][1]))
-    res_right = sorted(res_right, key=lambda x: (x["layout_bbox"][1]))
+    res_left = sorted(res_left, key=lambda x: (x["block_bbox"][1]))
+    res_right = sorted(res_right, key=lambda x: (x["block_bbox"][1]))
 
     if res_left:
         new_res += res_left