Browse Source

add use_textline_orientation to layout parsing (#2968)

changdazhou 9 months ago
parent
commit
6f6a7c13b0

+ 2 - 0
paddlex/inference/pipelines/doc_preprocessor/result.py

@@ -79,6 +79,7 @@ class DocPreprocessorResult(BaseCVResult):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["model_settings"] = self["model_settings"]
         data["angle"] = self["angle"]
         data["angle"] = self["angle"]
         return JsonMixin._to_str(data, *args, **kwargs)
         return JsonMixin._to_str(data, *args, **kwargs)
@@ -96,6 +97,7 @@ class DocPreprocessorResult(BaseCVResult):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["model_settings"] = self["model_settings"]
         data["angle"] = self["angle"]
         data["angle"] = self["angle"]
         return JsonMixin._to_json(data, *args, **kwargs)
         return JsonMixin._to_json(data, *args, **kwargs)

+ 2 - 0
paddlex/inference/pipelines/formula_recognition/result.py

@@ -163,6 +163,7 @@ class FormulaRecognitionResult(BaseCVResult):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
@@ -195,6 +196,7 @@ class FormulaRecognitionResult(BaseCVResult):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = str(self["page_index"])
         data["model_settings"] = self["model_settings"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]

+ 48 - 21
paddlex/inference/pipelines/layout_parsing/pipeline.py

@@ -12,12 +12,11 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-from email.mime import image
-from typing import Any, Dict, Optional, Union, List, Tuple
+from typing import Dict, Optional, Union, List, Tuple
 import numpy as np
 import numpy as np
 from ..base import BasePipeline
 from ..base import BasePipeline
 from .utils import get_sub_regions_ocr_res, sorted_layout_boxes
 from .utils import get_sub_regions_ocr_res, sorted_layout_boxes
-from ..components import convert_points_to_boxes
+from ..components import CropByBoxes
 from .result import LayoutParsingResult
 from .result import LayoutParsingResult
 from ....utils import logging
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption
 from ...utils.pp_option import PaddlePredictorOption
@@ -56,6 +55,7 @@ class LayoutParsingPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.batch_sampler = ImageBatchSampler(batch_size=1)
 
 
         self.img_reader = ReadImage(format="BGR")
         self.img_reader = ReadImage(format="BGR")
+        self._crop_by_boxes = CropByBoxes()
 
 
     def inintial_predictor(self, config: Dict) -> None:
     def inintial_predictor(self, config: Dict) -> None:
         """Initializes the predictor based on the provided configuration.
         """Initializes the predictor based on the provided configuration.
@@ -88,7 +88,6 @@ class LayoutParsingPipeline(BasePipeline):
             "LayoutDetection",
             "LayoutDetection",
             {"model_config_error": "config error for layout_det_model!"},
             {"model_config_error": "config error for layout_det_model!"},
         )
         )
-        self.layout_det_model = self.create_model(layout_det_config)
         layout_kwargs = {}
         layout_kwargs = {}
         if (threshold := layout_det_config.get("threshold", None)) is not None:
         if (threshold := layout_det_config.get("threshold", None)) is not None:
             layout_kwargs["threshold"] = threshold
             layout_kwargs["threshold"] = threshold
@@ -205,7 +204,9 @@ class LayoutParsingPipeline(BasePipeline):
             list: A list of dictionaries representing the layout parsing result.
             list: A list of dictionaries representing the layout parsing result.
         """
         """
         layout_parsing_res = []
         layout_parsing_res = []
+        sub_image_list = []
         matched_ocr_dict = {}
         matched_ocr_dict = {}
+        sub_image_region_id = 0
         formula_index = 0
         formula_index = 0
         table_index = 0
         table_index = 0
         seal_index = 0
         seal_index = 0
@@ -218,15 +219,15 @@ class LayoutParsingPipeline(BasePipeline):
             label = box_info["label"].lower()
             label = box_info["label"].lower()
             single_box_res["layout_bbox"] = box
             single_box_res["layout_bbox"] = box
             object_boxes.append(box)
             object_boxes.append(box)
