Sfoglia il codice sorgente

fix input params and save_to_img for OCR and doc preprocessor (#2826)

dyning 10 mesi fa
parent
commit
87f601c32f

+ 12 - 6
api_examples/pipelines/test_doc_preprocessor.py

@@ -23,21 +23,27 @@ output = pipeline.predict(
 )
 
 # output = pipeline.predict(
-#     "./test_samples/doc_distort_test.jpg",
+#     "./test_samples/img_rot180_demo.jpg",
 #     use_doc_orientation_classify=False,
-#     use_doc_unwarping=True
+#     use_doc_unwarping=True,
 # )
 
 # output = pipeline.predict(
-#     "./test_samples/doc_distort_test.jpg",
+#     "./test_samples/img_rot180_demo.jpg",
 #     use_doc_orientation_classify=True,
-#     use_doc_unwarping=True
+#     use_doc_unwarping=True,
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/img_rot180_demo.jpg",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
 # )
 
 # output = pipeline.predict(
-#     "./test_samples/test_doc_processer.pdf",
+#     "./test_samples/doc_distort_test.jpg",
 #     use_doc_orientation_classify=True,
-#     use_doc_unwarping=False
+#     use_doc_unwarping=True
 # )
 
 for res in output:

+ 32 - 21
api_examples/pipelines/test_ocr.py

@@ -14,60 +14,71 @@
 
 from paddlex import create_pipeline
 
-pipeline = create_pipeline(pipeline="OCR", limit_side_len=320)
+pipeline = create_pipeline(pipeline="OCR")
 
 output = pipeline.predict(
     "./test_samples/general_ocr_002.png",
-    use_doc_orientation_classify=True,
+    use_doc_orientation_classify=False,
     use_doc_unwarping=False,
     use_textline_orientation=False,
-    unclip_ratio=3.0,
-    limit_side_len=1920,
 )
+
 # output = pipeline.predict(
 #     "./test_samples/general_ocr_002.png",
-#     use_doc_orientation_classify=True,
-#     use_doc_unwarping=True,
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
 #     use_textline_orientation=False,
+#     text_rec_score_thresh = 0.5
 # )
+
 # output = pipeline.predict(
 #     "./test_samples/general_ocr_002.png",
-#     use_doc_orientation_classify=True,
+#     use_doc_orientation_classify=False,
 #     use_doc_unwarping=False,
-#     use_textline_orientation=True,
+#     use_textline_orientation=False,
+#     text_det_unclip_ratio=3.0,
+#     text_det_limit_side_len=1920
 # )
+
 # output = pipeline.predict(
 #     "./test_samples/general_ocr_002.png",
 #     use_doc_orientation_classify=True,
-#     use_doc_unwarping=False,
-#     use_textline_orientation=False,
+#     use_doc_unwarping=True,
+#     use_textline_orientation=False
 # )
+
 # output = pipeline.predict(
-#     "./test_samples/general_ocr_002.png",
+#     "./test_samples/general_ocr_003.jpg",
 #     use_doc_orientation_classify=False,
-#     use_doc_unwarping=True,
-#     use_textline_orientation=True,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=False
 # )
+
 # output = pipeline.predict(
-#     "./test_samples/general_ocr_002.png",
+#     "./test_samples/general_ocr_003.jpg",
 #     use_doc_orientation_classify=False,
-#     use_doc_unwarping=True,
-#     use_textline_orientation=False,
+#     use_doc_unwarping=False,
+#     use_textline_orientation=True
 # )
+
 # output = pipeline.predict(
-#     "./test_samples/general_ocr_002.png",
+#     "./test_samples/general_ocr_002_rotate_90.png",
 #     use_doc_orientation_classify=False,
 #     use_doc_unwarping=False,
-#     use_textline_orientation=True,
+#     use_textline_orientation=False
 # )
+
 # output = pipeline.predict(
-#     "./test_samples/general_ocr_002.png",
+#     "./test_samples/general_ocr_002_rotate_90.png",
 #     use_doc_orientation_classify=False,
 #     use_doc_unwarping=False,
-#     use_textline_orientation=False,
+#     use_textline_orientation=True
 # )
+
 # output = pipeline.predict("./test_samples/财报1.pdf")
+
 for res in output:
     print(res)
     res.save_to_img("./output")
-    res.save_to_json("./output/res.json")
+    # TODO: need to check the json format
+    # res.save_to_json("./output/res.json")

+ 3 - 5
paddlex/configs/pipelines/OCR.yaml

@@ -3,8 +3,8 @@ pipeline_name: OCR
 
 text_type: general
 
-use_doc_preprocessor: False
-use_textline_orientation: False
+use_doc_preprocessor: True
+use_textline_orientation: True
 
 SubPipelines:
   DocPreprocessor:
@@ -16,19 +16,16 @@ SubPipelines:
         module_name: doc_text_orientation
         model_name: PP-LCNet_x1_0_doc_ori
         model_dir: null
-        batch_size: 1
       DocUnwarping:
         module_name: image_unwarping
         model_name: UVDoc
         model_dir: null
-        batch_size: 1
 
 SubModules:
   TextDetection:
     module_name: text_detection
     model_name: PP-OCRv4_mobile_det
     model_dir: null
-    batch_size: 1
     limit_side_len: 960
     limit_type: max
     thresh: 0.3
@@ -46,3 +43,4 @@ SubModules:
     model_name: PP-OCRv4_mobile_rec 
     model_dir: null
     batch_size: 1
+    score_thresh: 0

+ 2 - 2
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -1,6 +1,8 @@
 
 pipeline_name: PP-ChatOCRv3-doc
 
+use_layout_parser: False
+
 SubModules:
   LLM_Chat:
     module_name: chat_bot
@@ -9,7 +11,6 @@ SubModules:
     ak: "api_key" # Set this to a real API key
     sk: "secret_key"  # Set this to a real secret key
 
-
   LLM_Retriever:
     module_name: retriever
     model_name: ernie-3.5
@@ -17,7 +18,6 @@ SubModules:
     ak: "api_key" # Set this to a real API key
     sk: "secret_key"  # Set this to a real secret key
 
-
   PromptEngneering:
     KIE_CommonText:
       module_name: prompt_engneering

+ 2 - 0
paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml

@@ -1,6 +1,8 @@
 
 pipeline_name: PP-ChatOCRv4-doc
 
+use_layout_parser: False
+
 use_mllm_predict: True
 
 SubModules:

+ 0 - 2
paddlex/configs/pipelines/doc_preprocessor.yaml

@@ -9,9 +9,7 @@ SubModules:
     module_name: doc_text_orientation
     model_name: PP-LCNet_x1_0_doc_ori
     model_dir: null
-    batch_size: 1
   DocUnwarping:
     module_name: image_unwarping
     model_name: UVDoc
     model_dir: null
-    batch_size: 1

+ 0 - 6
paddlex/inference/pipelines_new/base.py

@@ -93,12 +93,6 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
             hpi_params=self.hpi_params,
             **kwargs,
         )
