Browse Source

fix input params and save_to_img (#2847)

* add the new architecture of pipelines

* add the new architecture of pipelines

* add explanatory note

* add explanatory note

* fix some modules name

* add pipelines of single modual, sseal recogniton and table recognition.

* support tacking pdf and original image

* add PP-ChatOCRv4 and support PDF

* add PP-ChatOCRv4 and support PDF

* modify layout parsing and pp-chatocr to support different versions

* modify layout parsing and pp-chatocr to support different versions

* modify layout parsing and pp-chatocr to support different versions

* fix input params and save_to_img

* fix input params and save_to_img

* fix input params and save_to_img

* fix input params and save_to_img

* fix input params and save_to_img for doc preprocessor, OCR, seal recognition, table recognition

* fix input params and save_to_img for doc preprocessor, OCR, seal recognition, table recognition
dyning 10 tháng trước cách đây
mục cha
commit
44043bd950
24 tập tin đã thay đổi với 814 bổ sung419 xóa
  1. 2 0
      api_examples/pipelines/test_doc_preprocessor.py
  2. 5 2
      api_examples/pipelines/test_ocr.py
  3. 31 4
      api_examples/pipelines/test_seal_recognition.py
  4. 26 4
      api_examples/pipelines/test_table_recognition.py
  5. 0 2
      paddlex/configs/pipelines/OCR.yaml
  6. 9 5
      paddlex/configs/pipelines/seal_recognition.yaml
  7. 11 5
      paddlex/configs/pipelines/table_recognition.yaml
  8. 46 8
      paddlex/inference/common/result/mixin.py
  9. 5 0
      paddlex/inference/pipelines_new/base.py
  10. 1 0
      paddlex/inference/pipelines_new/components/__init__.py
  11. 1 0
      paddlex/inference/pipelines_new/components/common/__init__.py
  12. 46 0
      paddlex/inference/pipelines_new/components/common/convert_points_and_boxes.py
  13. 13 11
      paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py
  14. 35 2
      paddlex/inference/pipelines_new/doc_preprocessor/result.py
  15. 3 1
      paddlex/inference/pipelines_new/formula_recognition/pipeline.py
  16. 2 1
      paddlex/inference/pipelines_new/layout_parsing/pipeline.py
  17. 18 44
      paddlex/inference/pipelines_new/layout_parsing/utils.py
  18. 66 89
      paddlex/inference/pipelines_new/ocr/pipeline.py
  19. 63 10
      paddlex/inference/pipelines_new/ocr/result.py
  20. 103 83
      paddlex/inference/pipelines_new/seal_recognition/pipeline.py
  21. 59 28
      paddlex/inference/pipelines_new/seal_recognition/result.py
  22. 128 59
      paddlex/inference/pipelines_new/table_recognition/pipeline.py
  23. 133 56
      paddlex/inference/pipelines_new/table_recognition/result.py
  24. 8 5
      paddlex/inference/pipelines_new/table_recognition/table_recognition_post_processing.py

+ 2 - 0
api_examples/pipelines/test_doc_preprocessor.py

@@ -48,4 +48,6 @@ output = pipeline.predict(
 
 for res in output:
     print(res)
+    res.print()
     res.save_to_img("./output")
+    res.save_to_json("./output")

+ 5 - 2
api_examples/pipelines/test_ocr.py

@@ -75,10 +75,13 @@ output = pipeline.predict(
 #     use_textline_orientation=True
 # )
 
+# output = pipeline.predict(
+#     "./test_samples/general_ocr_002.png")
+
 # output = pipeline.predict("./test_samples/财报1.pdf")
 
 for res in output:
     print(res)
+    res.print()
     res.save_to_img("./output")
-    # TODO: need to check the json format
-    # res.save_to_json("./output/res.json")
+    res.save_to_json("./output")

+ 31 - 4
api_examples/pipelines/test_seal_recognition.py

@@ -15,13 +15,40 @@
 from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="seal_recognition")
-output = pipeline.predict("./test_samples/seal_text_det.png")
+output = pipeline.predict(
+    "./test_samples/seal_text_det.png",
+    use_doc_orientation_classify=False,
+    use_doc_unwarping=False,
+)
 
-# output = pipeline.predict("./test_samples/seal_text_det.png",
-#     use_layout_detection=False)
+# output = pipeline.predict(
+#     "./test_samples/seal_text_det.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     text_rec_score_thresh = 0.9
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/seal_text_det.png",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/seal_text_det.png",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_layout_detection=False
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/seal_text_det.png"
+# )
 
 # output = pipeline.predict("./test_samples/财报1.pdf")
 
 for res in output:
     print(res)
-    res.save_results("./output")
+    res.print()
+    res.save_to_img("./output")
+    res.save_to_json("./output")

+ 26 - 4
api_examples/pipelines/test_table_recognition.py

@@ -16,12 +16,34 @@ from paddlex import create_pipeline
 
 pipeline = create_pipeline(pipeline="table_recognition")
 
-output = pipeline("./test_samples/table_recognition.jpg")
+output = pipeline.predict(
+    "./test_samples/table_recognition.jpg",
+    use_doc_orientation_classify=False,
+    use_doc_unwarping=False,
+)
 
-# output = pipeline("./test_samples/table_recognition.jpg",
-#     use_layout_detection=False)
+# output = pipeline.predict(
+#     "./test_samples/table_recognition.jpg",
+#     use_doc_orientation_classify=True,
+#     use_doc_unwarping=True
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/table_recognition.jpg",
+#     use_doc_orientation_classify=False,
+#     use_doc_unwarping=False,
+#     use_layout_detection=False
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/table_recognition.jpg"
+# )
 
 # output = pipeline("./test_samples/财报1.pdf")
 for res in output:
     print(res)
-    res.save_results("./output/")
+    res.print()
+    res.save_to_img("./output")
+    res.save_to_json("./output")
+    res.save_to_xlsx("./output")
+    res.save_to_html("./output")

+ 0 - 2
paddlex/configs/pipelines/OCR.yaml

@@ -30,9 +30,7 @@ SubModules:
     limit_type: max
     thresh: 0.3
     box_thresh: 0.6
-    max_candidates: 1000
     unclip_ratio: 2.0
-    use_dilation: False
   TextLineOrientation:
     module_name: textline_orientation
     model_name: PP-LCNet_x0_25_textline_ori 

+ 9 - 5
paddlex/configs/pipelines/seal_recognition.yaml

@@ -1,15 +1,14 @@
 
 pipeline_name: seal_recognition
 
-use_layout_detection: True
 use_doc_preprocessor: True
+use_layout_detection: True
 
 SubModules:
   LayoutDetection:
     module_name: layout_detection
     model_name: RT-DETR-H_layout_3cls
     model_dir: null
-    batch_size: 1
 
 SubPipelines:
   DocPreprocessor:
@@ -21,23 +20,28 @@ SubPipelines:
         module_name: doc_text_orientation
         model_name: PP-LCNet_x1_0_doc_ori
         model_dir: null
-        batch_size: 1
       DocUnwarping:
         module_name: image_unwarping
         model_name: UVDoc
         model_dir: null
-        batch_size: 1
   SealOCR:
     pipeline_name: OCR
     text_type: seal
+    use_doc_preprocessor: False
+    use_textline_orientation: False
     SubModules:
       TextDetection:
         module_name: seal_text_detection
         model_name: PP-OCRv4_server_seal_det
         model_dir: null
-        batch_size: 1    
+        limit_side_len: 736
+        limit_type: min
+        thresh: 0.2
+        box_thresh: 0.6
+        unclip_ratio: 0.5
       TextRecognition:
         module_name: text_recognition
         model_name: PP-OCRv4_server_rec
         model_dir: null
         batch_size: 1
+        score_thresh: 0

+ 11 - 5
paddlex/configs/pipelines/table_recognition.yaml

@@ -10,13 +10,11 @@ SubModules:
     module_name: layout_detection
     model_name: RT-DETR-H_layout_3cls
     model_dir: null
-    batch_size: 1
 
   TableStructureRecognition:
     module_name: table_structure_recognition
     model_name: SLANet_plus
     model_dir: null
-    batch_size: 1
 
 SubPipelines:
   DocPreprocessor:
@@ -28,23 +26,31 @@ SubPipelines:
         module_name: doc_text_orientation
         model_name: PP-LCNet_x1_0_doc_ori
         model_dir: null
-        batch_size: 1
+
       DocUnwarping:
         module_name: image_unwarping
         model_name: UVDoc
         model_dir: null
-        batch_size: 1
+
   GeneralOCR:
     pipeline_name: OCR
     text_type: general
+    use_doc_preprocessor: False
+    use_textline_orientation: False
     SubModules:
       TextDetection:
         module_name: text_detection
         model_name: PP-OCRv4_server_det
         model_dir: null
-        batch_size: 1    
+        limit_side_len: 960
+        limit_type: max
+        thresh: 0.3
+        box_thresh: 0.6
+        unclip_ratio: 2.0
+        
       TextRecognition:
         module_name: text_recognition
         model_name: PP-OCRv4_server_rec
         model_dir: null
         batch_size: 1
+        score_thresh: 0

+ 46 - 8
paddlex/inference/common/result/mixin.py

@@ -398,11 +398,28 @@ class HtmlMixin:
             *args: Additional positional arguments.
             **kwargs: Additional keyword arguments.
         """
-        if not str(save_path).endswith(".html"):
-            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
+
+        def _is_html_file(file_path):
+            mime_type, _ = mimetypes.guess_type(file_path)
+            return mime_type is not None and mime_type == "text/html"
+
+        if not _is_html_file(save_path):
+            fp = Path(self["input_path"])
+            stem = fp.stem
+            base_save_path = Path(save_path)
+            for key in self.html:
+                save_path = base_save_path / f"{stem}_{key}.html"
+                self._html_writer.write(
+                    save_path.as_posix(), self.html[key], *args, **kwargs
+                )
         else:
-            save_path = Path(save_path)
-        self._html_writer.write(save_path.as_posix(), self.html["res"], *args, **kwargs)
+            if len(self.html) > 1:
+                logging.warning(
+                    f"The result has multiple html files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
+                )
+            self._html_writer.write(
+                save_path, self.html[list(self.html.keys())[0]], *args, **kwargs
+            )
 
 
 class XlsxMixin:
@@ -445,11 +462,32 @@ class XlsxMixin:
             *args: Additional positional arguments to pass to the XLSX writer.
             **kwargs: Additional keyword arguments to pass to the XLSX writer.
         """
-        if not str(save_path).endswith(".xlsx"):
-            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
+
+        def _is_xlsx_file(file_path):
+            mime_type, _ = mimetypes.guess_type(file_path)
+            return (
+                mime_type is not None
+                and mime_type
+                == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
+            )
+
+        if not _is_xlsx_file(save_path):
+            fp = Path(self["input_path"])
+            stem = fp.stem
+            base_save_path = Path(save_path)
+            for key in self.xlsx:
+                save_path = base_save_path / f"{stem}_{key}.xlsx"
+                self._xlsx_writer.write(
+                    save_path.as_posix(), self.xlsx[key], *args, **kwargs
+                )
         else:
-            save_path = Path(save_path)
-        self._xlsx_writer.write(save_path.as_posix(), self.xlsx, *args, **kwargs)
+            if len(self.xlsx) > 1:
+                logging.warning(
+                    f"The result has multiple xlsx files need to be saved. But the `save_path` has been specfied as `{save_path}`!"
+                )
+            self._xlsx_writer.write(
+                save_path, self.xlsx[list(self.xlsx.keys())[0]], *args, **kwargs
+            )
 
 
 class VideoMixin:

+ 5 - 0
paddlex/inference/pipelines_new/base.py

@@ -78,6 +78,8 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         Returns:
             BasePredictor: An instance of the model.
         """
+        if "model_config_error" in config:
+            raise ValueError(config["model_config_error"])
 
         model_dir = config["model_dir"]
         if model_dir == None:
@@ -105,6 +107,9 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         Returns:
             BasePipeline: An instance of the created pipeline.
         """
+        if "pipeline_config_error" in config:
+            raise ValueError(config["pipeline_config_error"])
+
         from . import create_pipeline
 
         pipeline_name = config["pipeline_name"]

+ 1 - 0
paddlex/inference/pipelines_new/components/__init__.py

@@ -15,6 +15,7 @@
 from .common import CVResult, BaseResult
 from .common import SortQuadBoxes, SortPolyBoxes
 from .common import CropByPolys, CropByBoxes
+from .common import convert_points_to_boxes
 from .utils.mixin import HtmlMixin, XlsxMixin
 from .chat_server.base import BaseChat
 from .retriever.base import BaseRetriever

+ 1 - 0
paddlex/inference/pipelines_new/components/common/__init__.py

@@ -15,3 +15,4 @@
 from .base_result import CVResult, BaseResult
 from .sort_boxes import SortQuadBoxes, SortPolyBoxes
 from .crop_image_regions import CropByPolys, CropByBoxes
+from .convert_points_and_boxes import convert_points_to_boxes

+ 46 - 0
paddlex/inference/pipelines_new/components/common/convert_points_and_boxes.py

@@ -0,0 +1,46 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["convert_points_to_boxes"]
+
+import numpy as np
+import copy
+
+
+def convert_points_to_boxes(dt_polys: list) -> np.ndarray:
+    """
+    Converts a list of polygons to a numpy array of bounding boxes.
+
+    Args:
+        dt_polys (list): A list of polygons, where each polygon is represented
+                        as a list of (x, y) points.
+
+    Returns:
+        np.ndarray: A numpy array of bounding boxes, where each box is represented
+                    as [left, top, right, bottom].
+                    If the input list is empty, returns an empty numpy array.
+    """
+
+    if len(dt_polys) > 0:
+        dt_polys_tmp = dt_polys.copy()
+        dt_polys_tmp = np.array(dt_polys_tmp)
+        boxes_left = np.min(dt_polys_tmp[:, :, 0], axis=1)
+        boxes_right = np.max(dt_polys_tmp[:, :, 0], axis=1)
+        boxes_top = np.min(dt_polys_tmp[:, :, 1], axis=1)
+        boxes_bottom = np.max(dt_polys_tmp[:, :, 1], axis=1)
+        dt_boxes = np.array([boxes_left, boxes_top, boxes_right, boxes_bottom])
+        dt_boxes = dt_boxes.T
+    else:
+        dt_boxes = np.array([])
+    return dt_boxes

+ 13 - 11
paddlex/inference/pipelines_new/doc_preprocessor/pipeline.py

@@ -50,20 +50,22 @@ class DocPreprocessorPipeline(BasePipeline):
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
-        self.use_doc_orientation_classify = True
-        if "use_doc_orientation_classify" in config:
-            self.use_doc_orientation_classify = config["use_doc_orientation_classify"]
-
-        self.use_doc_unwarping = True
-        if "use_doc_unwarping" in config:
-            self.use_doc_unwarping = config["use_doc_unwarping"]
-
+        self.use_doc_orientation_classify = config.get(
+            "use_doc_orientation_classify", True
+        )
         if self.use_doc_orientation_classify:
-            doc_ori_classify_config = config["SubModules"]["DocOrientationClassify"]
+            doc_ori_classify_config = config.get("SubModules", {}).get(
+                "DocOrientationClassify",
+                {"model_config_error": "config error for doc_ori_classify_model!"},
+            )
             self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
 
+        self.use_doc_unwarping = config.get("use_doc_unwarping", True)
         if self.use_doc_unwarping:
-            doc_unwarping_config = config["SubModules"]["DocUnwarping"]
+            doc_unwarping_config = config.get("SubModules", {}).get(
+                "DocUnwarping",
+                {"model_config_error": "config error for doc_unwarping_model!"},
+            )
             self.doc_unwarping_model = self.create_model(doc_unwarping_config)
 
         self.batch_sampler = ImageBatchSampler(batch_size=1)
@@ -188,7 +190,7 @@ class DocPreprocessorPipeline(BasePipeline):
 
             single_img_res = {
                 "input_path": input_path,
-                "input_image": image_array,
+                "input_img": image_array,
                 "model_settings": model_settings,
                 "angle": angle,
                 "rot_img": rot_img,

+ 35 - 2
paddlex/inference/pipelines_new/doc_preprocessor/result.py

@@ -22,7 +22,7 @@ import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
-from ...common.result import BaseCVResult
+from ...common.result import BaseCVResult, StrMixin, JsonMixin
 
 
 class DocPreprocessorResult(BaseCVResult):
@@ -35,7 +35,7 @@ class DocPreprocessorResult(BaseCVResult):
         Returns:
             Dict[Image.Image]: A new image combining the original, rotated, and unwarping images
         """
-        image = self["input_image"][:, :, ::-1]
+        image = self["input_img"][:, :, ::-1]
         rot_img = self["rot_img"][:, :, ::-1]
         angle = self["angle"]
         output_img = self["output_img"][:, :, ::-1]
@@ -66,3 +66,36 @@ class DocPreprocessorResult(BaseCVResult):
             )
         imgs = {"preprocessed_img": img_show}
         return imgs
+
+    def _to_str(self, *args, **kwargs) -> Dict[str, str]:
+        """Converts the instance's attributes to a dictionary and then to a string.
+
+        Args:
+            *args: Additional positional arguments passed to the base class method.
+            **kwargs: Additional keyword arguments passed to the base class method.
+
+        Returns:
+            Dict[str, str]: A dictionary with the instance's attributes converted to strings.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        data["angle"] = self["angle"]
+        return StrMixin._to_str(data, *args, **kwargs)
+
+    def _to_json(self, *args, **kwargs) -> Dict[str, str]:
+        """
+        Converts the object's data to a JSON dictionary.
+
+        Args:
+            *args: Positional arguments passed to the JsonMixin._to_json method.
+            **kwargs: Keyword arguments passed to the JsonMixin._to_json method.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the object's data in JSON format.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        data["angle"] = self["angle"]
+        return JsonMixin._to_json(data, *args, **kwargs)

+ 3 - 1
paddlex/inference/pipelines_new/formula_recognition/pipeline.py

@@ -18,7 +18,9 @@ import numpy as np
 import cv2
 from ..base import BasePipeline
 from ..components import CropByBoxes
-from ..layout_parsing.utils import convert_points_to_boxes
+
+# from ..layout_parsing.utils import convert_points_to_boxes
+from ..components import convert_points_to_boxes
 
 from .result import FormulaRecognitionResult
 from ...models_new.formula_recognition.result import (

+ 2 - 1
paddlex/inference/pipelines_new/layout_parsing/pipeline.py

@@ -17,7 +17,8 @@ import os, sys
 import numpy as np
 import cv2
 from ..base import BasePipeline
-from .utils import convert_points_to_boxes, get_sub_regions_ocr_res
+from .utils import get_sub_regions_ocr_res
+from ..components import convert_points_to_boxes
 from .result import LayoutParsingResult
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption

+ 18 - 44
paddlex/inference/pipelines_new/layout_parsing/utils.py

@@ -12,41 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-__all__ = ["convert_points_to_boxes", "get_sub_regions_ocr_res"]
+__all__ = ["get_sub_regions_ocr_res"]
 
 import numpy as np
 import copy
 from ..ocr.result import OCRResult
 
 
-def convert_points_to_boxes(dt_polys: list) -> np.ndarray:
-    """
-    Converts a list of polygons to a numpy array of bounding boxes.
-
-    Args:
-        dt_polys (list): A list of polygons, where each polygon is represented
-                        as a list of (x, y) points.
-
-    Returns:
-        np.ndarray: A numpy array of bounding boxes, where each box is represented
-                    as [left, top, right, bottom].
-                    If the input list is empty, returns an empty numpy array.
-    """
-
-    if len(dt_polys) > 0:
-        dt_polys_tmp = dt_polys.copy()
-        dt_polys_tmp = np.array(dt_polys_tmp)
-        boxes_left = np.min(dt_polys_tmp[:, :, 0], axis=1)
-        boxes_right = np.max(dt_polys_tmp[:, :, 0], axis=1)
-        boxes_top = np.min(dt_polys_tmp[:, :, 1], axis=1)
-        boxes_bottom = np.max(dt_polys_tmp[:, :, 1], axis=1)
-        dt_boxes = np.array([boxes_left, boxes_top, boxes_right, boxes_bottom])
-        dt_boxes = dt_boxes.T
-    else:
-        dt_boxes = np.array([])
-    return dt_boxes
-
-
 def get_overlap_boxes_idx(src_boxes: np.ndarray, ref_boxes: np.ndarray) -> list:
     """
     Get the indices of source boxes that overlap with reference boxes based on a specified threshold.
@@ -88,17 +60,13 @@ def get_sub_regions_ocr_res(
     Returns:
         OCRResult: A filtered OCR result containing only the relevant text boxes.
     """
-    sub_regions_ocr_res = copy.deepcopy(overall_ocr_res)
-    sub_regions_ocr_res["doc_preprocessor_image"] = overall_ocr_res[
-        "doc_preprocessor_image"
-    ]
-    sub_regions_ocr_res["img_id"] = -1
-    sub_regions_ocr_res["dt_polys"] = []
-    sub_regions_ocr_res["rec_text"] = []
-    sub_regions_ocr_res["rec_score"] = []
-    sub_regions_ocr_res["dt_boxes"] = []
+    sub_regions_ocr_res = {}
+    sub_regions_ocr_res["rec_polys"] = []
+    sub_regions_ocr_res["rec_texts"] = []
+    sub_regions_ocr_res["rec_scores"] = []
+    sub_regions_ocr_res["rec_boxes"] = []
 
-    overall_text_boxes = overall_ocr_res["dt_boxes"]
+    overall_text_boxes = overall_ocr_res["rec_boxes"]
     match_idx_list = get_overlap_boxes_idx(overall_text_boxes, object_boxes)
     match_idx_list = list(set(match_idx_list))
     for box_no in range(len(overall_text_boxes)):
@@ -113,10 +81,16 @@ def get_sub_regions_ocr_res(
             else:
                 flag_match = False
         if flag_match:
-            sub_regions_ocr_res["dt_polys"].append(overall_ocr_res["dt_polys"][box_no])
-            sub_regions_ocr_res["rec_text"].append(overall_ocr_res["rec_text"][box_no])
-            sub_regions_ocr_res["rec_score"].append(
-                overall_ocr_res["rec_score"][box_no]
+            sub_regions_ocr_res["rec_polys"].append(
+                overall_ocr_res["rec_polys"][box_no]
+            )
+            sub_regions_ocr_res["rec_texts"].append(
+                overall_ocr_res["rec_texts"][box_no]
+            )
+            sub_regions_ocr_res["rec_scores"].append(
+                overall_ocr_res["rec_scores"][box_no]
+            )
+            sub_regions_ocr_res["rec_boxes"].append(
+                overall_ocr_res["rec_boxes"][box_no]
             )
-            sub_regions_ocr_res["dt_boxes"].append(overall_ocr_res["dt_boxes"][box_no])
     return sub_regions_ocr_res

+ 66 - 89
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -19,7 +19,12 @@ from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
-from ..components import CropByPolys, SortQuadBoxes, SortPolyBoxes
+from ..components import (
+    CropByPolys,
+    SortQuadBoxes,
+    SortPolyBoxes,
+    convert_points_to_boxes,
+)
 from .result import OCRResult
 from ..doc_preprocessor.result import DocPreprocessorResult
 from ....utils import logging
@@ -54,14 +59,22 @@ class OCRPipeline(BasePipeline):
 
         self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
         if self.use_doc_preprocessor:
-            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            doc_preprocessor_config = config.get("SubPipelines", {}).get(
+                "DocPreprocessor",
+                {
+                    "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
+                },
+            )
             self.doc_preprocessor_pipeline = self.create_pipeline(
                 doc_preprocessor_config
             )
 
         self.use_textline_orientation = config.get("use_textline_orientation", True)
         if self.use_textline_orientation:
-            textline_orientation_config = config["SubModules"]["TextLineOrientation"]
+            textline_orientation_config = config.get("SubModules", {}).get(
+                "TextLineOrientation",
+                {"model_config_error": "config error for textline_orientation_model!"},
+            )
             # TODO: add batch_size
             # batch_size = textline_orientation_config.get("batch_size", 1)
             # self.textline_orientation_model = self.create_model(
@@ -71,26 +84,42 @@ class OCRPipeline(BasePipeline):
                 textline_orientation_config
             )
 
-        text_det_config = config["SubModules"]["TextDetection"]
-        self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
-        self.text_det_limit_type = text_det_config.get("limit_type", "max")
-        self.text_det_thresh = text_det_config.get("thresh", 0.3)
-        self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
-        self.text_det_max_candidates = text_det_config.get("max_candidates", 1000)
-        self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
-        self.text_det_use_dilation = text_det_config.get("use_dilation", False)
+        text_det_config = config.get("SubModules", {}).get(
+            "TextDetection", {"model_config_error": "config error for text_det_model!"}
+        )
+        self.text_type = config["text_type"]
+        if self.text_type == "general":
+            self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
+            self.text_det_limit_type = text_det_config.get("limit_type", "max")
+            self.text_det_thresh = text_det_config.get("thresh", 0.3)
+            self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
+            self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
+            self._sort_boxes = SortQuadBoxes()
+            self._crop_by_polys = CropByPolys(det_box_type="quad")
+        elif self.text_type == "seal":
+            self.text_det_limit_side_len = text_det_config.get("limit_side_len", 736)
+            self.text_det_limit_type = text_det_config.get("limit_type", "min")
+            self.text_det_thresh = text_det_config.get("thresh", 0.2)
+            self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
+            self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 0.5)
+            self._sort_boxes = SortPolyBoxes()
+            self._crop_by_polys = CropByPolys(det_box_type="poly")
+        else:
+            raise ValueError("Unsupported text type {}".format(self.text_type))
+
         self.text_det_model = self.create_model(
             text_det_config,
             limit_side_len=self.text_det_limit_side_len,
             limit_type=self.text_det_limit_type,
             thresh=self.text_det_thresh,
             box_thresh=self.text_det_box_thresh,
-            max_candidates=self.text_det_max_candidates,
             unclip_ratio=self.text_det_unclip_ratio,
-            use_dilation=self.text_det_use_dilation,
         )
 
-        text_rec_config = config["SubModules"]["TextRecognition"]
+        text_rec_config = config.get("SubModules", {}).get(
+            "TextRecognition",
+            {"model_config_error": "config error for text_rec_model!"},
+        )
         # TODO: add batch_size
         # batch_size = text_rec_config.get("batch_size", 1)
         # self.text_rec_model = self.create_model(text_rec_config,
@@ -98,16 +127,6 @@ class OCRPipeline(BasePipeline):
         self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
         self.text_rec_model = self.create_model(text_rec_config)
 
-        self.text_type = config["text_type"]
-        if self.text_type == "general":
-            self._sort_boxes = SortQuadBoxes()
-            self._crop_by_polys = CropByPolys(det_box_type="quad")
-        elif self.text_type == "seal":
-            self._sort_boxes = SortPolyBoxes()
-            self._crop_by_polys = CropByPolys(det_box_type="poly")
-        else:
-            raise ValueError("Unsupported text type {}".format(self.text_type))
-
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
@@ -175,36 +194,6 @@ class OCRPipeline(BasePipeline):
 
         return True
 
-    def predict_doc_preprocessor_res(
-        self, image_array: np.ndarray, model_settings: dict
-    ) -> tuple[DocPreprocessorResult, np.ndarray]:
-        """
-        Preprocess the document image based on input parameters.
-
-        Args:
-            image_array (np.ndarray): The input image array.
-            model_settings (dict): Dictionary containing preprocessing parameters.
-
-        Returns:
-            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
-                                              result dictionary and the processed image array.
-        """
-        if model_settings["use_doc_preprocessor"]:
-            use_doc_orientation_classify = model_settings[
-                "use_doc_orientation_classify"
-            ]
-            use_doc_unwarping = model_settings["use_doc_unwarping"]
-            doc_preprocessor_res = next(
-                self.doc_preprocessor_pipeline(
-                    image_array,
-                    use_doc_orientation_classify=use_doc_orientation_classify,
-                    use_doc_unwarping=use_doc_unwarping,
-                )
-            )
-        else:
-            doc_preprocessor_res = {"output_img": image_array}
-        return doc_preprocessor_res
-
     def get_model_settings(
         self,
         use_doc_orientation_classify: Optional[bool],
@@ -222,15 +211,15 @@ class OCRPipeline(BasePipeline):
         Returns:
             dict: A dictionary containing the model settings.
         """
-        if use_doc_orientation_classify is None:
-            use_doc_orientation_classify = self.use_doc_orientation_classify
-        if use_doc_unwarping is None:
-            use_doc_unwarping = self.use_doc_unwarping
+        if use_doc_orientation_classify is None and use_doc_unwarping is None:
+            use_doc_preprocessor = self.use_doc_preprocessor
+        else:
+            use_doc_preprocessor = True
+
         if use_textline_orientation is None:
             use_textline_orientation = self.use_textline_orientation
         return dict(
-            use_doc_orientation_classify=use_doc_orientation_classify,
-            use_doc_unwarping=use_doc_unwarping,
+            use_doc_preprocessor=use_doc_preprocessor,
             use_textline_orientation=use_textline_orientation,
         )
 
@@ -240,9 +229,7 @@ class OCRPipeline(BasePipeline):
         text_det_limit_type: Optional[str] = None,
         text_det_thresh: Optional[float] = None,
         text_det_box_thresh: Optional[float] = None,
-        text_det_max_candidates: Optional[int] = None,
         text_det_unclip_ratio: Optional[float] = None,
-        text_det_use_dilation: Optional[bool] = None,
     ) -> dict:
         """
         Get text detection parameters.
@@ -254,9 +241,7 @@ class OCRPipeline(BasePipeline):
             text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
             text_det_thresh (Optional[float]): The threshold for text detection.
             text_det_box_thresh (Optional[float]): The threshold for the bounding box.
-            text_det_max_candidates (Optional[int]): The maximum number of candidate text boxes.
             text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
-            text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
 
         Returns:
             dict: A dictionary containing the text detection parameters.
@@ -269,20 +254,14 @@ class OCRPipeline(BasePipeline):
             text_det_thresh = self.text_det_thresh
         if text_det_box_thresh is None:
             text_det_box_thresh = self.text_det_box_thresh
-        if text_det_max_candidates is None:
-            text_det_max_candidates = self.text_det_max_candidates
         if text_det_unclip_ratio is None:
             text_det_unclip_ratio = self.text_det_unclip_ratio
-        if text_det_use_dilation is None:
-            text_det_use_dilation = self.text_det_use_dilation
         return dict(
             limit_side_len=text_det_limit_side_len,
             limit_type=text_det_limit_type,
             thresh=text_det_thresh,
             box_thresh=text_det_box_thresh,
-            max_candidates=text_det_max_candidates,
             unclip_ratio=text_det_unclip_ratio,
-            use_dilation=text_det_use_dilation,
         )
 
     def predict(
@@ -295,9 +274,7 @@ class OCRPipeline(BasePipeline):
         text_det_limit_type: Optional[str] = None,
         text_det_thresh: Optional[float] = None,
         text_det_box_thresh: Optional[float] = None,
-        text_det_max_candidates: Optional[int] = None,
         text_det_unclip_ratio: Optional[float] = None,
-        text_det_use_dilation: Optional[bool] = None,
         text_rec_score_thresh: Optional[float] = None,
     ) -> OCRResult:
         """
@@ -312,9 +289,7 @@ class OCRPipeline(BasePipeline):
             text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
             text_det_thresh (Optional[float]): Threshold for text detection.
             text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
-            text_det_max_candidates (Optional[int]): Maximum number of text detection candidates.
             text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
-            text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
             text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
         Returns:
             OCRResult: Generator yielding OCR results for each input image.
@@ -323,13 +298,6 @@ class OCRPipeline(BasePipeline):
         model_settings = self.get_model_settings(
             use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
         )
-        if (
-            model_settings["use_doc_orientation_classify"]
-            or model_settings["use_doc_unwarping"]
-        ):
-            model_settings["use_doc_preprocessor"] = True
-        else:
-            model_settings["use_doc_preprocessor"] = False
 
         if not self.check_model_settings_valid(model_settings):
             yield {"error": "the input params for model settings are invalid!"}
@@ -339,9 +307,7 @@ class OCRPipeline(BasePipeline):
             text_det_limit_type,
             text_det_thresh,
             text_det_box_thresh,
-            text_det_max_candidates,
             text_det_unclip_ratio,
-            text_det_use_dilation,
         )
 
         if text_rec_score_thresh is None:
@@ -356,9 +322,16 @@ class OCRPipeline(BasePipeline):
 
             image_array = self.img_reader(batch_data)[0]
 
-            doc_preprocessor_res = self.predict_doc_preprocessor_res(
-                image_array, model_settings
-            )
+            if model_settings["use_doc_preprocessor"]:
+                doc_preprocessor_res = next(
+                    self.doc_preprocessor_pipeline(
+                        image_array,
+                        use_doc_orientation_classify=use_doc_orientation_classify,
+                        use_doc_unwarping=use_doc_unwarping,
+                    )
+                )
+            else:
+                doc_preprocessor_res = {"output_img": image_array}
 
             doc_preprocessor_image = doc_preprocessor_res["output_img"]
 
@@ -372,17 +345,18 @@ class OCRPipeline(BasePipeline):
             dt_polys = self._sort_boxes(dt_polys)
 
             single_img_res = {
-                "input_path": batch_data[0],
+                "input_path": input_path,
                 "doc_preprocessor_res": doc_preprocessor_res,
                 "dt_polys": dt_polys,
                 "model_settings": model_settings,
                 "text_det_params": text_det_params,
                 "text_type": self.text_type,
+                "text_rec_score_thresh": text_rec_score_thresh,
             }
 
             single_img_res["rec_texts"] = []
             single_img_res["rec_scores"] = []
-            single_img_res["rec_boxes"] = []
+            single_img_res["rec_polys"] = []
             if len(dt_polys) > 0:
                 all_subs_of_img = list(
                     self._crop_by_polys(doc_preprocessor_image, dt_polys)
@@ -404,5 +378,8 @@ class OCRPipeline(BasePipeline):
                     if rec_res["rec_score"] >= text_rec_score_thresh:
                         single_img_res["rec_texts"].append(rec_res["rec_text"])
                         single_img_res["rec_scores"].append(rec_res["rec_score"])
-                        single_img_res["rec_boxes"].append(dt_polys[rno])
+                        single_img_res["rec_polys"].append(dt_polys[rno])
+
+            rec_boxes = convert_points_to_boxes(single_img_res["rec_polys"])
+            single_img_res["rec_boxes"] = rec_boxes
             yield OCRResult(single_img_res)

+ 63 - 10
paddlex/inference/pipelines_new/ocr/result.py

@@ -14,6 +14,7 @@
 
 import os
 from pathlib import Path
+from typing import Dict
 import copy
 import math
 import random
@@ -22,7 +23,7 @@ import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
-from ...common.result import BaseCVResult
+from ...common.result import BaseCVResult, StrMixin, JsonMixin
 
 
 class OCRResult(BaseCVResult):
@@ -62,14 +63,14 @@ class OCRResult(BaseCVResult):
 
         return box
 
-    def _to_img(self) -> PIL.Image:
+    def _to_img(self) -> Dict[str, Image.Image]:
         """
         Converts the internal data to a PIL Image with detection and recognition results.
 
         Returns:
-            PIL.Image: An image with detection boxes, texts, and scores blended on it.
+            Dict[Image.Image]: A dictionary containing two images: 'doc_preprocessor_res' and 'ocr_res_img'.
         """
-        boxes = self["rec_boxes"]
+        boxes = self["rec_polys"]
         txts = self["rec_texts"]
         image = self["doc_preprocessor_res"]["output_img"]
         h, w = image.shape[0:2]
@@ -109,13 +110,65 @@ class OCRResult(BaseCVResult):
         img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
 
         model_settings = self["model_settings"]
+        res_img_dict = {f"ocr_res_img": img_show}
         if model_settings["use_doc_preprocessor"]:
-            return {
-                **self["doc_preprocessor_res"].img,
-                f"ocr_res_img": img_show,
-            }
-        else:
-            return {f"ocr_res_img": img_show}
+            res_img_dict.update(**self["doc_preprocessor_res"].img)
+        return res_img_dict
+
+    def _to_str(self, *args, **kwargs) -> Dict[str, str]:
+        """Converts the instance's attributes to a dictionary and then to a string.
+
+        Args:
+            *args: Additional positional arguments passed to the base class method.
+            **kwargs: Additional keyword arguments passed to the base class method.
+
+        Returns:
+            Dict[str, str]: A dictionary with the instance's attributes converted to strings.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
+        data["dt_polys"] = self["dt_polys"]
+        data["text_det_params"] = self["text_det_params"]
+        data["text_type"] = self["text_type"]
+        if self["model_settings"]["use_textline_orientation"]:
+            data["textline_orientation_angle"] = self["textline_orientation_angle"]
+        data["text_rec_score_thresh"] = self["text_rec_score_thresh"]
+        data["rec_texts"] = self["rec_texts"]
+        data["rec_scores"] = self["rec_scores"]
+        data["rec_polys"] = self["rec_polys"]
+        data["rec_boxes"] = self["rec_boxes"]
+        return StrMixin._to_str(data, *args, **kwargs)
+
+    def _to_json(self, *args, **kwargs) -> Dict[str, str]:
+        """
+        Converts the object's data to a JSON dictionary.
+
+        Args:
+            *args: Positional arguments passed to the JsonMixin._to_json method.
+            **kwargs: Keyword arguments passed to the JsonMixin._to_json method.
+
+        Returns:
+            Dict[str, str]: A dictionary containing the object's data in JSON format.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
+        data["dt_polys"] = self["dt_polys"]
+        data["text_det_params"] = self["text_det_params"]
+        data["text_type"] = self["text_type"]
+        if self["model_settings"]["use_textline_orientation"]:
+            data["textline_orientation_angle"] = self["textline_orientation_angle"]
+        data["text_rec_score_thresh"] = self["text_rec_score_thresh"]
+        data["rec_texts"] = self["rec_texts"]
+        data["rec_scores"] = self["rec_scores"]
+        data["rec_polys"] = self["rec_polys"]
+        data["rec_boxes"] = self["rec_boxes"]
+        return JsonMixin._to_json(data, *args, **kwargs)
 
 
 # Adds a function comment according to Google Style Guide

+ 103 - 83
paddlex/inference/pipelines_new/seal_recognition/pipeline.py

@@ -56,25 +56,29 @@ class SealRecognitionPipeline(BasePipeline):
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
-        self.use_doc_preprocessor = False
-        if "use_doc_preprocessor" in config:
-            self.use_doc_preprocessor = config["use_doc_preprocessor"]
-
+        self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
         if self.use_doc_preprocessor:
-            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            doc_preprocessor_config = config.get("SubPipelines", {}).get(
+                "DocPreprocessor",
+                {
+                    "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
+                },
+            )
             self.doc_preprocessor_pipeline = self.create_pipeline(
                 doc_preprocessor_config
             )
 
-        self.use_layout_detection = True
-        if "use_layout_detection" in config:
-            self.use_layout_detection = config["use_layout_detection"]
-
+        self.use_layout_detection = config.get("use_layout_detection", True)
         if self.use_layout_detection:
-            layout_det_config = config["SubModules"]["LayoutDetection"]
+            layout_det_config = config.get("SubModules", {}).get(
+                "LayoutDetection",
+                {"model_config_error": "config error for layout_det_model!"},
+            )
             self.layout_det_model = self.create_model(layout_det_config)
 
-        seal_ocr_config = config["SubPipelines"]["SealOCR"]
+        seal_ocr_config = config.get("SubPipelines", {}).get(
+            "SealOCR", {"pipeline_config_error": "config error for seal_ocr_pipeline!"}
+        )
         self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)
 
         self._crop_by_boxes = CropByBoxes()
@@ -83,27 +87,27 @@ class SealRecognitionPipeline(BasePipeline):
 
         self.img_reader = ReadImage(format="BGR")
 
-    def check_input_params_valid(
-        self, input_params: Dict, layout_det_res: DetResult
+    def check_model_settings_valid(
+        self, model_settings: Dict, layout_det_res: DetResult
     ) -> bool:
         """
         Check if the input parameters are valid based on the initialized models.
 
         Args:
-            input_params (Dict): A dictionary containing input parameters.
+            model_settings (Dict): A dictionary containing input parameters.
             layout_det_res (DetResult): Layout detection result.
 
         Returns:
             bool: True if all required models are initialized according to input parameters, False otherwise.
         """
 
-        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+        if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
             logging.error(
                 "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
             )
             return False
 
-        if input_params["use_layout_detection"]:
+        if model_settings["use_layout_detection"]:
             if layout_det_res is not None:
                 logging.error(
                     "The layout detection model has already been initialized, please set use_layout_detection=False"
@@ -117,112 +121,128 @@ class SealRecognitionPipeline(BasePipeline):
                 return False
         return True
 
-    def predict_doc_preprocessor_res(
-        self, image_array: np.ndarray, input_params: dict
-    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+    def get_model_settings(
+        self,
+        use_doc_orientation_classify: Optional[bool],
+        use_doc_unwarping: Optional[bool],
+        use_layout_detection: Optional[bool],
+    ) -> dict:
         """
-        Preprocess the document image based on input parameters.
+        Get the model settings based on the provided parameters or default values.
 
         Args:
-            image_array (np.ndarray): The input image array.
-            input_params (dict): Dictionary containing preprocessing parameters.
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_layout_detection (Optional[bool]): Whether to use layout detection.
 
         Returns:
-            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
-                                              result dictionary and the processed image array.
+            dict: A dictionary containing the model settings.
         """
-        if input_params["use_doc_preprocessor"]:
-            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
-            use_doc_unwarping = input_params["use_doc_unwarping"]
-            doc_preprocessor_res = next(
-                self.doc_preprocessor_pipeline(
-                    image_array,
-                    use_doc_orientation_classify=use_doc_orientation_classify,
-                    use_doc_unwarping=use_doc_unwarping,
-                )
-            )
-            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        if use_doc_orientation_classify is None and use_doc_unwarping is None:
+            use_doc_preprocessor = self.use_doc_preprocessor
         else:
-            doc_preprocessor_res = {}
-            doc_preprocessor_image = image_array
-        return doc_preprocessor_res, doc_preprocessor_image
+            use_doc_preprocessor = True
+
+        if use_layout_detection is None:
+            use_layout_detection = self.use_layout_detection
+        return dict(
+            use_doc_preprocessor=use_doc_preprocessor,
+            use_layout_detection=use_layout_detection,
+        )
 
     def predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        use_layout_detection: bool = True,
-        use_doc_orientation_classify: bool = False,
-        use_doc_unwarping: bool = False,
-        layout_det_res: DetResult = None,
-        **kwargs
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_layout_detection: Optional[bool] = None,
+        layout_det_res: Optional[DetResult] = None,
+        seal_det_limit_side_len: Optional[int] = None,
+        seal_det_limit_type: Optional[str] = None,
+        seal_det_thresh: Optional[float] = None,
+        seal_det_box_thresh: Optional[float] = None,
+        seal_det_unclip_ratio: Optional[float] = None,
+        seal_rec_score_thresh: Optional[float] = None,
+        **kwargs,
     ) -> SealRecognitionResult:
-        """
-        This function predicts the seal recognition result for the given input.
 
-        Args:
-            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) of pdf(s) to be processed.
-            use_layout_detection (bool): Whether to use layout detection.
-            use_doc_orientation_classify (bool): Whether to use document orientation classification.
-            use_doc_unwarping (bool): Whether to use document unwarping.
-            layout_det_res (DetResult): The layout detection result.
-                It will be used if it is not None and use_layout_detection is False.
-            **kwargs: Additional keyword arguments.
-
-        Returns:
-            SealRecognitionResult: The predicted seal recognition result.
-        """
-
-        input_params = {
-            "use_layout_detection": use_layout_detection,
-            "use_doc_preprocessor": self.use_doc_preprocessor,
-            "use_doc_orientation_classify": use_doc_orientation_classify,
-            "use_doc_unwarping": use_doc_unwarping,
-        }
-
-        if use_doc_orientation_classify or use_doc_unwarping:
-            input_params["use_doc_preprocessor"] = True
-        else:
-            input_params["use_doc_preprocessor"] = False
+        model_settings = self.get_model_settings(
+            use_doc_orientation_classify, use_doc_unwarping, use_layout_detection
+        )
 
-        if not self.check_input_params_valid(input_params, layout_det_res):
-            yield None
+        if not self.check_model_settings_valid(model_settings, layout_det_res):
+            yield {"error": "the input params for model settings are invalid!"}
 
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            if not isinstance(batch_data[0], str):
+                # TODO: add support input_pth for ndarray and pdf
+                input_path = f"{img_id}"
+            else:
+                input_path = batch_data[0]
+
             image_array = self.img_reader(batch_data)[0]
-            img_id += 1
 
-            doc_preprocessor_res, doc_preprocessor_image = (
-                self.predict_doc_preprocessor_res(image_array, input_params)
-            )
+            if model_settings["use_doc_preprocessor"]:
+                doc_preprocessor_res = next(
+                    self.doc_preprocessor_pipeline(
+                        image_array,
+                        use_doc_orientation_classify=use_doc_orientation_classify,
+                        use_doc_unwarping=use_doc_unwarping,
+                    )
+                )
+            else:
+                doc_preprocessor_res = {"output_img": image_array}
+
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
 
             seal_res_list = []
             seal_region_id = 1
-            if not input_params["use_layout_detection"] and layout_det_res is None:
+            if not model_settings["use_layout_detection"] and layout_det_res is None:
                 layout_det_res = {}
-                seal_ocr_res = next(self.seal_ocr_pipeline(doc_preprocessor_image))
+                seal_ocr_res = next(
+                    self.seal_ocr_pipeline(
+                        doc_preprocessor_image,
+                        text_det_limit_side_len=seal_det_limit_side_len,
+                        text_det_limit_type=seal_det_limit_type,
+                        text_det_thresh=seal_det_thresh,
+                        text_det_box_thresh=seal_det_box_thresh,
+                        text_det_unclip_ratio=seal_det_unclip_ratio,
+                        text_rec_score_thresh=seal_rec_score_thresh,
+                    )
+                )
                 seal_ocr_res["seal_region_id"] = seal_region_id
                 seal_res_list.append(seal_ocr_res)
                 seal_region_id += 1
             else:
-                if input_params["use_layout_detection"]:
+                if model_settings["use_layout_detection"]:
                     layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
 
                 for box_info in layout_det_res["boxes"]:
                     if box_info["label"].lower() in ["seal"]:
-                        crop_img_info = self._crop_by_boxes(image_array, [box_info])
+                        crop_img_info = self._crop_by_boxes(
+                            doc_preprocessor_image, [box_info]
+                        )
                         crop_img_info = crop_img_info[0]
                         seal_ocr_res = next(
-                            self.seal_ocr_pipeline(crop_img_info["img"])
+                            self.seal_ocr_pipeline(
+                                crop_img_info["img"],
+                                text_det_limit_side_len=seal_det_limit_side_len,
+                                text_det_limit_type=seal_det_limit_type,
+                                text_det_thresh=seal_det_thresh,
+                                text_det_box_thresh=seal_det_box_thresh,
+                                text_det_unclip_ratio=seal_det_unclip_ratio,
+                                text_rec_score_thresh=seal_rec_score_thresh,
+                            )
                         )
                         seal_ocr_res["seal_region_id"] = seal_region_id
                         seal_res_list.append(seal_ocr_res)
                         seal_region_id += 1
 
             single_img_res = {
-                "layout_det_res": layout_det_res,
+                "input_path": input_path,
                 "doc_preprocessor_res": doc_preprocessor_res,
+                "layout_det_res": layout_det_res,
                 "seal_res_list": seal_res_list,
-                "input_params": input_params,
-                "img_id": img_id,
+                "model_settings": model_settings,
             }
             yield SealRecognitionResult(single_img_res)

+ 59 - 28
paddlex/inference/pipelines_new/seal_recognition/result.py

@@ -12,44 +12,75 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
-from pathlib import Path
+from typing import Dict
+import numpy as np
+from ...common.result import BaseCVResult, StrMixin, JsonMixin
 
 
-class SealRecognitionResult(dict):
+class SealRecognitionResult(BaseCVResult):
     """Seal Recognition Result"""
 
-    def __init__(self, data) -> None:
-        """Initializes a new instance of the class with the specified data."""
-        super().__init__(data)
+    def _to_img(self) -> Dict[str, np.ndarray]:
+        res_img_dict = {}
+        layout_det_res = self["layout_det_res"]
+        if len(layout_det_res) > 0:
+            res_img_dict["layout_det_res"] = layout_det_res.img["res"]
 
-    def save_results(self, save_path: str) -> None:
-        """Save the layout parsing results to the specified directory.
+        model_settings = self["model_settings"]
+        if model_settings["use_doc_preprocessor"]:
+            res_img_dict.update(**self["doc_preprocessor_res"].img)
+
+        for sno in range(len(self["seal_res_list"])):
+            seal_res = self["seal_res_list"][sno]
+            seal_region_id = seal_res["seal_region_id"]
+            sub_seal_res_dict = seal_res.img
+            key = f"seal_res_region{seal_region_id}"
+            res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
+        return res_img_dict
+
+    def _to_str(self, *args, **kwargs) -> Dict[str, str]:
+        """Converts the instance's attributes to a dictionary and then to a string.
 
         Args:
-            save_path (str): The directory path to save the results.
-        """
+            *args: Additional positional arguments passed to the base class method.
+            **kwargs: Additional keyword arguments passed to the base class method.
 
-        if not os.path.isdir(save_path):
-            return
+        Returns:
+            Dict[str, str]: A dictionary with the instance's attributes converted to strings.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
+        if len(self["layout_det_res"]) > 0:
+            data["layout_det_res"] = self["layout_det_res"].str["res"]
+        data["seal_res_list"] = []
+        for sno in range(len(self["seal_res_list"])):
+            seal_res = self["seal_res_list"][sno]
+            data["seal_res_list"].append(seal_res.str["res"])
+        return StrMixin._to_str(data, *args, **kwargs)
 
-        img_id = self["img_id"]
-        layout_det_res = self["layout_det_res"]
-        if len(layout_det_res) > 0:
-            save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg"
-            layout_det_res.save_to_img(save_img_path)
+    def _to_json(self, *args, **kwargs) -> Dict[str, str]:
+        """
+        Converts the object's data to a JSON dictionary.
 
-        input_params = self["input_params"]
-        if input_params["use_doc_preprocessor"]:
-            save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg"
-            self["doc_preprocessor_res"].save_to_img(save_img_path)
+        Args:
+            *args: Positional arguments passed to the JsonMixin._to_json method.
+            **kwargs: Keyword arguments passed to the JsonMixin._to_json method.
 
+        Returns:
+            Dict[str, str]: A dictionary containing the object's data in JSON format.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
+        if len(self["layout_det_res"]) > 0:
+            data["layout_det_res"] = self["layout_det_res"].json["res"]
+        data["seal_res_list"] = []
         for sno in range(len(self["seal_res_list"])):
             seal_res = self["seal_res_list"][sno]
-            seal_region_id = seal_res["seal_region_id"]
-            save_img_path = (
-                Path(save_path) / f"seal_res_img{img_id}_region{seal_region_id}.jpg"
-            )
-            seal_res.save_to_img(save_img_path)
-
-        return
+            data["seal_res_list"].append(seal_res.json["res"])
+        return JsonMixin._to_json(data, *args, **kwargs)

+ 128 - 59
paddlex/inference/pipelines_new/table_recognition/pipeline.py

@@ -18,7 +18,6 @@ import numpy as np
 import cv2
 from ..base import BasePipeline
 from ..components import CropByBoxes
-from ..layout_parsing.utils import convert_points_to_boxes
 from .utils import get_neighbor_boxes_idx
 from .table_recognition_post_processing import get_table_recognition_res
 from .result import SingleTableRecognitionResult, TableRecognitionResult
@@ -60,32 +59,38 @@ class TableRecognitionPipeline(BasePipeline):
             device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
         )
 
-        self.use_doc_preprocessor = False
-        if "use_doc_preprocessor" in config:
-            self.use_doc_preprocessor = config["use_doc_preprocessor"]
-
+        self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
         if self.use_doc_preprocessor:
-            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            doc_preprocessor_config = config.get("SubPipelines", {}).get(
+                "DocPreprocessor",
+                {
+                    "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
+                },
+            )
             self.doc_preprocessor_pipeline = self.create_pipeline(
                 doc_preprocessor_config
             )
 
-        self.use_layout_detection = True
-        if "use_layout_detection" in config:
-            self.use_layout_detection = config["use_layout_detection"]
-
+        self.use_layout_detection = config.get("use_layout_detection", True)
         if self.use_layout_detection:
-            layout_det_config = config["SubModules"]["LayoutDetection"]
+            layout_det_config = config.get("SubModules", {}).get(
+                "LayoutDetection",
+                {"model_config_error": "config error for layout_det_model!"},
+            )
             self.layout_det_model = self.create_model(layout_det_config)
 
-        table_structure_config = config["SubModules"]["TableStructureRecognition"]
+        table_structure_config = config.get("SubModules", {}).get(
+            "TableStructureRecognition",
+            {"model_config_error": "config error for table_structure_model!"},
+        )
         self.table_structure_model = self.create_model(table_structure_config)
 
-        self.use_ocr_model = True
-        if "use_ocr_model" in config:
-            self.use_ocr_model = config["use_ocr_model"]
+        self.use_ocr_model = config.get("use_ocr_model", True)
         if self.use_ocr_model:
-            general_ocr_config = config["SubPipelines"]["GeneralOCR"]
+            general_ocr_config = config.get("SubPipelines", {}).get(
+                "GeneralOCR",
+                {"pipeline_config_error": "config error for general_ocr_pipeline!"},
+            )
             self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
 
         self._crop_by_boxes = CropByBoxes()
@@ -93,14 +98,53 @@ class TableRecognitionPipeline(BasePipeline):
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
 
-    def check_input_params_valid(
-        self, input_params: Dict, overall_ocr_res: OCRResult, layout_det_res: DetResult
+    def get_model_settings(
+        self,
+        use_doc_orientation_classify: Optional[bool],
+        use_doc_unwarping: Optional[bool],
+        use_layout_detection: Optional[bool],
+        use_ocr_model: Optional[bool],
+    ) -> dict:
+        """
+        Get the model settings based on the provided parameters or default values.
+
+        Args:
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_layout_detection (Optional[bool]): Whether to use layout detection.
+            use_ocr_model (Optional[bool]): Whether to use OCR model.
+
+        Returns:
+            dict: A dictionary containing the model settings.
+        """
+        if use_doc_orientation_classify is None and use_doc_unwarping is None:
+            use_doc_preprocessor = self.use_doc_preprocessor
+        else:
+            use_doc_preprocessor = True
+
+        if use_layout_detection is None:
+            use_layout_detection = self.use_layout_detection
+
+        if use_ocr_model is None:
+            use_ocr_model = self.use_ocr_model
+
+        return dict(
+            use_doc_preprocessor=use_doc_preprocessor,
+            use_layout_detection=use_layout_detection,
+            use_ocr_model=use_ocr_model,
+        )
+
+    def check_model_settings_valid(
+        self,
+        model_settings: Dict,
+        overall_ocr_res: OCRResult,
+        layout_det_res: DetResult,
     ) -> bool:
         """
         Check if the input parameters are valid based on the initialized models.
 
         Args:
-            input_params (Dict): A dictionary containing input parameters.
+            model_settings (Dict): A dictionary containing input parameters.
             overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
                 The overall OCR result with convert_points_to_boxes information.
             layout_det_res (DetResult): The layout detection result.
@@ -108,13 +152,13 @@ class TableRecognitionPipeline(BasePipeline):
             bool: True if all required models are initialized according to input parameters, False otherwise.
         """
 
-        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+        if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
             logging.error(
                 "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
             )
             return False
 
-        if input_params["use_layout_detection"]:
+        if model_settings["use_layout_detection"]:
             if layout_det_res is not None:
                 logging.error(
                     "The layout detection model has already been initialized, please set use_layout_detection=False"
@@ -127,7 +171,7 @@ class TableRecognitionPipeline(BasePipeline):
                 )
                 return False
 
-        if input_params["use_ocr_model"]:
+        if model_settings["use_ocr_model"]:
             if overall_ocr_res is not None:
                 logging.error(
                     "The OCR models have already been initialized, please set use_ocr_model=False"
@@ -139,7 +183,10 @@ class TableRecognitionPipeline(BasePipeline):
                     "Set use_ocr_model, but the models for OCR are not initialized."
                 )
                 return False
-
+        else:
+            if overall_ocr_res is None:
+                logging.error("Set use_ocr_model=False, but no OCR results were found.")
+                return False
         return True
 
     def predict_doc_preprocessor_res(
@@ -198,23 +245,30 @@ class TableRecognitionPipeline(BasePipeline):
         neighbor_text = ""
         if flag_find_nei_text:
             match_idx_list = get_neighbor_boxes_idx(
-                overall_ocr_res["dt_boxes"], table_box
+                overall_ocr_res["rec_boxes"], table_box
             )
             if len(match_idx_list) > 0:
                 for idx in match_idx_list:
-                    neighbor_text += overall_ocr_res["rec_text"][idx] + "; "
+                    neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
         single_table_recognition_res["neighbor_text"] = neighbor_text
         return single_table_recognition_res
 
     def predict(
         self,
         input: str | list[str] | np.ndarray | list[np.ndarray],
-        use_layout_detection: bool = True,
-        use_doc_orientation_classify: bool = False,
-        use_doc_unwarping: bool = False,
-        overall_ocr_res: OCRResult = None,
-        layout_det_res: DetResult = None,
-        **kwargs
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_layout_detection: Optional[bool] = None,
+        use_ocr_model: Optional[bool] = None,
+        overall_ocr_res: Optional[OCRResult] = None,
+        layout_det_res: Optional[DetResult] = None,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_rec_score_thresh: Optional[float] = None,
+        **kwargs,
     ) -> TableRecognitionResult:
         """
         This function predicts the layout parsing result for the given input.
@@ -234,42 +288,56 @@ class TableRecognitionPipeline(BasePipeline):
             TableRecognitionResult: The predicted table recognition result.
         """
 
-        input_params = {
-            "use_layout_detection": use_layout_detection,
-            "use_doc_preprocessor": self.use_doc_preprocessor,
-            "use_doc_orientation_classify": use_doc_orientation_classify,
-            "use_doc_unwarping": use_doc_unwarping,
-            "use_ocr_model": self.use_ocr_model,
-        }
-
-        if use_doc_orientation_classify or use_doc_unwarping:
-            input_params["use_doc_preprocessor"] = True
-        else:
-            input_params["use_doc_preprocessor"] = False
+        model_settings = self.get_model_settings(
+            use_doc_orientation_classify,
+            use_doc_unwarping,
+            use_layout_detection,
+            use_ocr_model,
+        )
 
-        if not self.check_input_params_valid(
-            input_params, overall_ocr_res, layout_det_res
+        if not self.check_model_settings_valid(
+            model_settings, overall_ocr_res, layout_det_res
         ):
-            yield None
+            yield {"error": "the input params for model settings are invalid!"}
 
         for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            if not isinstance(batch_data[0], str):
+                # TODO: add support input_pth for ndarray and pdf
+                input_path = f"{img_id}"
+            else:
+                input_path = batch_data[0]
+
             image_array = self.img_reader(batch_data)[0]
-            img_id += 1
 
-            doc_preprocessor_res, doc_preprocessor_image = (
-                self.predict_doc_preprocessor_res(image_array, input_params)
-            )
+            if model_settings["use_doc_preprocessor"]:
+                doc_preprocessor_res = next(
+                    self.doc_preprocessor_pipeline(
+                        image_array,
+                        use_doc_orientation_classify=use_doc_orientation_classify,
+                        use_doc_unwarping=use_doc_unwarping,
+                    )
+                )
+            else:
+                doc_preprocessor_res = {"output_img": image_array}
 
-            if self.use_ocr_model:
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+
+            if model_settings["use_ocr_model"]:
                 overall_ocr_res = next(
-                    self.general_ocr_pipeline(doc_preprocessor_image)
+                    self.general_ocr_pipeline(
+                        doc_preprocessor_image,
+                        text_det_limit_side_len=text_det_limit_side_len,
+                        text_det_limit_type=text_det_limit_type,
+                        text_det_thresh=text_det_thresh,
+                        text_det_box_thresh=text_det_box_thresh,
+                        text_det_unclip_ratio=text_det_unclip_ratio,
+                        text_rec_score_thresh=text_rec_score_thresh,
+                    )
                 )
-                dt_boxes = convert_points_to_boxes(overall_ocr_res["dt_polys"])
-                overall_ocr_res["dt_boxes"] = dt_boxes
 
             table_res_list = []
             table_region_id = 1
-            if not input_params["use_layout_detection"] and layout_det_res is None:
+            if not model_settings["use_layout_detection"] and layout_det_res is None:
                 layout_det_res = {}
                 img_height, img_width = doc_preprocessor_image.shape[:2]
                 table_box = [0, 0, img_width - 1, img_height - 1]
@@ -283,8 +351,9 @@ class TableRecognitionPipeline(BasePipeline):
                 table_res_list.append(single_table_rec_res)
                 table_region_id += 1
             else:
-                if input_params["use_layout_detection"]:
+                if model_settings["use_layout_detection"]:
                     layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+
                 for box_info in layout_det_res["boxes"]:
                     if box_info["label"].lower() in ["table"]:
                         crop_img_info = self._crop_by_boxes(image_array, [box_info])
@@ -300,11 +369,11 @@ class TableRecognitionPipeline(BasePipeline):
                         table_region_id += 1
 
             single_img_res = {
-                "layout_det_res": layout_det_res,
+                "input_path": input_path,
                 "doc_preprocessor_res": doc_preprocessor_res,
+                "layout_det_res": layout_det_res,
                 "overall_ocr_res": overall_ocr_res,
                 "table_res_list": table_res_list,
-                "input_params": input_params,
-                "img_id": img_id,
+                "model_settings": model_settings,
             }
             yield TableRecognitionResult(single_img_res)

+ 133 - 56
paddlex/inference/pipelines_new/table_recognition/result.py

@@ -17,95 +17,172 @@ from typing import Dict
 from pathlib import Path
 import numpy as np
 import cv2
-from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin
+import copy
+from ...common.result import BaseCVResult, HtmlMixin, XlsxMixin, StrMixin, JsonMixin
 
 
 class SingleTableRecognitionResult(BaseCVResult, HtmlMixin, XlsxMixin):
-    """table recognition result"""
+    """single table recognition result"""
 
     def __init__(self, data: Dict) -> None:
-        """Initializes the object with given data and sets up mixins for HTML and XLSX processing."""
         super().__init__(data)
-        HtmlMixin.__init__(self)  # Initializes the HTML mixin functionality
-        XlsxMixin.__init__(self)  # Initializes the XLSX mixin functionality
+        HtmlMixin.__init__(self)
+        XlsxMixin.__init__(self)
 
-    def _to_html(self) -> str:
+    def _to_html(self) -> Dict[str, str]:
         """Converts the prediction to its corresponding HTML representation.
 
         Returns:
-            str: The HTML string representation of the prediction.
+            Dict[str, str]: The str type HTML representation result.
         """
-        return self["pred_html"]
+        return {"pred": self["pred_html"]}
 
-    def _to_xlsx(self) -> str:
+    def _to_xlsx(self) -> Dict[str, str]:
         """Converts the prediction HTML to an XLSX file path.
 
         Returns:
             str: The path to the XLSX file containing the prediction data.
         """
-        return self["pred_html"]
+        return {"pred": self["pred_html"]}
 
-    def _to_img(self) -> np.ndarray:
+    def _to_str(self, *args, **kwargs) -> Dict[str, str]:
+        """Converts the instance's attributes to a dictionary and then to a string.
+
+        Args:
+            *args: Additional positional arguments passed to the base class method.
+            **kwargs: Additional keyword arguments passed to the base class method.
+
+        Returns:
+            Dict[str, str]: A dictionary with the instance's attributes converted to strings.
         """
-        Convert the input image with table OCR predictions to an image with cell boundaries highlighted.
+        data = {}
+        data["cell_box_list"] = self["cell_box_list"]
+        data["pred_html"] = self["pred_html"]
+        data["table_ocr_pred"] = self["table_ocr_pred"]
+        return StrMixin._to_str(data, *args, **kwargs)
+
+    def _to_json(self, *args, **kwargs) -> Dict[str, str]:
+        """
+        Converts the object's data to a JSON dictionary.
+
+        Args:
+            *args: Positional arguments passed to the JsonMixin._to_json method.
+            **kwargs: Keyword arguments passed to the JsonMixin._to_json method.
 
         Returns:
-            np.ndarray: The input image with cell boundaries highlighted in red.
+            Dict[str, str]: A dictionary containing the object's data in JSON format.
         """
-        input_img = self["table_ocr_pred"]["input_img"].copy()
-        cell_box_list = self["cell_box_list"]
-        for box in cell_box_list:
-            x1, y1, x2, y2 = [int(pos) for pos in box]
-            cv2.rectangle(input_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
-        return input_img
+        data = {}
+        data["cell_box_list"] = self["cell_box_list"]
+        data["pred_html"] = self["pred_html"]
+        data["table_ocr_pred"] = self["table_ocr_pred"]
+        return JsonMixin._to_json(data, *args, **kwargs)
 
 
-class TableRecognitionResult(dict):
-    """Layout Parsing Result"""
+class TableRecognitionResult(BaseCVResult, HtmlMixin, XlsxMixin):
+    """Table Recognition Result"""
 
-    def __init__(self, data) -> None:
-        """Initializes a new instance of the class with the specified data."""
+    def __init__(self, data: Dict) -> None:
         super().__init__(data)
+        HtmlMixin.__init__(self)
+        XlsxMixin.__init__(self)
+
+    def _to_img(self) -> Dict[str, np.ndarray]:
+        res_img_dict = {}
+        layout_det_res = self["layout_det_res"]
+        if len(layout_det_res) > 0:
+            res_img_dict["layout_det_res"] = layout_det_res.img["res"]
+
+        model_settings = self["model_settings"]
+        if model_settings["use_doc_preprocessor"]:
+            res_img_dict.update(**self["doc_preprocessor_res"].img)
+
+        res_img_dict.update(**self["overall_ocr_res"].img)
+
+        if len(self["table_res_list"]) > 0:
+            table_cell_img = copy.deepcopy(self["doc_preprocessor_res"]["output_img"])
+            for sno in range(len(self["table_res_list"])):
+                table_res = self["table_res_list"][sno]
+                cell_box_list = table_res["cell_box_list"]
+                for box in cell_box_list:
+                    x1, y1, x2, y2 = [int(pos) for pos in box]
+                    cv2.rectangle(table_cell_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
+            res_img_dict["table_cell_img"] = table_cell_img
+        return res_img_dict
 
-    def save_results(self, save_path: str) -> None:
-        """Save the table recognition results to the specified directory.
+    def _to_str(self, *args, **kwargs) -> Dict[str, str]:
+        """Converts the instance's attributes to a dictionary and then to a string.
 
         Args:
-            save_path (str): The directory path to save the results.
+            *args: Additional positional arguments passed to the base class method.
+            **kwargs: Additional keyword arguments passed to the base class method.
+
+        Returns:
+            Dict[str, str]: A dictionary with the instance's attributes converted to strings.
         """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
+        if len(self["layout_det_res"]) > 0:
+            data["layout_det_res"] = self["layout_det_res"].str["res"]
+        data["overall_ocr_res"] = self["overall_ocr_res"].str["res"]
+        data["table_res_list"] = []
+        for sno in range(len(self["table_res_list"])):
+            table_res = self["table_res_list"][sno]
+            data["table_res_list"].append(table_res.str["res"])
+        return StrMixin._to_str(data, *args, **kwargs)
+
+    def _to_json(self, *args, **kwargs) -> Dict[str, str]:
+        """
+        Converts the object's data to a JSON dictionary.
 
-        if not os.path.isdir(save_path):
-            return
+        Args:
+            *args: Positional arguments passed to the JsonMixin._to_json method.
+            **kwargs: Keyword arguments passed to the JsonMixin._to_json method.
 
-        img_id = self["img_id"]
-        layout_det_res = self["layout_det_res"]
-        if len(layout_det_res) > 0:
-            save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg"
-            layout_det_res.save_to_img(save_img_path)
+        Returns:
+            Dict[str, str]: A dictionary containing the object's data in JSON format.
+        """
+        data = {}
+        data["input_path"] = self["input_path"]
+        data["model_settings"] = self["model_settings"]
+        if self["model_settings"]["use_doc_preprocessor"]:
+            data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
+        if len(self["layout_det_res"]) > 0:
+            data["layout_det_res"] = self["layout_det_res"].json["res"]
+        data["overall_ocr_res"] = self["overall_ocr_res"].json["res"]
+        data["table_res_list"] = []
+        for sno in range(len(self["table_res_list"])):
+            table_res = self["table_res_list"][sno]
+            data["table_res_list"].append(table_res.json["res"])
+        return JsonMixin._to_json(data, *args, **kwargs)
+
+    def _to_html(self) -> Dict[str, str]:
+        """Converts the prediction to its corresponding HTML representation.
 
-        input_params = self["input_params"]
-        if input_params["use_doc_preprocessor"]:
-            save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg"
-            self["doc_preprocessor_res"].save_to_img(save_img_path)
+        Returns:
+            Dict[str, str]: The str type HTML representation result.
+        """
+        res_html_dict = {}
+        for sno in range(len(self["table_res_list"])):
+            table_res = self["table_res_list"][sno]
+            table_region_id = table_res["table_region_id"]
+            key = f"table_{table_region_id}"
+            res_html_dict[key] = table_res.html["pred"]
+        return res_html_dict
 
-        save_img_path = Path(save_path) / f"overall_ocr_result_img{img_id}.jpg"
-        self["overall_ocr_res"].save_to_img(save_img_path)
+    def _to_xlsx(self) -> Dict[str, str]:
+        """Converts the prediction HTML to an XLSX file path.
 
-        for tno in range(len(self["table_res_list"])):
-            table_res = self["table_res_list"][tno]
+        Returns:
+            Dict[str, str]: The str type XLSX representation result.
+        """
+        res_xlsx_dict = {}
+        for sno in range(len(self["table_res_list"])):
+            table_res = self["table_res_list"][sno]
             table_region_id = table_res["table_region_id"]
-            save_img_path = (
-                Path(save_path)
-                / f"table_res_cell_img{img_id}_region{table_region_id}.jpg"
-            )
-            table_res.save_to_img(save_img_path)
-            save_html_path = (
-                Path(save_path) / f"table_res_img{img_id}_region{table_region_id}.html"
-            )
-            table_res.save_to_html(save_html_path)
-            save_xlsx_path = (
-                Path(save_path) / f"table_res_img{img_id}_region{table_region_id}.xlsx"
-            )
-            table_res.save_to_xlsx(save_xlsx_path)
-
-        return
+            key = f"table_{table_region_id}"
+            res_xlsx_dict[key] = table_res.xlsx["pred"]
+        return res_xlsx_dict

+ 8 - 5
paddlex/inference/pipelines_new/table_recognition/table_recognition_post_processing.py

@@ -13,7 +13,8 @@
 # limitations under the License.
 from typing import Any, Dict, Optional
 import numpy as np
-from ..layout_parsing.utils import convert_points_to_boxes, get_sub_regions_ocr_res
+from ..layout_parsing.utils import get_sub_regions_ocr_res
+from ..components import convert_points_to_boxes
 from .result import SingleTableRecognitionResult
 from ..ocr.result import OCRResult
 
@@ -59,6 +60,7 @@ def convert_table_structure_pred_bbox(
     )
     ori_cell_points_list = np.reshape(ori_cell_points_list, (-1, 4, 2))
     cell_box_list = convert_points_to_boxes(ori_cell_points_list)
+
     img_height, img_width = img_shape
     cell_box_list = np.clip(
         cell_box_list, 0, [img_width, img_height, img_width, img_height]
@@ -225,17 +227,18 @@ def get_table_recognition_res(
     table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
 
     crop_start_point = [table_box[0][0], table_box[0][1]]
-    img_shape = overall_ocr_res["doc_preprocessor_image"].shape[0:2]
+    img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
 
     convert_table_structure_pred_bbox(table_structure_pred, crop_start_point, img_shape)
 
     structures = table_structure_pred["structure"]
     cell_box_list = table_structure_pred["cell_box_list"]
-    ocr_dt_boxes = table_ocr_pred["dt_boxes"]
-    ocr_text_res = table_ocr_pred["rec_text"]
+
+    ocr_dt_boxes = table_ocr_pred["rec_boxes"]
+    ocr_texts_res = table_ocr_pred["rec_texts"]
 
     matched_index = match_table_and_ocr(cell_box_list, ocr_dt_boxes)
-    pred_html = get_html_result(matched_index, ocr_text_res, structures)
+    pred_html = get_html_result(matched_index, ocr_texts_res, structures)
 
     single_img_res = {
         "cell_box_list": cell_box_list,