-            if label == "formula":
+            if label == "formula" and len(formula_res_list) > formula_index:
                 single_box_res["formula"] = formula_res_list[formula_index][
                 single_box_res["formula"] = formula_res_list[formula_index][
                     "rec_formula"
                     "rec_formula"
                 ]
                 ]
                 formula_index += 1
                 formula_index += 1
-            elif label == "table":
+            elif label == "table" and len(table_res_list) > table_index:
                 single_box_res["table"] = table_res_list[table_index]["pred_html"]
                 single_box_res["table"] = table_res_list[table_index]["pred_html"]
                 table_index += 1
                 table_index += 1
-            elif label == "seal":
+            elif label == "seal" and len(seal_res_list) > seal_index:
                 single_box_res["seal"] = "".join(seal_res_list[seal_index]["rec_texts"])
                 single_box_res["seal"] = "".join(seal_res_list[seal_index]["rec_texts"])
                 seal_index += 1
                 seal_index += 1
             else:
             else:
@@ -239,9 +240,9 @@ class LayoutParsingPipeline(BasePipeline):
                     else:
                     else:
                         matched_ocr_dict[matched_idx].append(object_box_idx)
                         matched_ocr_dict[matched_idx].append(object_box_idx)
                 if label in image_labels:
                 if label in image_labels:
-                    x1, y1, x2, y2 = [int(i) for i in box]
-                    sub_image = image[y1:y2, x1:x2, :]
-                    single_box_res["image"] = sub_image
+                    crop_img_info = self._crop_by_boxes(image, [box_info])
+                    crop_img_info = crop_img_info[0]
+                    sub_image_list.append(crop_img_info["img"])
                     single_box_res[f"{label}_text"] = "\n".join(
                     single_box_res[f"{label}_text"] = "\n".join(
                         ocr_res_in_box["rec_texts"]
                         ocr_res_in_box["rec_texts"]
                     )
                     )
@@ -286,7 +287,7 @@ class LayoutParsingPipeline(BasePipeline):
 
 
         layout_parsing_res = sorted_layout_boxes(layout_parsing_res, w=image.shape[1])
         layout_parsing_res = sorted_layout_boxes(layout_parsing_res, w=image.shape[1])
 
 
-        return layout_parsing_res
+        return layout_parsing_res, sub_image_list
 
 
     def check_model_settings_valid(self, input_params: Dict) -> bool:
     def check_model_settings_valid(self, input_params: Dict) -> bool:
         """
         """
@@ -380,10 +381,15 @@ class LayoutParsingPipeline(BasePipeline):
         input: Union[str, List[str], np.ndarray, List[np.ndarray]],
         input: Union[str, List[str], np.ndarray, List[np.ndarray]],
         use_doc_orientation_classify: Optional[bool] = None,
         use_doc_orientation_classify: Optional[bool] = None,
         use_doc_unwarping: Optional[bool] = None,
         use_doc_unwarping: Optional[bool] = None,
+        use_textline_orientation: Optional[bool] = None,
         use_general_ocr: Optional[bool] = None,
         use_general_ocr: Optional[bool] = None,
         use_seal_recognition: Optional[bool] = None,
         use_seal_recognition: Optional[bool] = None,
         use_table_recognition: Optional[bool] = None,
         use_table_recognition: Optional[bool] = None,
         use_formula_recognition: Optional[bool] = None,
         use_formula_recognition: Optional[bool] = None,
+        layout_threshold: Optional[Union[float, dict]] = None,
+        layout_nms: Optional[bool] = None,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
         text_det_limit_side_len: Optional[int] = None,
         text_det_limit_side_len: Optional[int] = None,
         text_det_limit_type: Optional[str] = None,
         text_det_limit_type: Optional[str] = None,
         text_det_thresh: Optional[float] = None,
         text_det_thresh: Optional[float] = None,