-
-        # [TODO] Support initializing with additional parameters
-        if "batch_size" in config:
-            batch_size = config["batch_size"]
-            model.set_predictor(batch_size=batch_size)
-
         return model
 
     def create_pipeline(self, config: Dict):

+ 6 - 3
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -212,13 +212,16 @@ class ErnieBotRetriever(BaseRetriever):
             str: A concatenated string of all unique contexts found.
         """
         C = []
+        all_C = ""
         for query_text in query_text_list:
             QUESTION = query_text
             time.sleep(sleep_time)
             docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=topk)
             context = [(document.page_content, score) for document, score in docs]
             context = sorted(context, key=lambda x: x[1])
-            C.extend([x[0] for x in context[::-1]])
-        C = list(set(C))
-        all_C = " ".join(C)
+            for text, score in context[::-1]:
+                if score >= -0.1:
+                    if len(all_C) + len(text) > min_characters:
+                        break
+                    all_C += text
         return all_C

+ 48 - 23
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -31,8 +31,8 @@ class DocPreprocessorPipeline(BasePipeline):
     def __init__(
         self,
         config: Dict,
-        device: str = None,
-        pp_option: PaddlePredictorOption = None,
+        device: Optional[str] = None,
+        pp_option: Optional[PaddlePredictorOption] = None,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
     ) -> None:
@@ -88,19 +88,19 @@ class DocPreprocessorPipeline(BasePipeline):
         ), "rotate_angle must in [0-360), but get {rotate_angle}."
         return rotate(image_array, rotate_angle, reshape=True)
 
-    def check_input_params_valid(self, input_params: Dict) -> bool:
+    def check_model_settings_valid(self, model_settings: Dict) -> bool:
         """
-        Check if the input parameters are valid based on the initialized models.
+        Check if the the input params for model settings are valid based on the initialized models.
 
         Args:
-            input_params (Dict): A dictionary containing input parameters.
+            model_settings (Dict): A dictionary containing model settings.
 
         Returns:
-            bool: True if all required models are initialized according to input parameters, False otherwise.
+            bool: True if all required models are initialized according to the model settings, False otherwise.
         """
 
         if (
-            input_params["use_doc_orientation_classify"]
+            model_settings["use_doc_orientation_classify"]
             and not self.use_doc_orientation_classify
         ):
             logging.error(
@@ -108,7 +108,7 @@ class DocPreprocessorPipeline(BasePipeline):
             )
             return False
 
-        if input_params["use_doc_unwarping"] and not self.use_doc_unwarping:
+        if model_settings["use_doc_unwarping"] and not self.use_doc_unwarping:
             logging.error(
                 "Set use_doc_unwarping, but the model for doc unwarping is not initialized."
             )
@@ -116,12 +116,34 @@ class DocPreprocessorPipeline(BasePipeline):
 
         return True
 
+    def get_model_settings(
+        self, use_doc_orientation_classify, use_doc_unwarping
+    ) -> dict:
+        """
+        Retrieve the model settings dictionary based on input parameters.
+
+        Args:
+            use_doc_orientation_classify (bool, optional): Whether to use document orientation classification.
+            use_doc_unwarping (bool, optional): Whether to use document unwarping.
+
+        Returns:
+            dict: A dictionary containing the model settings.
+        """
+        if use_doc_orientation_classify is None:
+            use_doc_orientation_classify = self.use_doc_orientation_classify
+        if use_doc_unwarping is None:
+            use_doc_unwarping = self.use_doc_unwarping
+        model_settings = {
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+        }
+        return model_settings
+
     def predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        use_doc_orientation_classify: bool = True,
-        use_doc_unwarping: bool = False,
-        **kwargs
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
     ) -> DocPreprocessorResult:
         """
         Predict the preprocessing result for the input image or images.
@@ -136,18 +158,22 @@ class DocPreprocessorPipeline(BasePipeline):
             DocPreprocessorResult: A generator yielding preprocessing results.
         """
 
-        input_params = {
-            "use_doc_orientation_classify": use_doc_orientation_classify,
-            "use_doc_unwarping": use_doc_unwarping,
-        }
-
-        if not self.check_input_params_valid(input_params):
-            yield {"error": "input params invalid"}
+        model_settings = self.get_model_settings(
+            use_doc_orientation_classify, use_doc_unwarping
+        )
+        if not self.check_model_settings_valid(model_settings):
+            yield {"error": "the input params for model settings are invalid!"}
 
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            if not isinstance(batch_data[0], str):
+                # TODO: add support input_pth for ndarray and pdf
+                input_path = f"{img_id}"
+            else:
+                input_path = batch_data[0]
+
             image_array = self.img_reader(batch_data)[0]
 
-            if input_params["use_doc_orientation_classify"]:
+            if model_settings["use_doc_orientation_classify"]:
                 pred = next(self.doc_ori_classify_model(image_array))
                 angle = int(pred["label_names"][0])
                 rot_img = self.rotate_image(image_array, angle)
@@ -155,18 +181,17 @@ class DocPreprocessorPipeline(BasePipeline):
                 angle = -1
                 rot_img = image_array
 
-            if input_params["use_doc_unwarping"]:
+            if model_settings["use_doc_unwarping"]:
                 output_img = next(self.doc_unwarping_model(rot_img))["doctr_img"]
             else:
                 output_img = rot_img
 
-            img_id += 1
             single_img_res = {
+                "input_path": input_path,
                 "input_image": image_array,
-                "input_params": input_params,
+                "model_settings": model_settings,
                 "angle": angle,
                 "rot_img": rot_img,
                 "output_img": output_img,
-                "img_id": img_id,
             }
             yield DocPreprocessorResult(single_img_res)

+ 32 - 23
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -16,6 +16,7 @@ from typing import Dict
 import math
 import random
 from pathlib import Path
+import copy
 import numpy as np
 import cv2
 import PIL
@@ -27,33 +28,41 @@ from ...common.result import BaseCVResult
 class DocPreprocessorResult(BaseCVResult):
     """doc preprocessor result"""
 
-    def save_to_img(self, save_path: str, *args, **kwargs) -> None:
-        """
-        Save the image to the specified path.
-
-        Args:
-            save_path (str): The path to save the image.
-                If the path does not end with '.jpg' or '.png', it appends '_res_doc_preprocess_<img_id>.jpg'
-                to the path where <img_id> is retrieved from the object's 'img_id' attribute.
-            *args: Variable length argument list.
-            **kwargs: Arbitrary keyword arguments.
-
-        Returns:
-            None
-        """
-        if not str(save_path).lower().endswith((".jpg", ".png")):
-            img_id = self["img_id"]
-            save_path = Path(save_path) / f"res_doc_preprocess_{img_id}.jpg"
-        super().save_to_img(save_path, *args, **kwargs)
-
     def _to_img(self) -> Dict[str, Image.Image]:
         """
         Generate an image combining the original, rotated, and unwarping images.
 
         Returns:
-            Dict[Image.Image]: A new image that displays the rotated, and unwarping images.
+            Dict[Image.Image]: A new image combining the original, rotated, and unwarping images
         """
-        imgs = {"preprocessed_img": Image.fromarray(self["output_img"][:, :, ::-1])}
-        if self["rot_img"] is not None:
-            imgs["rotated_img"] = Image.fromarray(self["rot_img"][:, :, ::-1])
+        image = self["input_image"][:, :, ::-1]
+        rot_img = self["rot_img"][:, :, ::-1]
+        angle = self["angle"]
+        output_img = self["output_img"][:, :, ::-1]
+        use_doc_orientation_classify = self["model_settings"][
+            "use_doc_orientation_classify"
+        ]
+        use_doc_unwarping = self["model_settings"]["use_doc_unwarping"]
+        h1, w1 = image.shape[0:2]
+        h2, w2 = rot_img.shape[0:2]
+        h3, w3 = output_img.shape[0:2]
+        h = max(max(h1, h2), h3)
+        img_show = Image.new("RGB", (w1 + w2 + w3, h + 25), (255, 255, 255))
+        img_show.paste(Image.fromarray(image), (0, 0, w1, h1))
+        img_show.paste(Image.fromarray(rot_img), (w1, 0, w1 + w2, h2))
+        img_show.paste(Image.fromarray(output_img), (w1 + w2, 0, w1 + w2 + w3, h3))
+
+        draw_text = ImageDraw.Draw(img_show)
+        txt_list = ["Original Image", "Rotated Image", "Unwarping Image"]
+        txt_list[1] = f"Rotated Image ({use_doc_orientation_classify}, {angle})"
+        txt_list[2] = f"Unwarping Image ({use_doc_unwarping})"
+        region_w_list = [w1, w2, w3]
+        beg_w_list = [0, w1, w1 + w2]
+        for tno in range(len(txt_list)):
+            txt = txt_list[tno]
+            font = create_font(txt, (region_w_list[tno], 20), PINGFANG_FONT_FILE_PATH)
+            draw_text.text(
+                [10 + beg_w_list[tno], h + 2], txt, fill=(0, 0, 0), font=font
+            )
+        imgs = {"preprocessed_img": img_show}
         return imgs

+ 3 - 1
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -89,7 +89,9 @@ def get_sub_regions_ocr_res(
         OCRResult: A filtered OCR result containing only the relevant text boxes.
     """
     sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
-    sub_regions_ocr_res["input_img"] = overall_ocr_res["input_img"]
+    sub_regions_ocr_res["doc_preprocessor_image"] = overall_ocr_res[
+        "doc_preprocessor_image"
+    ]
     sub_regions_ocr_res["img_id"] = -1
     sub_regions_ocr_res["dt_polys"] = []
     sub_regions_ocr_res["rec_text"] = []

+ 194 - 117
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -34,16 +34,6 @@ class OCRPipeline(BasePipeline):
         self,
         config: Dict,
         device: Optional[str] = None,
-        use_doc_orientation_classify: Optional[bool] = None,
-        use_doc_unwarping: Optional[bool] = None,
-        use_textline_orientation: Optional[bool] = None,
-        limit_side_len: Optional[int] = None,
-        limit_type: Optional[str] = None,
-        thresh: Optional[float] = None,
-        box_thresh: Optional[float] = None,
-        max_candidates: Optional[int] = None,
-        unclip_ratio: Optional[float] = None,
-        use_dilation: Optional[bool] = None,
         pp_option: Optional[PaddlePredictorOption] = None,
         use_hpip: bool = False,
         hpi_params: Optional[Dict[str, Any]] = None,
@@ -52,69 +42,63 @@ class OCRPipeline(BasePipeline):
         Initializes the class with given configurations and options.
 
         Args:
-            config (Dict): Configuration dictionary containing model and other parameters.
-            device (Union[str, None]): The device to run the prediction on.
-            use_textline_orientation (Union[bool, None]): Whether to use textline orientation.
-            use_doc_orientation_classify (Union[bool, None]): Whether to use document orientation classification.
-            use_doc_unwarping (Union[bool, None]): Whether to use document unwarping.
-            limit_side_len (Union[int, None]): Limit of side length.
-            limit_type (Union[str, None]): Type of limit.
-            thresh (Union[float, None]): Threshold value.
-            box_thresh (Union[float, None]): Box threshold value.
-            max_candidates (Union[int, None]): Maximum number of candidates.
-            unclip_ratio (Union[float, None]): Unclip ratio.
-            use_dilation (Union[bool, None]): Whether to use dilation.
-            pp_option (Union[PaddlePredictorOption, None]): Options for PaddlePaddle predictor.
-            use_hpip (Union[bool, None]): Whether to use high-performance inference.
-            hpi_params (Union[Dict[str, Any], None]): HPIP specific parameters.
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
         """
         super().__init__(
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
-        self.use_textline_orientation = (
-            use_textline_orientation
-            if use_textline_orientation is not None
-            else config.get("use_textline_orientation", False)
-        )
-        self.use_doc_preprocessor = self.get_preprocessor_value(
-            use_doc_orientation_classify, use_doc_unwarping, config, False
-        )
-
-        text_det_default_params = {
-            "limit_side_len": 960,
-            "limit_type": "max",
-            "thresh": 0.3,
-            "box_thresh": 0.6,
-            "max_candidates": 1000,
-            "unclip_ratio": 2.0,
-            "use_dilation": False,
-        }
-
-        text_det_config = config["SubModules"]["TextDetection"]
-        for key, default_params in text_det_default_params.items():
-            text_det_config[key] = locals().get(
-                key, text_det_config.get(key, default_params)
+        self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
             )
-        self.text_det_model = self.create_model(text_det_config)
-
-        text_rec_config = config["SubModules"]["TextRecognition"]
-        self.text_rec_model = self.create_model(text_rec_config)
 
+        self.use_textline_orientation = config.get("use_textline_orientation", True)
         if self.use_textline_orientation:
             textline_orientation_config = config["SubModules"]["TextLineOrientation"]
+            # TODO: add batch_size
+            # batch_size = textline_orientation_config.get("batch_size", 1)
+            # self.textline_orientation_model = self.create_model(
+            #     textline_orientation_config, batch_size=batch_size
+            # )
             self.textline_orientation_model = self.create_model(
                 textline_orientation_config
             )
 
-        if self.use_doc_preprocessor:
-            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
-            self.doc_preprocessor_pipeline = self.create_pipeline(
-                doc_preprocessor_config
-            )
+        text_det_config = config["SubModules"]["TextDetection"]
+        self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
+        self.text_det_limit_type = text_det_config.get("limit_type", "max")
+        self.text_det_thresh = text_det_config.get("thresh", 0.3)
+        self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
+        self.text_det_max_candidates = text_det_config.get("max_candidates", 1000)
+        self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
+        self.text_det_use_dilation = text_det_config.get("use_dilation", False)
+        self.text_det_model = self.create_model(
+            text_det_config,
+            limit_side_len=self.text_det_limit_side_len,
+            limit_type=self.text_det_limit_type,
+            thresh=self.text_det_thresh,
+            box_thresh=self.text_det_box_thresh,
+            max_candidates=self.text_det_max_candidates,
+            unclip_ratio=self.text_det_unclip_ratio,
+            use_dilation=self.text_det_use_dilation,
+        )
 
-        self.text_type = config["text_type"]
+        text_rec_config = config["SubModules"]["TextRecognition"]
+        # TODO: add batch_size
+        # batch_size = text_rec_config.get("batch_size", 1)
+        # self.text_rec_model = self.create_model(text_rec_config,
+        #     batch_size=batch_size)
+        self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
+        self.text_rec_model = self.create_model(text_rec_config)
 
+        self.text_type = config["text_type"]
         if self.text_type == "general":
             self._sort_boxes = SortQuadBoxes()
             self._crop_by_polys = CropByPolys(det_box_type="quad")
@@ -127,16 +111,6 @@ class OCRPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
-    @staticmethod
-    def get_preprocessor_value(orientation, unwarping, config, default):
-        if orientation is None and unwarping is None:
-            return config.get("use_doc_preprocessor", default)
-        else:
-            if orientation is False and unwarping is False:
-                return False
-            else:
-                return True
-
     def rotate_image(
         self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
     ) -> List[np.ndarray]:
@@ -202,22 +176,24 @@ class OCRPipeline(BasePipeline):
         return True
 
     def predict_doc_preprocessor_res(
-        self, image_array: np.ndarray, input_params: dict
+        self, image_array: np.ndarray, model_settings: dict
     ) -> tuple[DocPreprocessorResult, np.ndarray]:
         """
         Preprocess the document image based on input parameters.
 
         Args:
             image_array (np.ndarray): The input image array.
