Bläddra i källkod

add use_textline_orientation to layout parsing (#2968)

changdazhou 9 månader sedan
förälder
incheckning
6f6a7c13b0

+ 2 - 0
paddlex/inference/pipelines/doc_preprocessor/result.py

@@ -79,6 +79,7 @@ class DocPreprocessorResult(BaseCVResult):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["angle"] = self["angle"]
         return JsonMixin._to_str(data, *args, **kwargs)
@@ -96,6 +97,7 @@ class DocPreprocessorResult(BaseCVResult):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         data["angle"] = self["angle"]
         return JsonMixin._to_json(data, *args, **kwargs)

+ 2 - 0
paddlex/inference/pipelines/formula_recognition/result.py

@@ -163,6 +163,7 @@ class FormulaRecognitionResult(BaseCVResult):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
@@ -195,6 +196,7 @@ class FormulaRecognitionResult(BaseCVResult):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = str(self["page_index"])
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]

+ 48 - 21
paddlex/inference/pipelines/layout_parsing/pipeline.py

@@ -12,12 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from email.mime import image
-from typing import Any, Dict, Optional, Union, List, Tuple
+from typing import Dict, Optional, Union, List, Tuple
 import numpy as np
 from ..base import BasePipeline
 from .utils import get_sub_regions_ocr_res, sorted_layout_boxes
-from ..components import convert_points_to_boxes
+from ..components import CropByBoxes
 from .result import LayoutParsingResult
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption
@@ -56,6 +55,7 @@ class LayoutParsingPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
 
         self.img_reader = ReadImage(format="BGR")
+        self._crop_by_boxes = CropByBoxes()
 
     def inintial_predictor(self, config: Dict) -> None:
         """Initializes the predictor based on the provided configuration.
@@ -88,7 +88,6 @@ class LayoutParsingPipeline(BasePipeline):
             "LayoutDetection",
             {"model_config_error": "config error for layout_det_model!"},
         )
-        self.layout_det_model = self.create_model(layout_det_config)
         layout_kwargs = {}
         if (threshold := layout_det_config.get("threshold", None)) is not None:
             layout_kwargs["threshold"] = threshold
@@ -205,7 +204,9 @@ class LayoutParsingPipeline(BasePipeline):
             list: A list of dictionaries representing the layout parsing result.
         """
         layout_parsing_res = []
+        sub_image_list = []
         matched_ocr_dict = {}
+        sub_image_region_id = 0
         formula_index = 0
         table_index = 0
         seal_index = 0
@@ -218,15 +219,15 @@ class LayoutParsingPipeline(BasePipeline):
             label = box_info["label"].lower()
             single_box_res["layout_bbox"] = box
             object_boxes.append(box)
-            if label == "formula":
+            if label == "formula" and len(formula_res_list) > formula_index:
                 single_box_res["formula"] = formula_res_list[formula_index][
                     "rec_formula"
                 ]
                 formula_index += 1
-            elif label == "table":
+            elif label == "table" and len(table_res_list) > table_index:
                 single_box_res["table"] = table_res_list[table_index]["pred_html"]
                 table_index += 1
-            elif label == "seal":
+            elif label == "seal" and len(seal_res_list) > seal_index:
                 single_box_res["seal"] = "".join(seal_res_list[seal_index]["rec_texts"])
                 seal_index += 1
             else:
@@ -239,9 +240,9 @@ class LayoutParsingPipeline(BasePipeline):
                     else:
                         matched_ocr_dict[matched_idx].append(object_box_idx)
                 if label in image_labels:
-                    x1, y1, x2, y2 = [int(i) for i in box]
-                    sub_image = image[y1:y2, x1:x2, :]
-                    single_box_res["image"] = sub_image
+                    crop_img_info = self._crop_by_boxes(image, [box_info])
+                    crop_img_info = crop_img_info[0]
+                    sub_image_list.append(crop_img_info["img"])
                     single_box_res[f"{label}_text"] = "\n".join(
                         ocr_res_in_box["rec_texts"]
                     )
@@ -286,7 +287,7 @@ class LayoutParsingPipeline(BasePipeline):
 
         layout_parsing_res = sorted_layout_boxes(layout_parsing_res, w=image.shape[1])
 
-        return layout_parsing_res
+        return layout_parsing_res, sub_image_list
 
     def check_model_settings_valid(self, input_params: Dict) -> bool:
         """
@@ -380,10 +381,15 @@ class LayoutParsingPipeline(BasePipeline):
         input: Union[str, List[str], np.ndarray, List[np.ndarray]],
         use_doc_orientation_classify: Optional[bool] = None,
         use_doc_unwarping: Optional[bool] = None,
+        use_textline_orientation: Optional[bool] = None,
         use_general_ocr: Optional[bool] = None,
         use_seal_recognition: Optional[bool] = None,
         use_table_recognition: Optional[bool] = None,
         use_formula_recognition: Optional[bool] = None,
+        layout_threshold: Optional[Union[float, dict]] = None,
+        layout_nms: Optional[bool] = None,
+        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
+        layout_merge_bboxes_mode: Optional[str] = None,
         text_det_limit_side_len: Optional[int] = None,
         text_det_limit_type: Optional[str] = None,
         text_det_thresh: Optional[float] = None,
@@ -396,10 +402,6 @@ class LayoutParsingPipeline(BasePipeline):
         seal_det_box_thresh: Optional[float] = None,
         seal_det_unclip_ratio: Optional[float] = None,
         seal_rec_score_thresh: Optional[float] = None,
-        layout_threshold: Optional[Union[float, dict]] = None,
-        layout_nms: Optional[bool] = None,
-        layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
-        layout_merge_bboxes_mode: Optional[str] = None,
         **kwargs,
     ) -> LayoutParsingResult:
         """