@@ -396,10 +402,6 @@ class LayoutParsingPipeline(BasePipeline):
         seal_det_box_thresh: Optional[float] = None,
         seal_det_box_thresh: Optional[float] = None,
         seal_det_unclip_ratio: Optional[float] = None,
         seal_det_unclip_ratio: Optional[float] = None,
         seal_rec_score_thresh: Optional[float] = None,
         seal_rec_score_thresh: Optional[float] = None,
-        layout_threshold: Optional[Union[float, dict]] = None,
-        layout_nms: Optional[bool] = None,
-        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
-        layout_merge_bboxes_mode: Optional[str] = None,
         **kwargs,
         **kwargs,
     ) -> LayoutParsingResult:
     ) -> LayoutParsingResult:
         """
         """
@@ -407,11 +409,34 @@ class LayoutParsingPipeline(BasePipeline):
 
 
         Args:
         Args:
             input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or pdf(s) to be processed.
             input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or pdf(s) to be processed.
-            use_doc_orientation_classify (bool): Whether to use document orientation classification.
-            use_doc_unwarping (bool): Whether to use document unwarping.
-            use_general_ocr (bool): Whether to use general OCR.
-            use_seal_recognition (bool): Whether to use seal recognition.
-            use_table_recognition (bool): Whether to use table recognition.
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
+            use_general_ocr (Optional[bool]): Whether to use general OCR.
+            use_seal_recognition (Optional[bool]): Whether to use seal recognition.
+            use_table_recognition (Optional[bool]): Whether to use table recognition.
+            use_formula_recognition (Optional[bool]): Whether to use formula recognition.
+            layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
+            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
+            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
+                Defaults to None.
+                If it's a single number, then both width and height are used.
+                If it's a tuple of two numbers, then they are used separately for width and height respectively.
+                If it's None, then no unclipping will be performed.
+            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
+            text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
+            text_det_thresh (Optional[float]): Threshold for text detection.
+            text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
+            text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
+            text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
+            seal_det_limit_side_len (Optional[int]): Maximum side length for seal detection.
+            seal_det_limit_type (Optional[str]): Type of limit to apply for seal detection.
+            seal_det_thresh (Optional[float]): Threshold for seal detection.
+            seal_det_box_thresh (Optional[float]): Threshold for seal detection boxes.
+            seal_det_unclip_ratio (Optional[float]): Ratio for unclipping seal detection boxes.
+            seal_rec_score_thresh (Optional[float]): Score threshold for seal recognition.
+
             **kwargs: Additional keyword arguments.
             **kwargs: Additional keyword arguments.
 
 
         Returns:
         Returns:
@@ -463,6 +488,7 @@ class LayoutParsingPipeline(BasePipeline):
                 overall_ocr_res = next(
                 overall_ocr_res = next(
                     self.general_ocr_pipeline(
                     self.general_ocr_pipeline(
                         doc_preprocessor_image,
                         doc_preprocessor_image,
+                        use_textline_orientation=use_textline_orientation,
                         text_det_limit_side_len=text_det_limit_side_len,
                         text_det_limit_side_len=text_det_limit_side_len,
                         text_det_limit_type=text_det_limit_type,
                         text_det_limit_type=text_det_limit_type,
                         text_det_thresh=text_det_thresh,
                         text_det_thresh=text_det_thresh,
@@ -531,7 +557,7 @@ class LayoutParsingPipeline(BasePipeline):
             else:
             else:
                 formula_res_list = []
                 formula_res_list = []
 
 
-            parsing_res_list = self.get_layout_parsing_res(
+            parsing_res_list, sub_image_list = self.get_layout_parsing_res(
                 doc_preprocessor_image,
                 doc_preprocessor_image,
                 layout_det_res=layout_det_res,
                 layout_det_res=layout_det_res,
                 overall_ocr_res=overall_ocr_res,
                 overall_ocr_res=overall_ocr_res,
@@ -558,5 +584,6 @@ class LayoutParsingPipeline(BasePipeline):
                 "formula_res_list": formula_res_list,
                 "formula_res_list": formula_res_list,
                 "parsing_res_list": parsing_res_list,
                 "parsing_res_list": parsing_res_list,
                 "model_settings": model_settings,
                 "model_settings": model_settings,
+                "sub_image_list": sub_image_list,
             }
             }
             yield LayoutParsingResult(single_img_res)
             yield LayoutParsingResult(single_img_res)

+ 14 - 2
paddlex/inference/pipelines/layout_parsing/result.py

@@ -12,12 +12,11 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import os
 from typing import Dict
 from typing import Dict
 import numpy as np
 import numpy as np
 from PIL import Image, ImageDraw
 from PIL import Image, ImageDraw
 import copy
 import copy
-from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, StrMixin, JsonMixin
+from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, JsonMixin
 
 
 
 
 class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
 class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
@@ -63,6 +62,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
                     table_draw.rectangle(
                     table_draw.rectangle(
                         [x1, y1, x2, y2], outline=rectangle_color, width=2
                         [x1, y1, x2, y2], outline=rectangle_color, width=2
                     )
                     )
+            res_img_dict["table_cell_img"] = table_cell_img
 
 
         if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
         if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
             for sno in range(len(self["seal_res_list"])):
             for sno in range(len(self["seal_res_list"])):
@@ -82,6 +82,16 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
                 sub_formula_res_dict = formula_res.img
                 sub_formula_res_dict = formula_res.img
                 key = f"formula_res_region{formula_region_id}"
                 key = f"formula_res_region{formula_region_id}"
                 res_img_dict[key] = sub_formula_res_dict["res"]
                 res_img_dict[key] = sub_formula_res_dict["res"]
+
+        if len(self["sub_image_list"]) > 0:
+            for sno in range(len(self["sub_image_list"])):
+                sub_region_image = Image.fromarray(
+                    copy.deepcopy(self["sub_image_list"][sno])
+                )
+                sub_region_image_id = sno + 1
+                key = f"sub_region_image{sub_region_image_id}"
+                res_img_dict[key] = sub_region_image
+
         return res_img_dict
         return res_img_dict
 
 
     def _to_str(self, *args, **kwargs) -> Dict[str, str]:
     def _to_str(self, *args, **kwargs) -> Dict[str, str]:
@@ -96,6 +106,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["model_settings"] = model_settings
         data["parsing_res_list"] = self["parsing_res_list"]
         data["parsing_res_list"] = self["parsing_res_list"]
@@ -147,6 +158,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["model_settings"] = model_settings
         data["parsing_res_list"] = self["parsing_res_list"]
         data["parsing_res_list"] = self["parsing_res_list"]

+ 2 - 0
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -108,6 +108,8 @@ def get_sub_regions_ocr_res(
             sub_regions_ocr_res["rec_boxes"].append(
             sub_regions_ocr_res["rec_boxes"].append(
                 overall_ocr_res["rec_boxes"][box_no]
                 overall_ocr_res["rec_boxes"][box_no]
             )
             )
+    for key in ["rec_polys", "rec_scores", "rec_boxes"]:
+        sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
     return (
     return (
         (sub_regions_ocr_res, match_idx_list)
         (sub_regions_ocr_res, match_idx_list)
         if return_match_idx
         if return_match_idx

+ 1 - 0
paddlex/inference/pipelines/seal_recognition/result.py

@@ -74,6 +74,7 @@ class SealRecognitionResult(BaseCVResult):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]

+ 2 - 0
paddlex/inference/pipelines/table_recognition/result.py

@@ -138,6 +138,7 @@ class TableRecognitionResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
@@ -163,6 +164,7 @@ class TableRecognitionResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         """
         data = {}
         data = {}
         data["input_path"] = self["input_path"]
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]