-            input_params (dict): Dictionary containing preprocessing parameters.
+            model_settings (dict): Dictionary containing preprocessing parameters.
 
         Returns:
             tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
                                               result dictionary and the processed image array.
         """
-        if input_params["use_doc_preprocessor"]:
-            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
-            use_doc_unwarping = input_params["use_doc_unwarping"]
+        if model_settings["use_doc_preprocessor"]:
+            use_doc_orientation_classify = model_settings[
+                "use_doc_orientation_classify"
+            ]
+            use_doc_unwarping = model_settings["use_doc_unwarping"]
             doc_preprocessor_res = next(
                 self.doc_preprocessor_pipeline(
                     image_array,
@@ -229,61 +205,161 @@ class OCRPipeline(BasePipeline):
             doc_preprocessor_res = {"output_img": image_array}
         return doc_preprocessor_res
 
+    def get_model_settings(
+        self,
+        use_doc_orientation_classify: Optional[bool],
+        use_doc_unwarping: Optional[bool],
+        use_textline_orientation: Optional[bool],
+    ) -> dict:
+        """
+        Get the model settings based on the provided parameters or default values.
+
+        Args:
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_textline_orientation (Optional[bool]): Whether to use textline orientation.
+
+        Returns:
+            dict: A dictionary containing the model settings.
+        """
+        if use_doc_orientation_classify is None:
+            use_doc_orientation_classify = self.use_doc_orientation_classify
+        if use_doc_unwarping is None:
+            use_doc_unwarping = self.use_doc_unwarping
+        if use_textline_orientation is None:
+            use_textline_orientation = self.use_textline_orientation
+        return dict(
+            use_doc_orientation_classify=use_doc_orientation_classify,
+            use_doc_unwarping=use_doc_unwarping,
+            use_textline_orientation=use_textline_orientation,
+        )
+
+    def get_text_det_params(
+        self,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_max_candidates: Optional[int] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_det_use_dilation: Optional[bool] = None,
+    ) -> dict:
+        """
+        Get text detection parameters.
+
+        If a parameter is None, its default value from the instance will be used.
+
+        Args:
+            text_det_limit_side_len (Optional[int]): The maximum side length of the text box.
+            text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
+            text_det_thresh (Optional[float]): The threshold for text detection.
+            text_det_box_thresh (Optional[float]): The threshold for the bounding box.
+            text_det_max_candidates (Optional[int]): The maximum number of candidate text boxes.
+            text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
+            text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
+
+        Returns:
+            dict: A dictionary containing the text detection parameters.
+        """
+        if text_det_limit_side_len is None:
+            text_det_limit_side_len = self.text_det_limit_side_len
+        if text_det_limit_type is None:
+            text_det_limit_type = self.text_det_limit_type
+        if text_det_thresh is None:
+            text_det_thresh = self.text_det_thresh
+        if text_det_box_thresh is None:
+            text_det_box_thresh = self.text_det_box_thresh
+        if text_det_max_candidates is None:
+            text_det_max_candidates = self.text_det_max_candidates
+        if text_det_unclip_ratio is None:
+            text_det_unclip_ratio = self.text_det_unclip_ratio
+        if text_det_use_dilation is None:
+            text_det_use_dilation = self.text_det_use_dilation
+        return dict(
+            limit_side_len=text_det_limit_side_len,
+            limit_type=text_det_limit_type,
+            thresh=text_det_thresh,
+            box_thresh=text_det_box_thresh,
+            max_candidates=text_det_max_candidates,
+            unclip_ratio=text_det_unclip_ratio,
+            use_dilation=text_det_use_dilation,
+        )
+
     def predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        use_doc_orientation_classify: bool = False,
-        use_doc_unwarping: bool = False,
-        use_textline_orientation: bool = False,
-        limit_side_len: int = 960,
-        limit_type: str = "max",
-        thresh: float = 0.3,
-        box_thresh: float = 0.6,
-        max_candidates: int = 1000,
-        unclip_ratio: float = 2.0,
-        use_dilation: bool = False,
-        **kwargs,
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_textline_orientation: Optional[bool] = None,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_max_candidates: Optional[int] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_det_use_dilation: Optional[bool] = None,
+        text_rec_score_thresh: Optional[float] = None,
     ) -> OCRResult:
-        """Predicts OCR results for the given input.
+        """
+        Predict OCR results based on input images or arrays with optional preprocessing steps.
 
         Args:
-            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images or pdf(s).
-            **kwargs: Additional keyword arguments that can be passed to the function.
-
+            input (str | list[str] | np.ndarray | list[np.ndarray]): Input image of pdf path(s) or numpy array(s).
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
+            text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
+            text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
+            text_det_thresh (Optional[float]): Threshold for text detection.
+            text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
+            text_det_max_candidates (Optional[int]): Maximum number of text detection candidates.
+            text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
+            text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
+            text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
         Returns:
-            OCRResult: An iterable of OCRResult objects, each containing the predicted text and other relevant information.
+            OCRResult: Generator yielding OCR results for each input image.
         """
 
-        model_settings = {
-            "use_doc_orientation_classify": use_doc_orientation_classify,
-            "use_doc_unwarping": use_doc_unwarping,
-            "use_textline_orientation": use_textline_orientation,
-        }
-        if use_doc_orientation_classify or use_doc_unwarping:
+        model_settings = self.get_model_settings(
+            use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
+        )
+        if (
+            model_settings["use_doc_orientation_classify"]
+            or model_settings["use_doc_unwarping"]
+        ):
             model_settings["use_doc_preprocessor"] = True
         else:
             model_settings["use_doc_preprocessor"] = False
 
         if not self.check_model_settings_valid(model_settings):
