Browse Source

fix input params and save_to_img for layout parsing (#2858)

dyning 10 tháng trước cách đây
mục cha
commit
4f16aa67e1

+ 5 - 1
api_examples/pipelines/test_layout_parsing.py

@@ -54,4 +54,8 @@ output = pipeline.predict(
 
 for res in output:
     print(res)
-    res.save_results("./output")
+    res.print()
+    res.save_to_img("./output")
+    res.save_to_json("./output")
+    res.save_to_xlsx("./output")
+    res.save_to_html("./output")

+ 17 - 6
paddlex/configs/pipelines/layout_parsing.yaml

@@ -11,7 +11,6 @@ SubModules:
     module_name: layout_detection
     model_name: RT-DETR-H_layout_3cls
     model_dir: null
-    batch_size: 1
 
 SubPipelines:
   DocPreprocessor:
@@ -23,27 +22,33 @@ SubPipelines:
         module_name: doc_text_orientation
         model_name: PP-LCNet_x1_0_doc_ori
         model_dir: null
-        batch_size: 1
       DocUnwarping:
         module_name: image_unwarping
         model_name: UVDoc
         model_dir: null
-        batch_size: 1
 
   GeneralOCR:
     pipeline_name: OCR
     text_type: general
+    use_doc_preprocessor: False
+    use_textline_orientation: False
     SubModules:
       TextDetection:
         module_name: text_detection
         model_name: PP-OCRv4_server_det
         model_dir: null
-        batch_size: 1    
+        limit_side_len: 960
+        limit_type: max
+        thresh: 0.3
+        box_thresh: 0.6
+        unclip_ratio: 2.0
+        
       TextRecognition:
         module_name: text_recognition
         model_name: PP-OCRv4_server_rec
         model_dir: null
         batch_size: 1
+        score_thresh: 0
 
   TableRecognition:
     pipeline_name: table_recognition
@@ -55,7 +60,6 @@ SubPipelines:
         module_name: table_structure_recognition
         model_name: SLANet_plus
         model_dir: null
-        batch_size: 1
 
   SealRecognition:
     pipeline_name: seal_recognition
@@ -65,14 +69,21 @@ SubPipelines:
       SealOCR:
         pipeline_name: OCR
         text_type: seal
+        use_doc_preprocessor: False
+        use_textline_orientation: False
         SubModules:
           TextDetection:
             module_name: seal_text_detection
             model_name: PP-OCRv4_server_seal_det
             model_dir: null
-            batch_size: 1    
+            limit_side_len: 736
+            limit_type: min
+            thresh: 0.2
+            box_thresh: 0.6
+            unclip_ratio: 0.5
           TextRecognition:
             module_name: text_recognition
             model_name: PP-OCRv4_server_rec
             model_dir: null
             batch_size: 1
+            score_thresh: 0

+ 146 - 113
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -64,37 +64,6 @@ class LayoutParsingPipeline(BasePipeline):
 
         self.img_reader = ReadImage(format="BGR")
 
-    def set_used_models_flag(self, config: Dict) -> None:
-        """
-        Set the flags for which models to use based on the configuration.
-
-        Args:
-            config (Dict): A dictionary containing configuration settings.
-
-        Returns:
-            None
-        """
-        pipeline_name = config["pipeline_name"]
-
-        self.pipeline_name = pipeline_name
-
-        self.use_doc_preprocessor = False
-        self.use_general_ocr = False
-        self.use_seal_recognition = False
-        self.use_table_recognition = False
-
-        if "use_doc_preprocessor" in config:
-            self.use_doc_preprocessor = config["use_doc_preprocessor"]
-
-        if "use_general_ocr" in config:
-            self.use_general_ocr = config["use_general_ocr"]
-
-        if "use_seal_recognition" in config:
-            self.use_seal_recognition = config["use_seal_recognition"]
-
-        if "use_table_recognition" in config:
-            self.use_table_recognition = config["use_table_recognition"]
-
     def inintial_predictor(self, config: Dict) -> None:
         """Initializes the predictor based on the provided configuration.
 
@@ -105,29 +74,53 @@ class LayoutParsingPipeline(BasePipeline):
             None
         """
 
-        self.set_used_models_flag(config)
-
-        layout_det_config = config["SubModules"]["LayoutDetection"]
-        self.layout_det_model = self.create_model(layout_det_config)
+        self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
+        self.use_general_ocr = config.get("use_general_ocr", True)
+        self.use_table_recognition = config.get("use_table_recognition", True)
+        self.use_seal_recognition = config.get("use_seal_recognition", True)
 
         if self.use_doc_preprocessor:
-            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            doc_preprocessor_config = config.get("SubPipelines", {}).get(
+                "DocPreprocessor",
+                {
+                    "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
+                },
+            )
             self.doc_preprocessor_pipeline = self.create_pipeline(
                 doc_preprocessor_config
             )
 
+        layout_det_config = config.get("SubModules", {}).get(
+            "LayoutDetection",
+            {"model_config_error": "config error for layout_det_model!"},
+        )
+        self.layout_det_model = self.create_model(layout_det_config)
+
         if self.use_general_ocr or self.use_table_recognition:
-            general_ocr_config = config["SubPipelines"]["GeneralOCR"]
+            general_ocr_config = config.get("SubPipelines", {}).get(
+                "GeneralOCR",
+                {"pipeline_config_error": "config error for general_ocr_pipeline!"},
+            )
             self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
 
         if self.use_seal_recognition:
-            seal_recognition_config = config["SubPipelines"]["SealRecognition"]
+            seal_recognition_config = config.get("SubPipelines", {}).get(
+                "SealRecognition",
+                {
+                    "pipeline_config_error": "config error for seal_recognition_pipeline!"
+                },
+            )
             self.seal_recognition_pipeline = self.create_pipeline(
                 seal_recognition_config
             )
 
         if self.use_table_recognition:
-            table_recognition_config = config["SubPipelines"]["TableRecognition"]
+            table_recognition_config = config.get("SubPipelines", {}).get(
+                "TableRecognition",
+                {
+                    "pipeline_config_error": "config error for table_recognition_pipeline!"
+                },
+            )
             self.table_recognition_pipeline = self.create_pipeline(
                 table_recognition_config
             )
@@ -154,7 +147,7 @@ class LayoutParsingPipeline(BasePipeline):
         object_boxes = np.array(object_boxes)
         return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
 
-    def check_input_params_valid(self, input_params: Dict) -> bool:
+    def check_model_settings_valid(self, input_params: Dict) -> bool:
         """
         Check if the input parameters are valid based on the initialized models.
 
@@ -191,60 +184,72 @@ class LayoutParsingPipeline(BasePipeline):
 
         return True
 
-    def predict_doc_preprocessor_res(
-        self, image_array: np.ndarray, input_params: dict
-    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+    def get_model_settings(
+        self,
+        use_doc_orientation_classify: Optional[bool],
+        use_doc_unwarping: Optional[bool],
+        use_general_ocr: Optional[bool],
+        use_seal_recognition: Optional[bool],
+        use_table_recognition: Optional[bool],
+    ) -> dict:
         """
-        Preprocess the document image based on input parameters.
+        Get the model settings based on the provided parameters or default values.
 
         Args:
-            image_array (np.ndarray): The input image array.
-            input_params (dict): Dictionary containing preprocessing parameters.
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_general_ocr (Optional[bool]): Whether to use general OCR.
+            use_seal_recognition (Optional[bool]): Whether to use seal recognition.
+            use_table_recognition (Optional[bool]): Whether to use table recognition.
 
         Returns:
-            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
-                                              result dictionary and the processed image array.
+            dict: A dictionary containing the model settings.
         """
-        if input_params["use_doc_preprocessor"]:
-            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
-            use_doc_unwarping = input_params["use_doc_unwarping"]
-            doc_preprocessor_res = next(
-                self.doc_preprocessor_pipeline(
-                    image_array,
-                    use_doc_orientation_classify=use_doc_orientation_classify,
-                    use_doc_unwarping=use_doc_unwarping,
-                )
-            )
-            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        if use_doc_orientation_classify is None and use_doc_unwarping is None:
+            use_doc_preprocessor = self.use_doc_preprocessor
         else:
-            doc_preprocessor_res = {}
-            doc_preprocessor_image = image_array
-        return doc_preprocessor_res, doc_preprocessor_image
+            if use_doc_orientation_classify is True or use_doc_unwarping is True:
+                use_doc_preprocessor = True
+            else:
+                use_doc_preprocessor = False
 
-    def predict_overall_ocr_res(self, image_array: np.ndarray) -> OCRResult:
-        """
-        Predict the overall OCR result for the given image array.
+        if use_general_ocr is None:
+            use_general_ocr = self.use_general_ocr
 
-        Args:
-            image_array (np.ndarray): The input image array to perform OCR on.
+        if use_seal_recognition is None:
+            use_seal_recognition = self.use_seal_recognition
 
-        Returns:
-            OCRResult: The predicted OCR result with updated dt_boxes.
-        """
-        overall_ocr_res = next(self.general_ocr_pipeline(image_array))
-        dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
-        overall_ocr_res["dt_boxes"] = dt_boxes
-        return overall_ocr_res
+        if use_table_recognition is None:
+            use_table_recognition = self.use_table_recognition
+
+        return dict(
+            use_doc_preprocessor=use_doc_preprocessor,
+            use_general_ocr=use_general_ocr,
+            use_seal_recognition=use_seal_recognition,
+            use_table_recognition=use_table_recognition,
+        )
 
     def predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        use_doc_orientation_classify: bool = False,
-        use_doc_unwarping: bool = False,
-        use_general_ocr: bool = True,
-        use_seal_recognition: bool = True,
-        use_table_recognition: bool = True,
-        **kwargs
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_general_ocr: Optional[bool] = None,
+        use_seal_recognition: Optional[bool] = None,
+        use_table_recognition: Optional[bool] = None,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_rec_score_thresh: Optional[float] = None,
+        seal_det_limit_side_len: Optional[int] = None,
+        seal_det_limit_type: Optional[str] = None,
+        seal_det_thresh: Optional[float] = None,
+        seal_det_box_thresh: Optional[float] = None,
+        seal_det_unclip_ratio: Optional[float] = None,
+        seal_rec_score_thresh: Optional[float] = None,
+        **kwargs,
     ) -> LayoutParsingResult:
         """
         This function predicts the layout parsing result for the given input.
@@ -262,82 +267,110 @@ class LayoutParsingPipeline(BasePipeline):
             LayoutParsingResult: The predicted layout parsing result.
         """
 
-        input_params = {
-            "use_doc_preprocessor": self.use_doc_preprocessor,
-            "use_doc_orientation_classify": use_doc_orientation_classify,
-            "use_doc_unwarping": use_doc_unwarping,
-            "use_general_ocr": use_general_ocr,
-            "use_seal_recognition": use_seal_recognition,
-            "use_table_recognition": use_table_recognition,
-        }
-
-        if use_doc_orientation_classify or use_doc_unwarping:
-            input_params["use_doc_preprocessor"] = True
-        else:
-            input_params["use_doc_preprocessor"] = False
+        model_settings = self.get_model_settings(
+            use_doc_orientation_classify,
+            use_doc_unwarping,
+            use_general_ocr,
+            use_seal_recognition,
+            use_table_recognition,
+        )
 
-        if not self.check_input_params_valid(input_params):
-            yield None
+        if not self.check_model_settings_valid(model_settings):
+            yield {"error": "the input params for model settings are invalid!"}
 
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            if not isinstance(batch_data[0], str):
+                # TODO: add support input_pth for ndarray and pdf
+                input_path = f"{img_id}"
+            else:
+                input_path = batch_data[0]
+
             image_array = self.img_reader(batch_data)[0]
-            img_id += 1
 
-            doc_preprocessor_res, doc_preprocessor_image = (
-                self.predict_doc_preprocessor_res(image_array, input_params)
-            )
+            if model_settings["use_doc_preprocessor"]:
+                doc_preprocessor_res = next(
+                    self.doc_preprocessor_pipeline(
+                        image_array,
+                        use_doc_orientation_classify=use_doc_orientation_classify,
+                        use_doc_unwarping=use_doc_unwarping,
+                    )
+                )
+            else:
+                doc_preprocessor_res = {"output_img": image_array}
+
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
 
             layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
 
-            if input_params["use_general_ocr"] or input_params["use_table_recognition"]:
-                overall_ocr_res = self.predict_overall_ocr_res(doc_preprocessor_image)
+            if (
+                model_settings["use_general_ocr"]
+                or model_settings["use_table_recognition"]
+            ):
+                overall_ocr_res = next(
+                    self.general_ocr_pipeline(
+                        doc_preprocessor_image,
+                        text_det_limit_side_len=text_det_limit_side_len,
+                        text_det_limit_type=text_det_limit_type,
+                        text_det_thresh=text_det_thresh,
+                        text_det_box_thresh=text_det_box_thresh,
+                        text_det_unclip_ratio=text_det_unclip_ratio,
+                        text_rec_score_thresh=text_rec_score_thresh,
+                    )
+                )
             else:
                 overall_ocr_res = {}
 
-            if input_params["use_general_ocr"]:
+            if model_settings["use_general_ocr"]:
                 text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
                     overall_ocr_res, layout_det_res
                 )
             else:
                 text_paragraphs_ocr_res = {}
 
-            if input_params["use_table_recognition"]:
-                table_res_list = next(
+            if model_settings["use_table_recognition"]:
+                table_res_all = next(
                     self.table_recognition_pipeline(
                         doc_preprocessor_image,
-                        use_layout_detection=False,
                         use_doc_orientation_classify=False,
                         use_doc_unwarping=False,
+                        use_layout_detection=False,
+                        use_ocr_model=False,
                         overall_ocr_res=overall_ocr_res,
                         layout_det_res=layout_det_res,
                     )
                 )
-                table_res_list = table_res_list["table_res_list"]
+                table_res_list = table_res_all["table_res_list"]
             else:
                 table_res_list = []
 
-            if input_params["use_seal_recognition"]:
-                seal_res_list = next(
+            if model_settings["use_seal_recognition"]:
+                seal_res_all = next(
                     self.seal_recognition_pipeline(
                         doc_preprocessor_image,
-                        use_layout_detection=False,
                         use_doc_orientation_classify=False,
                         use_doc_unwarping=False,
+                        use_layout_detection=False,
                         layout_det_res=layout_det_res,
+                        seal_det_limit_side_len=seal_det_limit_side_len,
+                        seal_det_limit_type=seal_det_limit_type,
+                        seal_det_thresh=seal_det_thresh,
+                        seal_det_box_thresh=seal_det_box_thresh,
+                        seal_det_unclip_ratio=seal_det_unclip_ratio,
+                        seal_rec_score_thresh=seal_rec_score_thresh,
                     )
                 )
-                seal_res_list = seal_res_list["seal_res_list"]
+                seal_res_list = seal_res_all["seal_res_list"]
             else:
                 seal_res_list = []
 
             single_img_res = {
-                "layout_det_res": layout_det_res,
+                "input_path": input_path,
                 "doc_preprocessor_res": doc_preprocessor_res,
+                "layout_det_res": layout_det_res,
                 "overall_ocr_res": overall_ocr_res,
                 "text_paragraphs_ocr_res": text_paragraphs_ocr_res,
                 "table_res_list": table_res_list,
                 "seal_res_list": seal_res_list,
-                "input_params": input_params,
-                "img_id": img_id,
+                "model_settings": model_settings,
             }
             yield LayoutParsingResult(single_img_res)

+ 149 - 51
paddlex/inference/pipelines_new/layout_parsing/result.py

@@ -13,72 +13,170 @@
 # limitations under the License.
 
 import os
-from pathlib import Path
+from typing import Dict
+import numpy as np
+import copy
+import cv2
+from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, StrMixin, JsonMixin
 
 
-class LayoutParsingResult(dict):
+class LayoutParsingResult(BaseCVResult, HtmlMixin, XlsxMixin):
     """Layout Parsing Result"""
 
     def __init__(self, data) -> None:
         """Initializes a new instance of the class with the specified data."""
         super().__init__(data)
+        HtmlMixin.__init__(self)
+        XlsxMixin.__init__(self)
 
-    def save_results(self, save_path: str) -> None:
-        """Save the layout parsing results to the specified directory.
+    def _to_img(self) -> Dict[str, np.ndarray]:
+        res_img_dict = {}
+        model_settings = self["model_settings"]
+        if model_settings["use_doc_preprocessor"]:
+            res_img_dict.update(**self["doc_preprocessor_res"].img)
+        res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
+
+        if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
+            res_img_dict["overall_ocr_res"] = self["overall_ocr_res"].img["ocr_res_img"]
+
+        if model_settings["use_general_ocr"]:
+            general_ocr_res = copy.deepcopy(self["overall_ocr_res"])
+            general_ocr_res["rec_polys"] = self["text_paragraphs_ocr_res"]["rec_polys"]
+            general_ocr_res["rec_texts"] = self["text_paragraphs_ocr_res"]["rec_texts"]
+            general_ocr_res["rec_scores"] = self["text_paragraphs_ocr_res"][
+                "rec_scores"
+            ]
+            general_ocr_res["rec_boxes"] = self["text_paragraphs_ocr_res"]["rec_boxes"]
+            res_img_dict["text_paragraphs_ocr_res"] = general_ocr_res.img["ocr_res_img"]
+
+        if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
+            table_cell_img = copy.deepcopy(self["doc_preprocessor_res"]["output_img"])
+            for sno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][sno]
+                cell_box_list = table_res["cell_box_list"]
+                for box in cell_box_list:
+                    x1, y1, x2, y2 = [int(pos) for pos in box]
+                    cv2.rectangle(table_cell_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
+            res_img_dict["table_cell_img"] = table_cell_img
+
+        if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
+            for sno in range(len(self["seal_res_list"])):
+                seal_res = self["seal_res_list"][sno]
+                seal_region_id = seal_res["seal_region_id"]
+                sub_seal_res_dict = seal_res.img
+                key = f"seal_res_region{seal_region_id}"
+                res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
+        return res_img_dict
+
+    def _to_str(self, *args, **kwargs) -> Dict[str, str]:
+        """Converts the instance's attributes to a dictionary and then to a string.
 
         Args:
-            save_path (str): The directory path to save the results.
+            *args: Additional positional arguments passed to the base class method.
+            **kwargs: Additional keyword arguments passed to the base class method.
+
+        Returns:
+            Dict[str, str]: A dictionary with the instance's attributes converted to strings.
         """
+        data = {}
+        data["input_path"] = self["input_path"]
+        model_settings = self["model_settings"]
+        data["model_settings"] = model_settings
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
+        data["layout_det_res"] = self["layout_det_res"].str["res"]
+        if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
+            data["overall_ocr_res"] = self["overall_ocr_res"].str["res"]
+        if model_settings["use_general_ocr"]:
+            general_ocr_res = {}
+            general_ocr_res["rec_polys"] = self["text_paragraphs_ocr_res"]["rec_polys"]
+            general_ocr_res["rec_texts"] = self["text_paragraphs_ocr_res"]["rec_texts"]
+            general_ocr_res["rec_scores"] = self["text_paragraphs_ocr_res"][
+                "rec_scores"
+            ]
+            general_ocr_res["rec_boxes"] = self["text_paragraphs_ocr_res"]["rec_boxes"]
+            data["text_paragraphs_ocr_res"] = general_ocr_res
+        if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
+            data["table_res_list"] = []
+            for sno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][sno]
+                data["table_res_list"].append(table_res.str["res"])
+        if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
+            data["seal_res_list"] = []
+            for sno in range(len(self["seal_res_list"])):
+                seal_res = self["seal_res_list"][sno]
+                data["seal_res_list"].append(seal_res.str["res"])
+        return StrMixin._to_str(data, *args, **kwargs)
 
-        if not os.path.isdir(save_path):
-            return
+    def _to_json(self, *args, **kwargs) -> Dict[str, str]:
+        """
+        Converts the object's data to a JSON dictionary.
 
-        img_id = self["img_id"]
-        layout_det_res = self["layout_det_res"]
-        save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg"
-        layout_det_res.save_to_img(save_img_path)
+        Args:
+            *args: Positional arguments passed to the JsonMixin._to_json method.
+            **kwargs: Keyword arguments passed to the JsonMixin._to_json method.
 
-        input_params = self["input_params"]
-        if input_params["use_doc_preprocessor"]:
-            save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg"
-            self["doc_preprocessor_res"].save_to_img(save_img_path)
+        Returns:
+            Dict[str, str]: A dictionary containing the object's data in JSON format.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        model_settings = self["model_settings"]
+        data["model_settings"] = model_settings
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
+        data["layout_det_res"] = self["layout_det_res"].json["res"]
+        if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
+            data["overall_ocr_res"] = self["overall_ocr_res"].json["res"]
+        if model_settings["use_general_ocr"]:
+            general_ocr_res = {}
+            general_ocr_res["rec_polys"] = self["text_paragraphs_ocr_res"]["rec_polys"]
+            general_ocr_res["rec_texts"] = self["text_paragraphs_ocr_res"]["rec_texts"]
+            general_ocr_res["rec_scores"] = self["text_paragraphs_ocr_res"][
+                "rec_scores"
+            ]
+            general_ocr_res["rec_boxes"] = self["text_paragraphs_ocr_res"]["rec_boxes"]
+            data["text_paragraphs_ocr_res"] = general_ocr_res
+        if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
+            data["table_res_list"] = []
+            for sno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][sno]
+                data["table_res_list"].append(table_res.json["res"])
+        if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
+            data["seal_res_list"] = []
+            for sno in range(len(self["seal_res_list"])):
+                seal_res = self["seal_res_list"][sno]
+                data["seal_res_list"].append(seal_res.json["res"])
+        return JsonMixin._to_json(data, *args, **kwargs)
 
-        if input_params["use_general_ocr"]:
-            save_img_path = (
-                Path(save_path) / f"text_paragraphs_ocr_result_img{img_id}.jpg"
-            )
-            self["text_paragraphs_ocr_res"].save_to_img(save_img_path)
+    def _to_html(self) -> Dict[str, str]:
+        """Converts the prediction to its corresponding HTML representation.
 
-        if input_params["use_general_ocr"] or input_params["use_table_recognition"]:
-            save_img_path = Path(save_path) / f"overall_ocr_result_img{img_id}.jpg"
-            self["overall_ocr_res"].save_to_img(save_img_path)
+        Returns:
+            Dict[str, str]: The str type HTML representation result.
+        """
+        model_settings = self["model_settings"]
+        res_html_dict = {}
+        if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
+            for sno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][sno]
+                table_region_id = table_res["table_region_id"]
+                key = f"table_{table_region_id}"
+                res_html_dict[key] = table_res.html["pred"]
+        return res_html_dict
 
-        if input_params["use_table_recognition"]:
-            for tno in range(len(self["table_res_list"])):
-                table_res = self["table_res_list"][tno]
+    def _to_xlsx(self) -> Dict[str, str]:
+        """Converts the prediction HTML to an XLSX file path.
+
+        Returns:
+            Dict[str, str]: The str type XLSX representation result.
+        """
+        model_settings = self["model_settings"]
+        res_xlsx_dict = {}
+        if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
+            for sno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][sno]
                 table_region_id = table_res["table_region_id"]
-                save_img_path = (
-                    Path(save_path)
-                    / f"table_res_cell_img{img_id}_region{table_region_id}.jpg"
-                )
-                table_res.save_to_img(save_img_path)
-                save_html_path = (
-                    Path(save_path)
-                    / f"table_res_img{img_id}_region{table_region_id}.html"
-                )
-                table_res.save_to_html(save_html_path)
-                save_xlsx_path = (
-                    Path(save_path)
-                    / f"table_res_img{img_id}_region{table_region_id}.xlsx"
-                )
-                table_res.save_to_xlsx(save_xlsx_path)
-
-        if input_params["use_seal_recognition"]:
-            for sno in range(len(self["seal_res_list"])):
-                seal_res = self["seal_res_list"][sno]
-                seal_region_id = seal_res["seal_region_id"]
-                save_img_path = (
-                    Path(save_path) / f"seal_res_img{img_id}_region{seal_region_id}.jpg"
-                )
-                seal_res.save_to_img(save_img_path)
-        return
+                key = f"table_{table_region_id}"
+                res_xlsx_dict[key] = table_res.xlsx["pred"]
+        return res_xlsx_dict

+ 4 - 1
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -214,7 +214,10 @@ class OCRPipeline(BasePipeline):
         if use_doc_orientation_classify is None and use_doc_unwarping is None:
             use_doc_preprocessor = self.use_doc_preprocessor
         else:
-            use_doc_preprocessor = True
+            if use_doc_orientation_classify is True or use_doc_unwarping is True:
+                use_doc_preprocessor = True
+            else:
+                use_doc_preprocessor = False
 
         if use_textline_orientation is None:
             use_textline_orientation = self.use_textline_orientation

+ 4 - 1
paddlex/inference/pipelines_new/seal_recognition/pipeline.py

@@ -141,7 +141,10 @@ class SealRecognitionPipeline(BasePipeline):
         if use_doc_orientation_classify is None and use_doc_unwarping is None:
             use_doc_preprocessor = self.use_doc_preprocessor
         else:
-            use_doc_preprocessor = True
+            if use_doc_orientation_classify is True or use_doc_unwarping is True:
+                use_doc_preprocessor = True
+            else:
+                use_doc_preprocessor = False
 
         if use_layout_detection is None:
             use_layout_detection = self.use_layout_detection

+ 4 - 1
paddlex/inference/pipelines_new/table_recognition/pipeline.py

@@ -120,7 +120,10 @@ class TableRecognitionPipeline(BasePipeline):
         if use_doc_orientation_classify is None and use_doc_unwarping is None:
             use_doc_preprocessor = self.use_doc_preprocessor
         else:
-            use_doc_preprocessor = True
+            if use_doc_orientation_classify is True or use_doc_unwarping is True:
+                use_doc_preprocessor = True
+            else:
+                use_doc_preprocessor = False
 
         if use_layout_detection is None:
             use_layout_detection = self.use_layout_detection