@@ -407,11 +409,34 @@ class LayoutParsingPipeline(BasePipeline):
 
         Args:
             input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or pdf(s) to be processed.
-            use_doc_orientation_classify (bool): Whether to use document orientation classification.
-            use_doc_unwarping (bool): Whether to use document unwarping.
-            use_general_ocr (bool): Whether to use general OCR.
-            use_seal_recognition (bool): Whether to use seal recognition.
-            use_table_recognition (bool): Whether to use table recognition.
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
+            use_general_ocr (Optional[bool]): Whether to use general OCR.
+            use_seal_recognition (Optional[bool]): Whether to use seal recognition.
+            use_table_recognition (Optional[bool]): Whether to use table recognition.
+            use_formula_recognition (Optional[bool]): Whether to use formula recognition.
+            layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
+            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
+            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
+                Defaults to None.
+                If it's a single number, then both width and height are used.
+                If it's a tuple of two numbers, then they are used separately for width and height respectively.
+                If it's None, then no unclipping will be performed.
+            layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
+            text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
+            text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
+            text_det_thresh (Optional[float]): Threshold for text detection.
+            text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
+            text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
+            text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
+            seal_det_limit_side_len (Optional[int]): Maximum side length for seal detection.
+            seal_det_limit_type (Optional[str]): Type of limit to apply for seal detection.
+            seal_det_thresh (Optional[float]): Threshold for seal detection.
+            seal_det_box_thresh (Optional[float]): Threshold for seal detection boxes.
+            seal_det_unclip_ratio (Optional[float]): Ratio for unclipping seal detection boxes.
+            seal_rec_score_thresh (Optional[float]): Score threshold for seal recognition.
+
             **kwargs: Additional keyword arguments.
 
         Returns:
@@ -463,6 +488,7 @@ class LayoutParsingPipeline(BasePipeline):
                 overall_ocr_res = next(
                     self.general_ocr_pipeline(
                         doc_preprocessor_image,
+                        use_textline_orientation=use_textline_orientation,
                         text_det_limit_side_len=text_det_limit_side_len,
                         text_det_limit_type=text_det_limit_type,
                         text_det_thresh=text_det_thresh,
@@ -531,7 +557,7 @@ class LayoutParsingPipeline(BasePipeline):
             else:
                 formula_res_list = []
 
-            parsing_res_list = self.get_layout_parsing_res(
+            parsing_res_list, sub_image_list = self.get_layout_parsing_res(
                 doc_preprocessor_image,
                 layout_det_res=layout_det_res,
                 overall_ocr_res=overall_ocr_res,
@@ -558,5 +584,6 @@ class LayoutParsingPipeline(BasePipeline):
                 "formula_res_list": formula_res_list,
                 "parsing_res_list": parsing_res_list,
                 "model_settings": model_settings,
+                "sub_image_list": sub_image_list,
             }
             yield LayoutParsingResult(single_img_res)

+ 14 - 2
paddlex/inference/pipelines/layout_parsing/result.py

@@ -12,12 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 from typing import Dict
 import numpy as np
 from PIL import Image, ImageDraw
 import copy
-from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, StrMixin, JsonMixin
+from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, JsonMixin
 
 
 class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
@@ -63,6 +62,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
                     table_draw.rectangle(
                         [x1, y1, x2, y2], outline=rectangle_color, width=2
                     )
+            res_img_dict["table_cell_img"] = table_cell_img
 
         if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
             for sno in range(len(self["seal_res_list"])):
@@ -82,6 +82,16 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
                 sub_formula_res_dict = formula_res.img
                 key = f"formula_res_region{formula_region_id}"
                 res_img_dict[key] = sub_formula_res_dict["res"]
+
+        if len(self["sub_image_list"]) > 0:
+            for sno in range(len(self["sub_image_list"])):
+                sub_region_image = Image.fromarray(
+                    copy.deepcopy(self["sub_image_list"][sno])
+                )
+                sub_region_image_id = sno + 1
+                key = f"sub_region_image{sub_region_image_id}"
+                res_img_dict[key] = sub_region_image
+
         return res_img_dict
 
     def _to_str(self, *args, **kwargs) -> Dict[str, str]:
@@ -96,6 +106,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["parsing_res_list"] = self["parsing_res_list"]
@@ -147,6 +158,7 @@ class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         model_settings = self["model_settings"]
         data["model_settings"] = model_settings
         data["parsing_res_list"] = self["parsing_res_list"]

+ 2 - 0
paddlex/inference/pipelines/layout_parsing/utils.py

@@ -108,6 +108,8 @@ def get_sub_regions_ocr_res(
             sub_regions_ocr_res["rec_boxes"].append(
                 overall_ocr_res["rec_boxes"][box_no]
             )
+    for key in ["rec_polys", "rec_scores", "rec_boxes"]:
+        sub_regions_ocr_res[key] = np.array(sub_regions_ocr_res[key])
     return (
         (sub_regions_ocr_res, match_idx_list)
         if return_match_idx

+ 1 - 0
paddlex/inference/pipelines/seal_recognition/result.py

@@ -74,6 +74,7 @@ class SealRecognitionResult(BaseCVResult):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]

+ 2 - 0
paddlex/inference/pipelines/table_recognition/result.py

@@ -138,6 +138,7 @@ class TableRecognitionResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
@@ -163,6 +164,7 @@ class TableRecognitionResult(BaseCVResult, HtmlMixin, XlsxMixin):
         """
         data = {}
         data["input_path"] = self["input_path"]
+        data["page_index"] = self["page_index"]
         data["model_settings"] = self["model_settings"]
         if self["model_settings"]["use_doc_preprocessor"]:
             data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]