-            yield None
-
-        text_det_params = {
-            "limit_side_len": limit_side_len,
-            "limit_type": limit_type,
-            "thresh": thresh,
-            "box_thresh": box_thresh,
-            "max_candidates": max_candidates,
-            "unclip_ratio": unclip_ratio,
-            "use_dilation": use_dilation,
-        }
+            yield {"error": "the input params for model settings are invalid!"}
+
+        text_det_params = self.get_text_det_params(
+            text_det_limit_side_len,
+            text_det_limit_type,
+            text_det_thresh,
+            text_det_box_thresh,
+            text_det_max_candidates,
+            text_det_unclip_ratio,
+            text_det_use_dilation,
+        )
+
+        if text_rec_score_thresh is None:
+            text_rec_score_thresh = self.text_rec_score_thresh
 
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            if not isinstance(batch_data[0], str):
+                # TODO: add support input_pth for ndarray and pdf
+                input_path = f"{img_id}"
+            else:
+                input_path = batch_data[0]
+
             image_array = self.img_reader(batch_data)[0]
-            img_id += 1
 
             doc_preprocessor_res = self.predict_doc_preprocessor_res(
                 image_array, model_settings
             )
+
             doc_preprocessor_image = doc_preprocessor_res["output_img"]
 
             det_res = next(
@@ -296,19 +372,17 @@ class OCRPipeline(BasePipeline):
             dt_polys = self._sort_boxes(dt_polys)
 
             single_img_res = {
-                "input_path": input,
-                # TODO: `doc_preprocessor_image` parameter does not need to be retained here, it requires further confirmation.
-                "doc_preprocessor_image": doc_preprocessor_image,
+                "input_path": batch_data[0],
                 "doc_preprocessor_res": doc_preprocessor_res,
                 "dt_polys": dt_polys,
-                "img_id": img_id,
-                "input_params": model_settings,
+                "model_settings": model_settings,
                 "text_det_params": text_det_params,
                 "text_type": self.text_type,
             }
 
-            single_img_res["rec_text"] = []
-            single_img_res["rec_score"] = []
+            single_img_res["rec_texts"] = []
+            single_img_res["rec_scores"] = []
+            single_img_res["rec_boxes"] = []
             if len(dt_polys) > 0:
                 all_subs_of_img = list(
                     self._crop_by_polys(doc_preprocessor_image, dt_polys)
@@ -324,8 +398,11 @@ class OCRPipeline(BasePipeline):
                     single_img_res["textline_orientation_angle"] = angles
                     all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
 
+                rno = -1
                 for rec_res in self.text_rec_model(all_subs_of_img):
-                    single_img_res["rec_text"].append(rec_res["rec_text"])
-                    single_img_res["rec_score"].append(rec_res["rec_score"])
-
+                    rno += 1
+                    if rec_res["rec_score"] >= text_rec_score_thresh:
+                        single_img_res["rec_texts"].append(rec_res["rec_text"])
+                        single_img_res["rec_scores"].append(rec_res["rec_score"])
+                        single_img_res["rec_boxes"].append(dt_polys[rno])
             yield OCRResult(single_img_res)

+ 11 - 71
paddlex/inference/pipelines_new/ocr/result.py

@@ -28,58 +28,6 @@ from ...common.result import BaseCVResult
 class OCRResult(BaseCVResult):
     """OCR result"""
 
-    def save_to_json(
-        self,
-        save_path: str,
-        indent: int = 4,
-        ensure_ascii: bool = False,
-        save_ndarray: bool = False,
-        *args,
-        **kwargs,
-    ) -> None:
-        """Save the JSON representation of the object to a file.
-
-        Args:
-            save_path (str): The path to save the JSON file. If the save path does not end with '.json', it appends the base name and suffix of the input path.
-            indent (int): The number of spaces to indent for pretty printing. Default is 4.
-            ensure_ascii (bool): If False, non-ASCII characters will be included in the output. Default is False.
-            save_ndarray (bool): If True, save the numpy arrays in the result. Default is False.
-            *args: Additional positional arguments to pass to the underlying writer.
-            **kwargs: Additional keyword arguments to pass to the underlying writer.
-        """
-        img_id = self["img_id"]
-
-        # TODO : Support determining the output name based on the input name.
-        os.makedirs(save_path, exist_ok=True)
-        save_path = os.path.join(save_path, "res.json")
-
-        base_name, ext = os.path.splitext(save_path)
-        save_path = f"{base_name}_{img_id}{ext}"
-
-        def remove_ndarray(d):
-            """
-            Remove all keys from the dictionary whose values are numpy arrays.
-            """
-            keys_to_delete = []
-            for key, value in d.items():
-                if isinstance(value, dict):
-                    remove_ndarray(value)
-                    if all(isinstance(v, np.ndarray) for v in value.values()):
-                        keys_to_delete.append(key)
-                elif isinstance(value, np.ndarray):
-                    keys_to_delete.append(key)
-            for key in keys_to_delete:
-                del d[key]
-
-        if not save_ndarray:
-            self_copy = copy.deepcopy(self)
-            remove_ndarray(self_copy)
-            super(type(self_copy), self_copy).save_to_json(
-                save_path, indent, ensure_ascii, *args, **kwargs
-            )
-        else:
-            super().save_to_json(save_path, indent, ensure_ascii, *args, **kwargs)
-
     def get_minarea_rect(self, points: np.ndarray) -> np.ndarray:
         """
         Get the minimum area rectangle for the given points using OpenCV.
@@ -121,26 +69,17 @@ class OCRResult(BaseCVResult):
         Returns:
             PIL.Image: An image with detection boxes, texts, and scores blended on it.
         """
-
-        # TODO(gaotingquan): mv to postprocess
-        drop_score = 0.5
-
-        boxes = self["dt_polys"]
-        txts = self["rec_text"]
-        scores = self["rec_score"]
-        image = self["doc_preprocessor_image"]
+        boxes = self["rec_boxes"]
+        txts = self["rec_texts"]
+        image = self["doc_preprocessor_res"]["output_img"]
         h, w = image.shape[0:2]
         image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         img_left = Image.fromarray(image_rgb)
         img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
         random.seed(0)
         draw_left = ImageDraw.Draw(img_left)
-        if txts is None or len(txts) != len(boxes):
-            txts = [None] * len(boxes)
         for idx, (box, txt) in enumerate(zip(boxes, txts)):
             try:
-                if scores is not None and scores[idx] < drop_score:
-                    continue
                 color = (
                     random.randint(0, 255),
                     random.randint(0, 255),
@@ -169,13 +108,14 @@ class OCRResult(BaseCVResult):
         img_show.paste(img_left, (0, 0, w, h))
         img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
 
-        input_params = self["input_params"]
-        img_id = self["img_id"]
-
-        return {
-            **self["doc_preprocessor_res"].img,
-            f"res_ocr_{img_id}": img_show,
-        }
+        model_settings = self["model_settings"]
+        if model_settings["use_doc_preprocessor"]:
+            return {
+                **self["doc_preprocessor_res"].img,
+                f"ocr_res_img": img_show,
+            }
+        else:
+            return {f"ocr_res_img": img_show}
 
 
 # Adds a function comment according to Google Style Guide

+ 11 - 2
paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py

@@ -74,8 +74,13 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             None
         """
 
-        layout_parsing_config = config["SubPipelines"]["LayoutParser"]
-        self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+        self.use_layout_parser = True
+        if "use_layout_parser" in config:
+            self.use_layout_parser = config["use_layout_parser"]
+
+        if self.use_layout_parser:
+            layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+            self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
         from .. import create_chat_bot
 
@@ -171,6 +176,10 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             dict: A dictionary containing the layout parsing result and visual information.
         """
 
+        if self.use_layout_parser == False:
+            logging.error("The models for layout parser are not initialized.")
+            yield None
+
         for layout_parsing_result in self.layout_parsing_pipeline.predict(
             input,
             use_doc_orientation_classify=use_doc_orientation_classify,

+ 11 - 3
paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py

@@ -76,8 +76,13 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             None
         """
 
-        layout_parsing_config = config["SubPipelines"]["LayoutParser"]
-        self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
+        self.use_layout_parser = True
+        if "use_layout_parser" in config:
+            self.use_layout_parser = config["use_layout_parser"]
+
+        if self.use_layout_parser:
+            layout_parsing_config = config["SubPipelines"]["LayoutParser"]
+            self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
 
         from .. import create_chat_bot
 
@@ -184,6 +189,9 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         Returns:
             dict: A dictionary containing the layout parsing result and visual information.
         """
+        if self.use_layout_parser == False:
+            logging.error("The models for layout parser are not initialized.")
+            yield None
 
         for layout_parsing_result in self.layout_parsing_pipeline.predict(
             input,
@@ -484,7 +492,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             vector = vector_info["vector"]
             if not vector_info["flag_too_short_text"]:
                 related_text = self.retriever.similarity_retrieval(
-                    question_key_list, vector, topk=5, min_characters=min_characters
+                    question_key_list, vector, topk=50, min_characters=min_characters
                 )
             else:
                 if len(vector) > 0:

+ 1 - 1
paddlex/inference/pipelines_new/table_recognition/table_recognition_post_processing.py

@@ -225,7 +225,7 @@ def get_table_recognition_res(
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
 
     crop_start_point = [table_box[0][0], table_box[0][1]]
-    img_shape = overall_ocr_res["input_img"].shape[0:2]
+    img_shape = overall_ocr_res["doc_preprocessor_image"].shape[0:2]
 
     convert_table_structure_pred_bbox(table_structure_pred, crop_start_point, img_shape)
 

+ 2 - 2
requirements.txt

@@ -31,8 +31,8 @@ importlib_resources>=6.4
 qianfan==0.0.3
 langchain==0.1.5
 langchain-community==0.0.17
-erniebot == 0.5.0
-erniebot-agent == 0.5.0
+erniebot == 0.5.9
+erniebot-agent == 0.5.2
 unstructured
 networkx
 faiss-cpu