Browse Source

new pdx inference (#1953)

Tingquan Gao 1 year ago
parent
commit
c50a1e4999
33 changed files with 4026 additions and 2 deletions
  1. 0 2
      paddlex/__init__.py
  2. 17 0
      paddlex/inference/components/__init__.py
  3. 229 0
      paddlex/inference/components/base.py
  4. 16 0
      paddlex/inference/components/paddle_predictor/__init__.py
  5. 170 0
      paddlex/inference/components/paddle_predictor/base_predictor.py
  6. 26 0
      paddlex/inference/components/paddle_predictor/image_predictor.py
  7. 171 0
      paddlex/inference/components/paddle_predictor/option.py
  8. 17 0
      paddlex/inference/components/task_related/__init__.py
  9. 84 0
      paddlex/inference/components/task_related/clas.py
  10. 543 0
      paddlex/inference/components/task_related/text_det.py
  11. 449 0
      paddlex/inference/components/task_related/text_rec.py
  12. 15 0
      paddlex/inference/components/transforms/__init__.py
  13. 520 0
      paddlex/inference/components/transforms/image/__init__.py
  14. 58 0
      paddlex/inference/components/transforms/image/funcs.py
  15. 16 0
      paddlex/inference/pipelines/__init__.py
  16. 48 0
      paddlex/inference/pipelines/base.py
  17. 33 0
      paddlex/inference/pipelines/image_classification.py
  18. 51 0
      paddlex/inference/pipelines/ocr.py
  19. 17 0
      paddlex/inference/predictors/__init__.py
  20. 67 0
      paddlex/inference/predictors/base.py
  21. 113 0
      paddlex/inference/predictors/image_classification.py
  22. 184 0
      paddlex/inference/predictors/official_models.py
  23. 123 0
      paddlex/inference/predictors/text_detection.py
  24. 100 0
      paddlex/inference/predictors/text_recognition.py
  25. 18 0
      paddlex/inference/results/__init__.py
  26. 150 0
      paddlex/inference/results/ocr.py
  27. 56 0
      paddlex/inference/results/text_det.py
  28. 43 0
      paddlex/inference/results/text_rec.py
  29. 127 0
      paddlex/inference/results/topk.py
  30. 89 0
      paddlex/inference/utils/color_map.py
  31. 17 0
      paddlex/inference/utils/io/__init__.py
  32. 233 0
      paddlex/inference/utils/io/readers.py
  33. 226 0
      paddlex/inference/utils/io/writers.py

+ 0 - 2
paddlex/__init__.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 from . import version
@@ -64,6 +63,5 @@ def _check_paddle_version():
 
 
 _initialize()
-_check_paddle_version()
 
 __version__ = version.get_pdx_version()

+ 17 - 0
paddlex/inference/components/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .transforms import *
+from .paddle_predictor import *
+from .task_related import *

+ 229 - 0
paddlex/inference/components/base.py

@@ -0,0 +1,229 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from copy import deepcopy
+from abc import ABC
+from types import GeneratorType
+
+from ...utils import logging
+
+
+class BaseComponent(ABC):
+
+    INPUT_KEYS = None
+    OUTPUT_KEYS = None
+
+    def __init__(self):
+        self.inputs = self.DEAULT_INPUTS if hasattr(self, "DEAULT_INPUTS") else {}
+        self.outputs = self.DEAULT_OUTPUTS if hasattr(self, "DEAULT_OUTPUTS") else {}
+
+    def __call__(self, input_list):
+        # use list type for batched data
+        if not isinstance(input_list, list):
+            input_list = [input_list]
+
+        output_list = []
+        for args, input_ in self._check_input(input_list):
+            output = self.apply(**args)
+            if not output:
+                yield input_list
+
+            # output may be a generator, when the apply() uses yield
+            if isinstance(output, GeneratorType):
+                # if output is a generator, use for-in to get every one batch output data and yield one by one
+                for each_output in output:
+                    reassemble_data = self._check_output(each_output, input_)
+                    yield reassemble_data
+            # if output is not a generator, process all data of that and yield, so use output_list to collect all reassemble_data
+            else:
+                reassemble_data = self._check_output(output, input_)
+                output_list.extend(reassemble_data)
+
+        # avoid yielding output_list when the output is a generator
+        if len(output_list) > 0:
+            yield output_list
+
+    def _check_input(self, input_list):
+        # check if the value of input data meets the requirements of apply(),
+        # and reassemble the parameters of apply() from input_list
+        def _check_type(input_):
+            if not isinstance(input_, dict):
+                if len(self.inputs) == 1:
+                    key = list(self.inputs.keys())[0]
+                    input_ = {key: input_}
+                else:
+                    raise Exception(
+                        f"The input must be a dict or a list of dict, unless the input of the component only requires one argument, but the component({self.__class__.__name__}) requires {list(self.inputs.keys())}!"
+                    )
+            return input_
+
+        def _check_args_key(args):
+            sig = inspect.signature(self.apply)
+            for param in sig.parameters.values():
+                if param.default == inspect.Parameter.empty and param.name not in args:
+                    raise Exception(
+                        f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"
+                    )
+
+        if self.need_batch_input:
+            args = {}
+            for input_ in input_list:
+                input_ = _check_type(input_)
+                for k, v in self.inputs.items():
+                    if v not in input_:
+                        raise Exception(
+                            f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
+                        )
+                    if k not in args:
+                        args[k] = []
+                    args[k].append(input_.get(v))
+                _check_args_key(args)
+            reassemble_input = [(args, input_list)]
+        else:
+            reassemble_input = []
+            for input_ in input_list:
+                input_ = _check_type(input_)
+                args = {}
+                for k, v in self.inputs.items():
+                    if v not in input_:
+                        raise Exception(
+                            f"The value ({v}) is needed by {self.__class__.__name__}. But not found in Data ({input_.keys()})!"
+                        )
+                    args[k] = input_.get(v)
+                _check_args_key(args)
+                reassemble_input.append((args, input_))
+        return reassemble_input
+
+    def _check_output(self, output, ori_data):
+        # check if the value of apply() output data meets the requirements of setting
+        # when the output data is list type, reassemble each of that
+        if isinstance(output, list):
+            if self.need_batch_input:
+                assert isinstance(ori_data, list) and len(ori_data) == len(output)
+                output_list = []
+                for ori_item, output_item in zip(ori_data, output):
+                    data = ori_item.copy() if self.keep_ori else {}
+                    for k, v in self.outputs.items():
+                        if k not in output_item:
+                            raise Exception(
+                                f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
+                            )
+                        data.update({v: output_item[k]})
+                    output_list.append(data)
+                return output_list
+            else:
+                assert isinstance(ori_data, dict)
+                output_list = []
+                for output_item in output:
+                    data = ori_data.copy() if self.keep_ori else {}
+                    for k, v in self.outputs.items():
+                        if k not in output_item:
+                            raise Exception(
+                                f"The value ({k}) is needed by {self.__class__.__name__}. But not found in Data ({output_item.keys()})!"
+                            )
+                        data.update({v: output_item[k]})
+                    output_list.append(data)
+                return output_list
+        else:
+            assert isinstance(ori_data, dict) and isinstance(output, dict)
+            data = ori_data.copy() if self.keep_ori else {}
+            for k, v in self.outputs.items():
+                if k not in output:
+                    raise Exception(
+                        f"The value of key ({k}) is needed add to Data. But not found in output of {self.__class__.__name__}: ({output.keys()})!"
+                    )
+                data.update({v: output[k]})
+        return [data]
+
+    def set_inputs(self, inputs):
+        assert isinstance(inputs, dict)
+        input_keys = deepcopy(self.INPUT_KEYS)
+
+        # e.g, input_keys is None or []
+        if input_keys is None or (
+            isinstance(input_keys, list) and len(input_keys) == 0
+        ):
+            self.inputs = {}
+            if inputs:
+                raise Exception
+            return
+
+        # e.g, input_keys is 'img'
+        if not isinstance(input_keys, list):
+            input_keys = [[input_keys]]
+        # e.g, input_keys is ['img'] or [['img']]
+        elif len(input_keys) > 0:
+            # e.g, input_keys is ['img']
+            if not isinstance(input_keys[0], list):
+                input_keys = [input_keys]
+
+        ck_pass = False
+        for key_group in input_keys:
+            for key in key_group:
+                if key not in inputs:
+                    break
+            # check pass
+            else:
+                ck_pass = True
+            if ck_pass == True:
+                break
+        else:
+            raise Exception(
+                f"The input {input_keys} are needed by {self.__class__.__name__}. But only get: {list(inputs.keys())}"
+            )
+        self.inputs = inputs
+
+    def set_outputs(self, outputs):
+        assert isinstance(outputs, dict)
+        output_keys = deepcopy(self.OUTPUT_KEYS)
+        if not isinstance(output_keys, list):
+            output_keys = [output_keys]
+
+        for k in output_keys:
+            if k not in outputs:
+                logging.debug(
+                    f"The output ({k}) of {self.__class__.__name__} would be abandon!"
+                )
+        self.outputs = outputs
+
+    @classmethod
+    def get_input_keys(cls) -> list:
+        return cls.input_keys
+
+    @classmethod
+    def get_output_keys(cls) -> list:
+        return cls.output_keys
+
+    @property
+    def need_batch_input(self):
+        return getattr(self, "ENABLE_BATCH", False)
+
+    @property
+    def keep_ori(self):
+        return getattr(self, "KEEP_INPUT", True)
+
+
+class ComponentsEngine(object):
+    def __init__(self, ops):
+        self.ops = ops
+        self.keys = list(ops.keys())
+
+    def __call__(self, data, i=0):
+        data_gen = self.ops[self.keys[i]](data)
+        if i + 1 < len(self.ops):
+            for data in data_gen:
+                yield from self.__call__(data, i + 1)
+        else:
+            yield from data_gen

+ 16 - 0
paddlex/inference/components/paddle_predictor/__init__.py

@@ -0,0 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .option import PaddlePredictorOption
+from .image_predictor import ImagePredictor

+ 170 - 0
paddlex/inference/components/paddle_predictor/base_predictor.py

@@ -0,0 +1,170 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from abc import abstractmethod
+
+import paddle
+from paddle.inference import Config, create_predictor
+
+from ..base import BaseComponent
+from ....utils import logging
+
+
+class BasePaddlePredictor(BaseComponent):
+    """Predictor based on Paddle Inference"""
+
+    INPUT_KEYS = "imgs"
+    OUTPUT_KEYS = "pred"
+    DEAULT_INPUTS = {"x": "x"}
+    DEAULT_OUTPUTS = {"pred": "pred"}
+    ENABLE_BATCH = True
+
+    def __init__(self, model_dir, model_prefix, option):
+        super().__init__()
+        (
+            self.predictor,
+            self.inference_config,
+            self.input_names,
+            self.input_handlers,
+            self.output_handlers,
+        ) = self._create(model_dir, model_prefix, option)
+
+    def _create(self, model_dir, model_prefix, option):
+        """_create"""
+        use_pir = (
+            hasattr(paddle.framework, "use_pir_api") and paddle.framework.use_pir_api()
+        )
+        model_postfix = ".json" if use_pir else ".pdmodel"
+        model_file = (model_dir / f"{model_prefix}{model_postfix}").as_posix()
+        params_file = (model_dir / f"{model_prefix}.pdiparams").as_posix()
+        config = Config(model_file, params_file)
+
+        if option.device == "gpu":
+            config.enable_use_gpu(200, option.device_id)
+            if paddle.is_compiled_with_rocm():
+                os.environ["FLAGS_conv_workspace_size_limit"] = "2000"
+            elif hasattr(config, "enable_new_ir"):
+                config.enable_new_ir(True)
+        elif option.device == "npu":
+            config.enable_custom_device("npu")
+            os.environ["FLAGS_npu_jit_compile"] = "0"
+            os.environ["FLAGS_use_stride_kernel"] = "0"
+            os.environ["FLAGS_allocator_strategy"] = "auto_growth"
+            os.environ["CUSTOM_DEVICE_BLACK_LIST"] = (
+                "pad3d,pad3d_grad,set_value,set_value_with_tensor"
+            )
+            os.environ["FLAGS_npu_scale_aclnn"] = "True"
+            os.environ["FLAGS_npu_split_aclnn"] = "True"
+        elif option.device == "xpu":
+            os.environ["BKCL_FORCE_SYNC"] = "1"
+            os.environ["BKCL_TIMEOUT"] = "1800"
+            os.environ["FLAGS_use_stride_kernel"] = "0"
+        elif option.device == "mlu":
+            config.enable_custom_device("mlu")
+            os.environ["FLAGS_use_stride_kernel"] = "0"
+        else:
+            assert option.device == "cpu"
+            config.disable_gpu()
+            if "mkldnn" in option.run_mode:
+                try:
+                    config.enable_mkldnn()
+                    config.set_cpu_math_library_num_threads(option.cpu_threads)
+                    if "bf16" in option.run_mode:
+                        config.enable_mkldnn_bfloat16()
+                except Exception as e:
+                    logging.warning(
+                        "MKL-DNN is not available. We will disable MKL-DNN."
+                    )
+
+        precision_map = {
+            "trt_int8": Config.Precision.Int8,
+            "trt_fp32": Config.Precision.Float32,
+            "trt_fp16": Config.Precision.Half,
+        }
+        if option.run_mode in precision_map.keys():
+            config.enable_tensorrt_engine(
+                workspace_size=(1 << 25) * option.batch_size,
+                max_batch_size=option.batch_size,
+                min_subgraph_size=option.min_subgraph_size,
+                precision_mode=precision_map[option.run_mode],
+                trt_use_static=option.trt_use_static,
+                use_calib_mode=option.trt_calib_mode,
+            )
+
+            if option.shape_info_filename is not None:
+                if not os.path.exists(option.shape_info_filename):
+                    config.collect_shape_range_info(option.shape_info_filename)
+                    logging.info(
+                        f"Dynamic shape info is collected into: {option.shape_info_filename}"
+                    )
+                else:
+                    logging.info(
+                        f"A dynamic shape info file ( {option.shape_info_filename} ) already exists. \
+No need to generate again."
+                    )
+                config.enable_tuned_tensorrt_dynamic_shape(
+                    option.shape_info_filename, True
+                )
+
+        # Disable paddle inference logging
+        config.disable_glog_info()
+        for del_p in option.delete_pass:
+            config.delete_pass(del_p)
+        # Enable shared memory
+        config.enable_memory_optim()
+        config.switch_ir_optim(True)
+        # Disable feed, fetch OP, needed by zero_copy_run
+        config.switch_use_feed_fetch_ops(False)
+        predictor = create_predictor(config)
+
+        # Get input and output handlers
+        input_names = predictor.get_input_names()
+        input_handlers = []
+        output_handlers = []
+        for input_name in input_names:
+            input_handler = predictor.get_input_handle(input_name)
+            input_handlers.append(input_handler)
+        output_names = predictor.get_output_names()
+        for output_name in output_names:
+            output_handler = predictor.get_output_handle(output_name)
+            output_handlers.append(output_handler)
+        return predictor, config, input_names, input_handlers, output_handlers
+
+    def get_input_names(self):
+        """get input names"""
+        return self.input_names
+
+    def apply(self, imgs):
+        x = self.to_batch(imgs)
+        for idx in range(len(x)):
+            self.input_handlers[idx].reshape(x[idx].shape)
+            self.input_handlers[idx].copy_from_cpu(x[idx])
+
+        self.predictor.run()
+
+        output = []
+        for out_tensor in self.output_handlers:
+            out_arr = out_tensor.copy_to_cpu()
+            output.append(out_arr)
+
+        return self.format_output(output)
+
+    @abstractmethod
+    def to_batch(self, imgs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def format_output(self, output):
+        raise NotImplementedError

+ 26 - 0
paddlex/inference/components/paddle_predictor/image_predictor.py

@@ -0,0 +1,26 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from .base_predictor import BasePaddlePredictor
+
+
+class ImagePredictor(BasePaddlePredictor):
+
+    def to_batch(self, imgs):
+        return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]
+
+    def format_output(self, output):
+        return [{"pred": np.array(res)} for res in output[0].tolist()]

+ 171 - 0
paddlex/inference/components/paddle_predictor/option.py

@@ -0,0 +1,171 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from functools import wraps, partial
+
+from ....utils import logging
+
+
+def register(register_map, key):
+    """register the option setting func"""
+
+    def decorator(func):
+        register_map[key] = func
+
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            return func(self, *args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+class PaddlePredictorOption(object):
+    """Paddle Inference Engine Option"""
+
+    SUPPORT_RUN_MODE = (
+        "paddle",
+        "trt_fp32",
+        "trt_fp16",
+        "trt_int8",
+        "mkldnn",
+        "mkldnn_bf16",
+    )
+    SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu")
+    _REGISTER_MAP = {}
+
+    register2self = partial(register, _REGISTER_MAP)
+
+    def __init__(self, **kwargs):
+        super().__init__()
+        self._cfg = {}
+        self._init_option(**kwargs)
+
+    def _init_option(self, **kwargs):
+        for k, v in kwargs.items():
+            if k not in self._REGISTER_MAP:
+                raise Exception(
+                    f"{k} is not supported to set! The supported option is: \
+{list(self._REGISTER_MAP.keys())}"
+                )
+            self._REGISTER_MAP.get(k)(self, v)
+        for k, v in self._get_default_config().items():
+            self._cfg.setdefault(k, v)
+
+    def _get_default_config(cls):
+        """get default config"""
+        return {
+            "run_mode": "paddle",
+            "batch_size": 1,
+            "device": "gpu",
+            "device_id": 0,
+            "min_subgraph_size": 3,
+            "shape_info_filename": None,
+            "trt_calib_mode": False,
+            "cpu_threads": 1,
+            "trt_use_static": False,
+            "delete_pass": [],
+        }
+
+    @register2self("run_mode")
+    def set_run_mode(self, run_mode: str):
+        """set run mode"""
+        if run_mode not in self.SUPPORT_RUN_MODE:
+            support_run_mode_str = ", ".join(self.SUPPORT_RUN_MODE)
+            raise ValueError(
+                f"`run_mode` must be {support_run_mode_str}, but received {repr(run_mode)}."
+            )
+        self._cfg["run_mode"] = run_mode
+
+    @register2self("batch_size")
+    def set_batch_size(self, batch_size: int):
+        """set batch size"""
+        if not isinstance(batch_size, int) or batch_size < 1:
+            raise Exception()
+        self._cfg["batch_size"] = batch_size
+
+    @register2self("device")
+    def set_device(self, device_setting: str):
+        """set device"""
+        if len(device_setting.split(":")) == 1:
+            device = device_setting.split(":")[0]
+            device_id = 0
+        else:
+            assert len(device_setting.split(":")) == 2
+            device = device_setting.split(":")[0]
+            device_id = device_setting.split(":")[1].split(",")[0]
+            logging.warning(f"The device id has been set to {device_id}.")
+
+        if device.lower() not in self.SUPPORT_DEVICE:
+            support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
+            raise ValueError(
+                f"`device` must be {support_run_mode_str}, but received {repr(device)}."
+            )
+        self._cfg["device"] = device.lower()
+        self._cfg["device_id"] = int(device_id)
+
+    @register2self("min_subgraph_size")
+    def set_min_subgraph_size(self, min_subgraph_size: int):
+        """set min subgraph size"""
+        if not isinstance(min_subgraph_size, int):
+            raise Exception()
+        self._cfg["min_subgraph_size"] = min_subgraph_size
+
+    @register2self("shape_info_filename")
+    def set_shape_info_filename(self, shape_info_filename: str):
+        """set shape info filename"""
+        self._cfg["shape_info_filename"] = shape_info_filename
+
+    @register2self("trt_calib_mode")
+    def set_trt_calib_mode(self, trt_calib_mode):
+        """set trt calib mode"""
+        self._cfg["trt_calib_mode"] = trt_calib_mode
+
+    @register2self("cpu_threads")
+    def set_cpu_threads(self, cpu_threads):
+        """set cpu threads"""
+        if not isinstance(cpu_threads, int) or cpu_threads < 1:
+            raise Exception()
+        self._cfg["cpu_threads"] = cpu_threads
+
+    @register2self("trt_use_static")
+    def set_trt_use_static(self, trt_use_static):
+        """set trt use static"""
+        self._cfg["trt_use_static"] = trt_use_static
+
+    @register2self("delete_pass")
+    def set_delete_pass(self, delete_pass):
+        self._cfg["delete_pass"] = delete_pass
+
+    def get_support_run_mode(self):
+        """get supported run mode"""
+        return self.SUPPORT_RUN_MODE
+
+    def get_support_device(self):
+        """get supported device"""
+        return self.SUPPORT_DEVICE
+
+    def get_device(self):
+        """get device"""
+        return f"{self._cfg['device']}:{self._cfg['device_id']}"
+
+    def __str__(self):
+        return ",  ".join([f"{k}: {v}" for k, v in self._cfg.items()])
+
+    def __getattr__(self, key):
+        if key not in self._cfg:
+            raise Exception(f"The key ({key}) is not found in cfg: \n {self._cfg}")
+        return self._cfg.get(key)

+ 17 - 0
paddlex/inference/components/task_related/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .clas import Topk
+from .text_det import DetResizeForTest, NormalizeImage, DBPostProcess, CropByPolys
+from .text_rec import OCRReisizeNormImg, CTCLabelDecode

+ 84 - 0
paddlex/inference/components/task_related/clas.py

@@ -0,0 +1,84 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ....utils import logging
+from ...results import TopkResult
+from ..base import BaseComponent
+
+
+__all__ = ["Topk", "NormalizeFeatures"]
+
+
+def _parse_class_id_map(class_ids):
+    """parse class id to label map file"""
+    if class_ids is None:
+        return None
+    class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
+    return class_id_map
+
+
+class Topk(BaseComponent):
+    """Topk Transform"""
+
+    INPUT_KEYS = ["pred", "img_path"]
+    OUTPUT_KEYS = ["topk_res"]
+    DEAULT_INPUTS = {"pred": "pred", "img_path": "img_path"}
+    DEAULT_OUTPUTS = {"topk_res": "topk_res"}
+
+    def __init__(self, topk, class_ids=None):
+        super().__init__()
+        assert isinstance(topk, (int,))
+        self.topk = topk
+        self.class_id_map = _parse_class_id_map(class_ids)
+
+    def apply(self, pred, img_path):
+        """apply"""
+        cls_pred = pred
+        class_id_map = self.class_id_map
+
+        index = cls_pred.argsort(axis=0)[-self.topk :][::-1].astype("int32")
+        clas_id_list = []
+        score_list = []
+        label_name_list = []
+        for i in index:
+            clas_id_list.append(i.item())
+            score_list.append(cls_pred[i].item())
+            if class_id_map is not None:
+                label_name_list.append(class_id_map[i.item()])
+        result = {
+            "img_path": img_path,
+            "class_ids": clas_id_list,
+            "scores": np.around(score_list, decimals=5).tolist(),
+        }
+        if label_name_list is not None:
+            result["label_names"] = label_name_list
+
+        return {"topk_res": TopkResult(result)}
+
+
+class NormalizeFeatures(BaseComponent):
+    """Normalize Features Transform"""
+
+    INPUT_KEYS = ["cls_pred"]
+    OUTPUT_KEYS = ["cls_res"]
+    DEAULT_INPUTS = {"cls_res": "cls_res"}
+    DEAULT_OUTPUTS = {"cls_pred": "cls_pred"}
+
+    def apply(self, cls_pred):
+        """apply"""
+        feas_norm = np.sqrt(np.sum(np.square(cls_pred), axis=0, keepdims=True))
+        cls_res = np.divide(cls_pred, feas_norm)
+        return {"cls_res": cls_res}

+ 543 - 0
paddlex/inference/components/task_related/text_det.py

@@ -0,0 +1,543 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import sys
+import cv2
+import copy
+import math
+import pyclipper
+import numpy as np
+from PIL import Image
+from shapely.geometry import Polygon
+
+from ...utils.io import ImageReader
+from ....utils import logging
+from ...results import TextDetResult
+from ..base import BaseComponent
+
+
+__all__ = ["DetResizeForTest", "NormalizeImage", "DBPostProcess", "CropByPolys"]
+
+
+class DetResizeForTest(BaseComponent):
+    """DetResizeForTest"""
+
+    INPUT_KEYS = ["img"]
+    OUTPUT_KEYS = ["img", "img_shape"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img", "img_shape": "img_shape"}
+
+    def __init__(self, **kwargs):
+        super().__init__()
+        self.resize_type = 0
+        self.keep_ratio = False
+        if "image_shape" in kwargs:
+            self.image_shape = kwargs["image_shape"]
+            self.resize_type = 1
+            if "keep_ratio" in kwargs:
+                self.keep_ratio = kwargs["keep_ratio"]
+        elif "limit_side_len" in kwargs:
+            self.limit_side_len = kwargs["limit_side_len"]
+            self.limit_type = kwargs.get("limit_type", "min")
+        elif "resize_long" in kwargs:
+            self.resize_type = 2
+            self.resize_long = kwargs.get("resize_long", 960)
+        else:
+            self.limit_side_len = 736
+            self.limit_type = "min"
+
+    def apply(self, img):
+        """apply"""
+        src_h, src_w, _ = img.shape
+        if sum([src_h, src_w]) < 64:
+            img = self.image_padding(img)
+
+        if self.resize_type == 0:
+            # img, shape = self.resize_image_type0(img)
+            img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+        elif self.resize_type == 2:
+            img, [ratio_h, ratio_w] = self.resize_image_type2(img)
+        else:
+            # img, shape = self.resize_image_type1(img)
+            img, [ratio_h, ratio_w] = self.resize_image_type1(img)
+        return {"img": img, "img_shape": np.array([src_h, src_w, ratio_h, ratio_w])}
+
+    def image_padding(self, im, value=0):
+        """padding image"""
+        h, w, c = im.shape
+        im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
+        im_pad[:h, :w, :] = im
+        return im_pad
+
+    def resize_image_type1(self, img):
+        """resize the image"""
+        resize_h, resize_w = self.image_shape
+        ori_h, ori_w = img.shape[:2]  # (h, w, c)
+        if self.keep_ratio is True:
+            resize_w = ori_w * resize_h / ori_h
+            N = math.ceil(resize_w / 32)
+            resize_w = N * 32
+        ratio_h = float(resize_h) / ori_h
+        ratio_w = float(resize_w) / ori_w
+        img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        # return img, np.array([ori_h, ori_w])
+        return img, [ratio_h, ratio_w]
+
+    def resize_image_type0(self, img):
+        """
+        resize image to a size multiple of 32 which is required by the network
+        args:
+            img(array): array with shape [h, w, c]
+        return(tuple):
+            img, (ratio_h, ratio_w)
+        """
+        limit_side_len = self.limit_side_len
+        h, w, c = img.shape
+
+        # limit the max side
+        if self.limit_type == "max":
+            if max(h, w) > limit_side_len:
+                if h > w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.0
+        elif self.limit_type == "min":
+            if min(h, w) < limit_side_len:
+                if h < w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.0
+        elif self.limit_type == "resize_long":
+            ratio = float(limit_side_len) / max(h, w)
+        else:
+            raise Exception("not support limit type, image ")
+        resize_h = int(h * ratio)
+        resize_w = int(w * ratio)
+
+        resize_h = max(int(round(resize_h / 32) * 32), 32)
+        resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+        try:
+            if int(resize_w) <= 0 or int(resize_h) <= 0:
+                return None, (None, None)
+            img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        except:
+            logging.info(img.shape, resize_w, resize_h)
+            sys.exit(0)
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+        return img, [ratio_h, ratio_w]
+
+    def resize_image_type2(self, img):
+        """resize image size"""
+        h, w, _ = img.shape
+
+        resize_w = w
+        resize_h = h
+
+        if resize_h > resize_w:
+            ratio = float(self.resize_long) / resize_h
+        else:
+            ratio = float(self.resize_long) / resize_w
+
+        resize_h = int(resize_h * ratio)
+        resize_w = int(resize_w * ratio)
+
+        max_stride = 128
+        resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
+        resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
+        img = cv2.resize(img, (int(resize_w), int(resize_h)))
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+
+        return img, [ratio_h, ratio_w]
+
+
+class NormalizeImage(BaseComponent):
+    """normalize image such as substract mean, divide std"""
+
+    INPUT_KEYS = ["img"]
+    OUTPUT_KEYS = ["img"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img"}
+
+    def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+        super().__init__()
+        if isinstance(scale, str):
+            scale = eval(scale)
+        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        mean = mean if mean is not None else [0.485, 0.456, 0.406]
+        std = std if std is not None else [0.229, 0.224, 0.225]
+
+        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+        self.mean = np.array(mean).reshape(shape).astype("float32")
+        self.std = np.array(std).reshape(shape).astype("float32")
+
+    def apply(self, img):
+        """apply"""
+        from PIL import Image
+
+        if isinstance(img, Image.Image):
+            img = np.array(img)
+        assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+        img = (img.astype("float32") * self.scale - self.mean) / self.std
+        return {"img": img}
+
+
+class DBPostProcess(BaseComponent):
+    """
+    The post process for Differentiable Binarization (DB).
+    """
+
+    INPUT_KEYS = ["pred", "img_shape", "img_path"]
+    OUTPUT_KEYS = ["text_det_res"]
+    DEAULT_INPUTS = {"pred": "pred", "img_shape": "img_shape", "img_path": "img_path"}
+    DEAULT_OUTPUTS = {"text_det_res": "text_det_res"}
+
+    def __init__(
+        self,
+        thresh=0.3,
+        box_thresh=0.7,
+        max_candidates=1000,
+        unclip_ratio=2.0,
+        use_dilation=False,
+        score_mode="fast",
+        box_type="quad",
+        **kwargs
+    ):
+        super().__init__()
+        self.thresh = thresh
+        self.box_thresh = box_thresh
+        self.max_candidates = max_candidates
+        self.unclip_ratio = unclip_ratio
+        self.min_size = 3
+        self.score_mode = score_mode
+        self.box_type = box_type
+        assert score_mode in [
+            "slow",
+            "fast",
+        ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
+
+        self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
+
+    def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+        """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
+
+        bitmap = _bitmap
+        height, width = bitmap.shape
+
+        boxes = []
+        scores = []
+
+        contours, _ = cv2.findContours(
+            (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+        )
+
+        for contour in contours[: self.max_candidates]:
+            epsilon = 0.002 * cv2.arcLength(contour, True)
+            approx = cv2.approxPolyDP(contour, epsilon, True)
+            points = approx.reshape((-1, 2))
+            if points.shape[0] < 4:
+                continue
+
+            score = self.box_score_fast(pred, points.reshape(-1, 2))
+            if self.box_thresh > score:
+                continue
+
+            if points.shape[0] > 2:
+                box = self.unclip(points, self.unclip_ratio)
+                if len(box) > 1:
+                    continue
+            else:
+                continue
+            box = box.reshape(-1, 2)
+
+            _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
+            if sside < self.min_size + 2:
+                continue
+
+            box = np.array(box)
+            box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
+            box[:, 1] = np.clip(
+                np.round(box[:, 1] / height * dest_height), 0, dest_height
+            )
+            boxes.append(box.tolist())
+            scores.append(score)
+        return boxes, scores
+
+    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+        """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
+
+        bitmap = _bitmap
+        height, width = bitmap.shape
+
+        outs = cv2.findContours(
+            (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
+        )
+        if len(outs) == 3:
+            img, contours, _ = outs[0], outs[1], outs[2]
+        elif len(outs) == 2:
+            contours, _ = outs[0], outs[1]
+
+        num_contours = min(len(contours), self.max_candidates)
+
+        boxes = []
+        scores = []
+        for index in range(num_contours):
+            contour = contours[index]
+            points, sside = self.get_mini_boxes(contour)
+            if sside < self.min_size:
+                continue
+            points = np.array(points)
+            if self.score_mode == "fast":
+                score = self.box_score_fast(pred, points.reshape(-1, 2))
+            else:
+                score = self.box_score_slow(pred, contour)
+            if self.box_thresh > score:
+                continue
+
+            box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
+            box, sside = self.get_mini_boxes(box)
+            if sside < self.min_size + 2:
+                continue
+            box = np.array(box)
+
+            box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
+            box[:, 1] = np.clip(
+                np.round(box[:, 1] / height * dest_height), 0, dest_height
+            )
+            boxes.append(box.astype(np.int16))
+            scores.append(score)
+        return np.array(boxes, dtype=np.int16), scores
+
+    def unclip(self, box, unclip_ratio):
+        """unclip"""
+        poly = Polygon(box)
+        distance = poly.area * unclip_ratio / poly.length
+        offset = pyclipper.PyclipperOffset()
+        offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+        expanded = np.array(offset.Execute(distance))
+        return expanded
+
+    def get_mini_boxes(self, contour):
+        """get mini boxes"""
+        bounding_box = cv2.minAreaRect(contour)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_1, index_2, index_3, index_4 = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_1 = 0
+            index_4 = 1
+        else:
+            index_1 = 1
+            index_4 = 0
+        if points[3][1] > points[2][1]:
+            index_2 = 2
+            index_3 = 3
+        else:
+            index_2 = 3
+            index_3 = 2
+
+        box = [points[index_1], points[index_2], points[index_3], points[index_4]]
+        return box, min(bounding_box[1])
+
+    def box_score_fast(self, bitmap, _box):
+        """box_score_fast: use bbox mean score as the mean score"""
+        h, w = bitmap.shape[:2]
+        box = _box.copy()
+        xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
+        xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
+        ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
+        ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
+
+        mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+        box[:, 0] = box[:, 0] - xmin
+        box[:, 1] = box[:, 1] - ymin
+        cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
+        return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
+
+    def box_score_slow(self, bitmap, contour):
+        """box_score_slow: use polyon mean score as the mean score"""
+        h, w = bitmap.shape[:2]
+        contour = contour.copy()
+        contour = np.reshape(contour, (-1, 2))
+
+        xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
+        xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
+        ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
+        ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
+
+        mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
+
+        contour[:, 0] = contour[:, 0] - xmin
+        contour[:, 1] = contour[:, 1] - ymin
+
+        cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
+        return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
+
+    def apply(self, pred, img_shape, img_path):
+        """apply"""
+        pred = pred[0, :, :]
+        segmentation = pred > self.thresh
+
+        src_h, src_w, ratio_h, ratio_w = img_shape
+        if self.dilation_kernel is not None:
+            mask = cv2.dilate(
+                np.array(segmentation).astype(np.uint8),
+                self.dilation_kernel,
+            )
+        else:
+            mask = segmentation
+        if self.box_type == "poly":
+            boxes, scores = self.polygons_from_bitmap(pred, mask, src_w, src_h)
+        elif self.box_type == "quad":
+            boxes, scores = self.boxes_from_bitmap(pred, mask, src_w, src_h)
+        else:
+            raise ValueError("box_type can only be one of ['quad', 'poly']")
+
+        text_det_res = TextDetResult(
+            {"img_path": img_path, "dt_polys": boxes, "dt_scores": scores}
+        )
+        return {"text_det_res": text_det_res}
+
+
+class CropByPolys(BaseComponent):
+    """Crop Image by Polys"""
+
+    INPUT_KEYS = ["img_path", "dt_polys"]
+    OUTPUT_KEYS = ["img"]
+    DEAULT_INPUTS = {"img_path": "img_path", "dt_polys": "dt_polys"}
+    DEAULT_OUTPUTS = {"img": "img"}
+
+    def __init__(self, det_box_type="quad"):
+        super().__init__()
+        self.det_box_type = det_box_type
+        self._reader = ImageReader(backend="opencv")
+
+    def apply(self, img_path, dt_polys):
+        """apply"""
+        img = self._reader.read(img_path)
+        dt_boxes = np.array(dt_polys)
+        # TODO
+        # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
+        output_list = []
+        for bno in range(len(dt_boxes)):
+            tmp_box = copy.deepcopy(dt_boxes[bno])
+            if self.det_box_type == "quad":
+                img_crop = self.get_rotate_crop_image(img, tmp_box)
+            else:
+                img_crop = self.get_minarea_rect_crop(img, tmp_box)
+            output_list.append(
+                {"img": img_crop, "img_size": [img_crop.shape[1], img_crop.shape[0]]}
+            )
+        return output_list
+
+    def sorted_boxes(self, dt_boxes):
+        """
+        Sort text boxes in order from top to bottom, left to right
+        args:
+            dt_boxes(array):detected text boxes with shape [4, 2]
+        return:
+            sorted boxes(array) with shape [4, 2]
+        """
+        dt_boxes = np.array(dt_boxes)
+        num_boxes = dt_boxes.shape[0]
+        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+        _boxes = list(sorted_boxes)
+
+        for i in range(num_boxes - 1):
+            for j in range(i, -1, -1):
+                if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
+                    _boxes[j + 1][0][0] < _boxes[j][0][0]
+                ):
+                    tmp = _boxes[j]
+                    _boxes[j] = _boxes[j + 1]
+                    _boxes[j + 1] = tmp
+                else:
+                    break
+        return _boxes
+
+    def get_minarea_rect_crop(self, img, points):
+        """get_minarea_rect_crop"""
+        bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img
+
+    def get_rotate_crop_image(self, img, points):
+        """
+        img_height, img_width = img.shape[0:2]
+        left = int(np.min(points[:, 0]))
+        right = int(np.max(points[:, 0]))
+        top = int(np.min(points[:, 1]))
+        bottom = int(np.max(points[:, 1]))
+        img_crop = img[top:bottom, left:right, :].copy()
+        points[:, 0] = points[:, 0] - left
+        points[:, 1] = points[:, 1] - top
+        """
+        assert len(points) == 4, "shape of points must be 4*2"
+        img_crop_width = int(
+            max(
+                np.linalg.norm(points[0] - points[1]),
+                np.linalg.norm(points[2] - points[3]),
+            )
+        )
+        img_crop_height = int(
+            max(
+                np.linalg.norm(points[0] - points[3]),
+                np.linalg.norm(points[1] - points[2]),
+            )
+        )
+        pts_std = np.float32(
+            [
+                [0, 0],
+                [img_crop_width, 0],
+                [img_crop_width, img_crop_height],
+                [0, img_crop_height],
+            ]
+        )
+        M = cv2.getPerspectiveTransform(points, pts_std)
+        dst_img = cv2.warpPerspective(
+            img,
+            M,
+            (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_REPLICATE,
+            flags=cv2.INTER_CUBIC,
+        )
+        dst_img_height, dst_img_width = dst_img.shape[0:2]
+        if dst_img_height * 1.0 / dst_img_width >= 1.5:
+            dst_img = np.rot90(dst_img)
+        return dst_img

+ 449 - 0
paddlex/inference/components/task_related/text_rec.py

@@ -0,0 +1,449 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import os.path as osp
+
+import re
+import numpy as np
+from PIL import Image
+import cv2
+import math
+import paddle
+import json
+import tempfile
+from tokenizers import Tokenizer as TokenizerFast
+
+from ....utils import logging
+from ...results import TextRecResult
+from ..base import BaseComponent
+
+__all__ = [
+    "OCRReisizeNormImg",
+    # "LaTeXOCRReisizeNormImg",
+    "CTCLabelDecode",
+    # "LaTeXOCRDecode",
+]
+
+
+class OCRReisizeNormImg(BaseComponent):
+    """for ocr image resize and normalization"""
+
+    INPUT_KEYS = ["img", "img_size"]
+    OUTPUT_KEYS = ["img"]
+    DEAULT_INPUTS = {"img": "img", "img_size": "img_size"}
+    DEAULT_OUTPUTS = {"img": "img"}
+
+    def __init__(self, rec_image_shape=[3, 48, 320]):
+        super().__init__()
+        self.rec_image_shape = rec_image_shape
+
+    def resize_norm_img(self, img, max_wh_ratio):
+        """resize and normalize the img"""
+        imgC, imgH, imgW = self.rec_image_shape
+        assert imgC == img.shape[2]
+        imgW = int((imgH * max_wh_ratio))
+
+        h, w = img.shape[:2]
+        ratio = w / float(h)
+        if math.ceil(imgH * ratio) > imgW:
+            resized_w = imgW
+        else:
+            resized_w = int(math.ceil(imgH * ratio))
+        resized_image = cv2.resize(img, (resized_w, imgH))
+        resized_image = resized_image.astype("float32")
+        resized_image = resized_image.transpose((2, 0, 1)) / 255
+        resized_image -= 0.5
+        resized_image /= 0.5
+        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+        padding_im[:, :, 0:resized_w] = resized_image
+        return padding_im
+
+    def apply(self, img, img_size):
+        """apply"""
+        imgC, imgH, imgW = self.rec_image_shape
+        max_wh_ratio = imgW / imgH
+        w, h = img_size[:2]
+        wh_ratio = w * 1.0 / h
+        max_wh_ratio = max(max_wh_ratio, wh_ratio)
+        img = self.resize_norm_img(img, max_wh_ratio)
+        return {"img": img}
+
+
+# class LaTeXOCRReisizeNormImg(BaseComponent):
+#     """for ocr image resize and normalization"""
+
+#     def __init__(self, rec_image_shape=[3, 48, 320]):
+#         super().__init__()
+#         self.rec_image_shape = rec_image_shape
+
+#     def pad_(self, img, divable=32):
+#         threshold = 128
+#         data = np.array(img.convert("LA"))
+#         if data[..., -1].var() == 0:
+#             data = (data[..., 0]).astype(np.uint8)
+#         else:
+#             data = (255 - data[..., -1]).astype(np.uint8)
+#         data = (data - data.min()) / (data.max() - data.min()) * 255
+#         if data.mean() > threshold:
+#             # To invert the text to white
+#             gray = 255 * (data < threshold).astype(np.uint8)
+#         else:
+#             gray = 255 * (data > threshold).astype(np.uint8)
+#             data = 255 - data
+
+#         coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+#         a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+#         rect = data[b : b + h, a : a + w]
+#         im = Image.fromarray(rect).convert("L")
+#         dims = []
+#         for x in [w, h]:
+#             div, mod = divmod(x, divable)
+#             dims.append(divable * (div + (1 if mod > 0 else 0)))
+#         padded = Image.new("L", dims, 255)
+#         padded.paste(im, (0, 0, im.size[0], im.size[1]))
+#         return padded
+
+#     def minmax_size_(
+#         self,
+#         img,
+#         max_dimensions,
+#         min_dimensions,
+#     ):
+#         if max_dimensions is not None:
+#             ratios = [a / b for a, b in zip(img.size, max_dimensions)]
+#             if any([r > 1 for r in ratios]):
+#                 size = np.array(img.size) // max(ratios)
+#                 img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
+#         if min_dimensions is not None:
+#             # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
+#             padded_size = [
+#                 max(img_dim, min_dim)
+#                 for img_dim, min_dim in zip(img.size, min_dimensions)
+#             ]
+#             if padded_size != list(img.size):  # assert hypothesis
+#                 padded_im = Image.new("L", padded_size, 255)
+#                 padded_im.paste(img, img.getbbox())
+#                 img = padded_im
+#         return img
+
+#     def norm_img_latexocr(self, img):
+#         # CAN only predict gray scale image
+#         shape = (1, 1, 3)
+#         mean = [0.7931, 0.7931, 0.7931]
+#         std = [0.1738, 0.1738, 0.1738]
+#         scale = np.float32(1.0 / 255.0)
+#         min_dimensions = [32, 32]
+#         max_dimensions = [672, 192]
+#         mean = np.array(mean).reshape(shape).astype("float32")
+#         std = np.array(std).reshape(shape).astype("float32")
+
+#         im_h, im_w = img.shape[:2]
+#         if (
+#             min_dimensions[0] <= im_w <= max_dimensions[0]
+#             and min_dimensions[1] <= im_h <= max_dimensions[1]
+#         ):
+#             pass
+#         else:
+#             img = Image.fromarray(np.uint8(img))
+#             img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
+#             img = np.array(img)
+#             im_h, im_w = img.shape[:2]
+#             img = np.dstack([img, img, img])
+#         img = (img.astype("float32") * scale - mean) / std
+#         img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+#         divide_h = math.ceil(im_h / 16) * 16
+#         divide_w = math.ceil(im_w / 16) * 16
+#         img = np.pad(
+#             img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
+#         )
+#         img = img[:, :, np.newaxis].transpose(2, 0, 1)
+#         img = img.astype("float32")
+#         return img
+
+#     def apply(self, data):
+#         """apply"""
+#         data[K.IMAGE] = self.norm_img_latexocr(data[K.IMAGE])
+#         return data
+
+#     @classmethod
+#     def get_input_keys(cls):
+#         """get input keys"""
+#         return [K.IMAGE, K.ORI_IM_SIZE]
+
+#     @classmethod
+#     def get_output_keys(cls):
+#         """get output keys"""
+#         return [K.IMAGE]
+
+
+class BaseRecLabelDecode(BaseComponent):
+    """Convert between text-label and text-index"""
+
+    INPUT_KEYS = ["pred", "img_path"]
+    OUTPUT_KEYS = ["text_rec_res"]
+    DEAULT_INPUTS = {"pred": "pred", "img_path": "img_path"}
+    DEAULT_OUTPUTS = {"text_rec_res": "text_rec_res"}
+
+    ENABLE_BATCH = True
+
+    def __init__(self, character_str=None, use_space_char=True):
+        super().__init__()
+        self.reverse = False
+        character_list = (
+            list(character_str)
+            if character_str is not None
+            else list("0123456789abcdefghijklmnopqrstuvwxyz")
+        )
+        if use_space_char:
+            character_list.append(" ")
+
+        character_list = self.add_special_char(character_list)
+        self.dict = {}
+        for i, char in enumerate(character_list):
+            self.dict[char] = i
+        self.character = character_list
+
+    def pred_reverse(self, pred):
+        """pred_reverse"""
+        pred_re = []
+        c_current = ""
+        for c in pred:
+            if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
+                if c_current != "":
+                    pred_re.append(c_current)
+                pred_re.append(c)
+                c_current = ""
+            else:
+                c_current += c
+        if c_current != "":
+            pred_re.append(c_current)
+
+        return "".join(pred_re[::-1])
+
+    def add_special_char(self, character_list):
+        """add_special_char"""
+        return character_list
+
+    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+        """convert text-index into text-label."""
+        result_list = []
+        ignored_tokens = self.get_ignored_tokens()
+        batch_size = len(text_index)
+        for batch_idx in range(batch_size):
+            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+            if is_remove_duplicate:
+                selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
+            for ignored_token in ignored_tokens:
+                selection &= text_index[batch_idx] != ignored_token
+
+            char_list = [
+                self.character[text_id] for text_id in text_index[batch_idx][selection]
+            ]
+            if text_prob is not None:
+                conf_list = text_prob[batch_idx][selection]
+            else:
+                conf_list = [1] * len(selection)
+            if len(conf_list) == 0:
+                conf_list = [0]
+
+            text = "".join(char_list)
+
+            if self.reverse:  # for arabic rec
+                text = self.pred_reverse(text)
+
+            result_list.append((text, np.mean(conf_list).tolist()))
+        return result_list
+
+    def get_ignored_tokens(self):
+        """get_ignored_tokens"""
+        return [0]  # for ctc blank
+
+    def apply(self, pred, img_path):
+        """apply"""
+        preds = np.array(pred)
+        if isinstance(preds, tuple) or isinstance(preds, list):
+            preds = preds[-1]
+        preds_idx = preds.argmax(axis=2)
+        preds_prob = preds.max(axis=2)
+        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
+        return [
+            {
+                "text_rec_res": TextRecResult(
+                    {"img_path": path, "rec_text": t[0], "rec_score": t[1]}
+                )
+            }
+            for path, t in zip(img_path, text)
+        ]
+
+
+class CTCLabelDecode(BaseRecLabelDecode):
+    """Convert between text-label and text-index"""
+
+    def __init__(self, character_list=None, use_space_char=True):
+        super().__init__(character_list, use_space_char=use_space_char)
+
+    def apply(self, pred, img_path):
+        """apply"""
+        preds = np.array(pred)
+        preds_idx = preds.argmax(axis=2)
+        preds_prob = preds.max(axis=2)
+        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
+        return [
+            {
+                "text_rec_res": TextRecResult(
+                    {"img_path": path, "rec_text": t[0], "rec_score": t[1]}
+                )
+            }
+            for path, t in zip(img_path, text)
+        ]
+
+    def add_special_char(self, character_list):
+        """add_special_char"""
+        character_list = ["blank"] + character_list
+        return character_list
+
+
+# class LaTeXOCRDecode(object):
+#     """Convert between latex-symbol and symbol-index"""
+
+#     def __init__(self, post_process_cfg=None, **kwargs):
+#         assert post_process_cfg["name"] == "LaTeXOCRDecode"
+
+#         super(LaTeXOCRDecode, self).__init__()
+#         character_list = post_process_cfg["character_dict"]
+#         temp_path = tempfile.gettempdir()
+#         rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
+#         try:
+#             with open(rec_char_dict_path, "w") as f:
+#                 json.dump(character_list, f)
+#         except Exception as e:
+#             print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
+#         self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
+
+#     def post_process(self, s):
+#         text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
+#         letter = "[a-zA-Z]"
+#         noletter = "[\W_^\d]"
+#         names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
+#         s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+#         news = s
+#         while True:
+#             s = news
+#             news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
+#             news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
+#             news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
+#             if news == s:
+#                 break
+#         return s
+
+#     def decode(self, tokens):
+#         if len(tokens.shape) == 1:
+#             tokens = tokens[None, :]
+
+#         dec = [self.tokenizer.decode(tok) for tok in tokens]
+#         dec_str_list = [
+#             "".join(detok.split(" "))
+#             .replace("Ġ", " ")
+#             .replace("[EOS]", "")
+#             .replace("[BOS]", "")
+#             .replace("[PAD]", "")
+#             .strip()
+#             for detok in dec
+#         ]
+#         return [str(self.post_process(dec_str)) for dec_str in dec_str_list]
+
+#     def __call__(self, data):
+#         preds = data[K.REC_PROBS]
+#         text = self.decode(preds)
+#         data[K.REC_TEXT] = text[0]
+#         return data
+
+
+# class SaveTextRecResults(BaseComponent):
+#     """SaveTextRecResults"""
+
+#     _TEXT_REC_RES_SUFFIX = "_text_rec"
+#     _FILE_EXT = ".txt"
+
+#     def __init__(self, save_dir):
+#         super().__init__()
+#         self.save_dir = save_dir
+#         # We use python backend to save text object
+#         self._writer = TextWriter(backend="python")
+
+#     def apply(self, data):
+#         """apply"""
+#         ori_path = data[K.IM_PATH]
+#         file_name = os.path.basename(ori_path)
+#         file_name = self._replace_ext(file_name, self._FILE_EXT)
+#         text_rec_res_save_path = os.path.join(self.save_dir, file_name)
+#         rec_res = ""
+#         for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]):
+#             line = text + "\t" + str(score) + "\n"
+#             rec_res += line
+#         text_rec_res_save_path = self._add_suffix(
+#             text_rec_res_save_path, self._TEXT_REC_RES_SUFFIX
+#         )
+#         self._write_txt(text_rec_res_save_path, rec_res)
+#         return data
+
+#     @classmethod
+#     def get_input_keys(cls):
+#         """get_input_keys"""
+#         return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE]
+
+#     @classmethod
+#     def get_output_keys(cls):
+#         """get_output_keys"""
+#         return []
+
+#     def _write_txt(self, path, txt_str):
+#         """_write_txt"""
+#         if os.path.exists(path):
+#             logging.warning(f"{path} already exists. Overwriting it.")
+#         self._writer.write(path, txt_str)
+
+#     @staticmethod
+#     def _add_suffix(path, suffix):
+#         """_add_suffix"""
+#         stem, ext = os.path.splitext(path)
+#         return stem + suffix + ext
+
+#     @staticmethod
+#     def _replace_ext(path, new_ext):
+#         """_replace_ext"""
+#         stem, _ = os.path.splitext(path)
+#         return stem + new_ext
+
+
+# class PrintResult(BaseComponent):
+#     """Print Result Transform"""
+
+#     def apply(self, data):
+#         """apply"""
+#         logging.info("The prediction result is:")
+#         logging.info(data[K.REC_TEXT])
+#         return data
+
+#     @classmethod
+#     def get_input_keys(cls):
+#         """get input keys"""
+#         return [K.REC_TEXT]
+
+#     @classmethod
+#     def get_output_keys(cls):
+#         """get output keys"""
+#         return []

+ 15 - 0
paddlex/inference/components/transforms/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .image import *

+ 520 - 0
paddlex/inference/components/transforms/image/__init__.py

@@ -0,0 +1,520 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import math
+from pathlib import Path
+
+import numpy as np
+import cv2
+
+from .....utils.download import download
+from .....utils.cache import CACHE_DIR
+from ....utils.io import ImageReader, ImageWriter
+from ...base import BaseComponent
+from . import funcs as F
+
+__all__ = [
+    "ReadImage",
+    "Flip",
+    "Crop",
+    "Resize",
+    "ResizeByLong",
+    "ResizeByShort",
+    "Pad",
+    "Normalize",
+    "ToCHWImage",
+]
+
+
+def _check_image_size(input_):
+    """check image size"""
+    if not (
+        isinstance(input_, (list, tuple))
+        and len(input_) == 2
+        and isinstance(input_[0], int)
+        and isinstance(input_[1], int)
+    ):
+        raise TypeError(f"{input_} cannot represent a valid image size.")
+
+
+class ReadImage(BaseComponent):
+    """Load image from the file."""
+
+    INPUT_KEYS = ["img"]
+    OUTPUT_KEYS = ["img", "img_size"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img", "img_path": "img_path", "img_size": "img_size"}
+
+    _FLAGS_DICT = {
+        "BGR": cv2.IMREAD_COLOR,
+        "RGB": cv2.IMREAD_COLOR,
+        "GRAY": cv2.IMREAD_GRAYSCALE,
+    }
+    SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
+
+    def __init__(self, batch_size=1, format="BGR"):
+        """
+        Initialize the instance.
+
+        Args:
+            format (str, optional): Target color format to convert the image to.
+                Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
+        """
+        super().__init__()
+        self.batch_size = batch_size
+        self.format = format
+        flags = self._FLAGS_DICT[self.format]
+        self._reader = ImageReader(backend="opencv", flags=flags)
+        self._writer = ImageWriter(backend="opencv")
+
+    def apply(self, img):
+        """apply"""
+        if not isinstance(img, str):
+            img_path = (Path(CACHE_DIR) / "predict_input" / "tmp_img.jpg").as_posix()
+            self._writer.write(img_path, img)
+            yield [
+                {
+                    "img_path": img_path,
+                    "img": img,
+                    "img_size": [img.shape[1], img.shape[0]],
+                }
+            ]
+        else:
+            img_path = img
+            # XXX: auto download for url
+            img_path = self._download_from_url(img_path)
+            image_list = self._get_image_list(img_path)
+            batch = []
+            for img_path in image_list:
+                img = self._read_img(img_path)
+                batch.append(img)
+                if len(batch) >= self.batch_size:
+                    yield batch
+                    batch = []
+            if len(batch) > 0:
+                yield batch
+
+    def _read_img(self, img_path):
+        blob = self._reader.read(img_path)
+        if blob is None:
+            raise Exception("Image read Error")
+
+        if self.format == "RGB":
+            if blob.ndim != 3:
+                raise RuntimeError("Array is not 3-dimensional.")
+            # BGR to RGB
+            blob = blob[..., ::-1]
+        return {
+            "img_path": img_path,
+            "img": blob,
+            "img_size": [blob.shape[1], blob.shape[0]],
+        }
+
+    def _download_from_url(self, in_path):
+        if in_path.startswith("http"):
+            file_name = Path(in_path).name
+            save_path = Path(CACHE_DIR) / "predict_input" / file_name
+            download(in_path, save_path, overwrite=True)
+            return save_path.as_posix()
+        return in_path
+
+    def _get_image_list(self, img_file):
+        imgs_lists = []
+        if img_file is None or not os.path.exists(img_file):
+            raise Exception("not found any img file in {}".format(img_file))
+
+        if os.path.isfile(img_file) and img_file.split(".")[-1] in self.SUFFIX:
+            imgs_lists.append(img_file)
+        elif os.path.isdir(img_file):
+            for root, dirs, files in os.walk(img_file):
+                for single_file in files:
+                    if single_file.split(".")[-1] in self.SUFFIX:
+                        imgs_lists.append(os.path.join(root, single_file))
+        if len(imgs_lists) == 0:
+            raise Exception("not found any img file in {}".format(img_file))
+        imgs_lists = sorted(imgs_lists)
+        return imgs_lists
+
+
+class GetImageInfo(BaseComponent):
+    """Get Image Info"""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = "img_size"
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img_size": "img_size"}
+
+    def __init__(self):
+        super().__init__()
+
+    def apply(self, img):
+        """apply"""
+        return {"img_size": [img.shape[1], img.shape[0]]}
+
+
+class Flip(BaseComponent):
+    """Flip the image vertically or horizontally."""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = "img"
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img"}
+
+    def __init__(self, mode="H"):
+        """
+        Initialize the instance.
+
+        Args:
+            mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
+                flipping. Default: 'H'.
+        """
+        super().__init__()
+        if mode not in ("H", "V"):
+            raise ValueError("`mode` should be 'H' or 'V'.")
+        self.mode = mode
+
+    def apply(self, img):
+        """apply"""
+        if self.mode == "H":
+            img = F.flip_h(img)
+        elif self.mode == "V":
+            img = F.flip_v(img)
+        return {"img": img}
+
+
+class Crop(BaseComponent):
+    """Crop region from the image."""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = ["img", "img_size"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
+
+    def __init__(self, crop_size, mode="C"):
+        """
+        Initialize the instance.
+
+        Args:
+            crop_size (list|tuple|int): Width and height of the region to crop.
+            mode (str, optional): 'C' for cropping the center part and 'TL' for
+                cropping the top left part. Default: 'C'.
+        """
+        super().__init__()
+        if isinstance(crop_size, int):
+            crop_size = [crop_size, crop_size]
+        _check_image_size(crop_size)
+
+        self.crop_size = crop_size
+
+        if mode not in ("C", "TL"):
+            raise ValueError("Unsupported interpolation method")
+        self.mode = mode
+
+    def apply(self, img):
+        """apply"""
+        h, w = img.shape[:2]
+        cw, ch = self.crop_size
+        if self.mode == "C":
+            x1 = max(0, (w - cw) // 2)
+            y1 = max(0, (h - ch) // 2)
+        elif self.mode == "TL":
+            x1, y1 = 0, 0
+        x2 = min(w, x1 + cw)
+        y2 = min(h, y1 + ch)
+        coords = (x1, y1, x2, y2)
+        if coords == (0, 0, w, h):
+            raise ValueError(
+                f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+            )
+        img = F.slice(img, coords=coords)
+        return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
+
+
+class _BaseResize(BaseComponent):
+    _INTERP_DICT = {
+        "NEAREST": cv2.INTER_NEAREST,
+        "LINEAR": cv2.INTER_LINEAR,
+        "CUBIC": cv2.INTER_CUBIC,
+        "AREA": cv2.INTER_AREA,
+        "LANCZOS4": cv2.INTER_LANCZOS4,
+    }
+
+    def __init__(self, size_divisor, interp):
+        super().__init__()
+
+        if size_divisor is not None:
+            assert isinstance(
+                size_divisor, int
+            ), "`size_divisor` should be None or int."
+        self.size_divisor = size_divisor
+
+        try:
+            interp = self._INTERP_DICT[interp]
+        except KeyError:
+            raise ValueError(
+                "`interp` should be one of {}.".format(self._INTERP_DICT.keys())
+            )
+        self.interp = interp
+
+    @staticmethod
+    def _rescale_size(img_size, target_size):
+        """rescale size"""
+        scale = min(max(target_size) / max(img_size), min(target_size) / min(img_size))
+        rescaled_size = [round(i * scale) for i in img_size]
+        return rescaled_size, scale
+
+
+class Resize(_BaseResize):
+    """Resize the image."""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = ["img", "img_size", "scale_factors"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {
+        "img": "img",
+        "img_size": "img_size",
+        "scale_factors": "scale_factors",
+    }
+
+    def __init__(
+        self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
+    ):
+        """
+        Initialize the instance.
+
+        Args:
+            target_size (list|tuple|int): Target width and height.
+            keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
+                image. Default: False.
+            size_divisor (int|None, optional): Divisor of resized image size.
+                Default: None.
+            interp (str, optional): Interpolation method. Choices are 'NEAREST',
+                'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
+        """
+        super().__init__(size_divisor=size_divisor, interp=interp)
+
+        if isinstance(target_size, int):
+            target_size = [target_size, target_size]
+        _check_image_size(target_size)
+        self.target_size = target_size
+
+        self.keep_ratio = keep_ratio
+
+    def apply(self, img):
+        """apply"""
+        target_size = self.target_size
+        original_size = img.shape[:2]
+
+        if self.keep_ratio:
+            h, w = img.shape[0:2]
+            target_size, _ = self._rescale_size((w, h), self.target_size)
+
+        if self.size_divisor:
+            target_size = [
+                math.ceil(i / self.size_divisor) * self.size_divisor
+                for i in target_size
+            ]
+
+        img_scale_w, img_scale_h = [
+            target_size[1] / original_size[1],
+            target_size[0] / original_size[0],
+        ]
+        img = F.resize(img, target_size, interp=self.interp)
+        return {
+            "img": img,
+            "img_size": [img.shape[1], img.shape[0]],
+            "scale_factors": [img_scale_w, img_scale_h],
+        }
+
+
+class ResizeByLong(_BaseResize):
+    """
+    Proportionally resize the image by specifying the target length of the
+    longest side.
+    """
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = ["img", "img_size"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
+
+    def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
+        """
+        Initialize the instance.
+
+        Args:
+            target_long_edge (int): Target length of the longest side of image.
+            size_divisor (int|None, optional): Divisor of resized image size.
+                Default: None.
+            interp (str, optional): Interpolation method. Choices are 'NEAREST',
+                'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
+        """
+        super().__init__(size_divisor=size_divisor, interp=interp)
+        self.target_long_edge = target_long_edge
+
+    def apply(self, img):
+        """apply"""
+        h, w = img.shape[:2]
+        scale = self.target_long_edge / max(h, w)
+        h_resize = round(h * scale)
+        w_resize = round(w * scale)
+        if self.size_divisor is not None:
+            h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
+            w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
+
+        img = F.resize(img, (w_resize, h_resize), interp=self.interp)
+        return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
+
+
+class ResizeByShort(_BaseResize):
+    """
+    Proportionally resize the image by specifying the target length of the
+    shortest side.
+    """
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = ["img", "img_size"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
+
+    def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
+        """
+        Initialize the instance.
+
+        Args:
+            target_short_edge (int): Target length of the shortest side of image.
+            size_divisor (int|None, optional): Divisor of resized image size.
+                Default: None.
+            interp (str, optional): Interpolation method. Choices are 'NEAREST',
+                'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
+        """
+        super().__init__(size_divisor=size_divisor, interp=interp)
+        self.target_short_edge = target_short_edge
+
+    def apply(self, img):
+        """apply"""
+        h, w = img.shape[:2]
+        scale = self.target_short_edge / min(h, w)
+        h_resize = round(h * scale)
+        w_resize = round(w * scale)
+        if self.size_divisor is not None:
+            h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
+            w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor
+
+        img = F.resize(img, (w_resize, h_resize), interp=self.interp)
+        return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
+
+
+class Pad(BaseComponent):
+    """Pad the image."""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = ["img", "img_size"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
+
+    def __init__(self, target_size, val=127.5):
+        """
+        Initialize the instance.
+
+        Args:
+            target_size (list|tuple|int): Target width and height of the image after
+                padding.
+            val (float, optional): Value to fill the padded area. Default: 127.5.
+        """
+        super().__init__()
+
+        if isinstance(target_size, int):
+            target_size = [target_size, target_size]
+        _check_image_size(target_size)
+        self.target_size = target_size
+
+        self.val = val
+
+    def apply(self, img):
+        """apply"""
+        h, w = img.shape[:2]
+        tw, th = self.target_size
+        ph = th - h
+        pw = tw - w
+
+        if ph < 0 or pw < 0:
+            raise ValueError(
+                f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
+            )
+        else:
+            img = F.pad(img, pad=(0, ph, 0, pw), val=self.val)
+        return {"img": img, "img_size": [img.shape[1], img.shape[0]]}
+
+
+class Normalize(BaseComponent):
+    """Normalize the image."""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = "img"
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img"}
+
+    def __init__(self, scale=1.0 / 255, mean=0.5, std=0.5, preserve_dtype=False):
+        """
+        Initialize the instance.
+
+        Args:
+            scale (float, optional): Scaling factor to apply to the image before
+                applying normalization. Default: 1/255.
+            mean (float|tuple|list, optional): Means for each channel of the image.
+                Default: 0.5.
+            std (float|tuple|list, optional): Standard deviations for each channel
+                of the image. Default: 0.5.
+            preserve_dtype (bool, optional): Whether to preserve the original dtype
+                of the image.
+        """
+        super().__init__()
+
+        self.scale = np.float32(scale)
+        if isinstance(mean, float):
+            mean = [mean]
+        self.mean = np.asarray(mean).astype("float32")
+        if isinstance(std, float):
+            std = [std]
+        self.std = np.asarray(std).astype("float32")
+        self.preserve_dtype = preserve_dtype
+
+    def apply(self, img):
+        """apply"""
+        old_type = img.dtype
+        # XXX: If `old_type` has higher precision than float32,
+        # we will lose some precision.
+        img = img.astype("float32", copy=False)
+        img *= self.scale
+        img -= self.mean
+        img /= self.std
+        if self.preserve_dtype:
+            img = img.astype(old_type, copy=False)
+        return {"img": img}
+
+
+class ToCHWImage(BaseComponent):
+    """Reorder the dimensions of the image from HWC to CHW."""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = "img"
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img"}
+
+    def apply(self, img):
+        """apply"""
+        img = img.transpose((2, 0, 1))
+        return {"img": img}

+ 58 - 0
paddlex/inference/components/transforms/image/funcs.py

@@ -0,0 +1,58 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+
+
+def resize(im, target_size, interp):
+    """resize image to target size"""
+    w, h = target_size
+    im = cv2.resize(im, (w, h), interpolation=interp)
+    return im
+
+
+def flip_h(im):
+    """flip image horizontally"""
+    if len(im.shape) == 3:
+        im = im[:, ::-1, :]
+    elif len(im.shape) == 2:
+        im = im[:, ::-1]
+    return im
+
+
+def flip_v(im):
+    """flip image vertically"""
+    if len(im.shape) == 3:
+        im = im[::-1, :, :]
+    elif len(im.shape) == 2:
+        im = im[::-1, :]
+    return im
+
+
+def slice(im, coords):
+    """slice the image"""
+    x1, y1, x2, y2 = coords
+    im = im[y1:y2, x1:x2, ...]
+    return im
+
+
+def pad(im, pad, val):
+    """padding image by value"""
+    if isinstance(pad, int):
+        pad = [pad] * 4
+    if len(pad) != 4:
+        raise ValueError
+    chns = 1 if im.ndim == 2 else im.shape[2]
+    im = cv2.copyMakeBorder(im, *pad, cv2.BORDER_CONSTANT, value=(val,) * chns)
+    return im

+ 16 - 0
paddlex/inference/pipelines/__init__.py

@@ -0,0 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .image_classification import ClasPipeline
+from .ocr import OCRPipeline

+ 48 - 0
paddlex/inference/pipelines/base.py

@@ -0,0 +1,48 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+
+from ...utils.misc import AutoRegisterABCMetaClass
+
+
+def create_pipeline(
+    pipeline_name: str,
+    model_list: list,
+    model_dir_list: list,
+    output: str,
+    device: str,
+) -> "BasePipeline":
+    """build model evaluater
+
+    Args:
+        pipeline_name (str): the pipeline name, that is name of pipeline class
+
+    Returns:
+        BasePipeline: the pipeline, which is subclass of BasePipeline.
+    """
+    pipeline = BasePipeline.get(pipeline_name)(output=output, device=device)
+    pipeline.update_model(model_list, model_dir_list)
+    pipeline.load_model()
+    return pipeline
+
+
+class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
+    """Base Pipeline"""
+
+    __is_base = True
+
+    # alias the __call__() to predict()
+    def __call__(self, *args, **kwargs):
+        yield from self.predict(*args, **kwargs)

+ 33 - 0
paddlex/inference/pipelines/image_classification.py

@@ -0,0 +1,33 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BasePipeline
+from ..predictors import ClasPredictor
+
+
+class ClasPipeline(BasePipeline):
+    """Cls Pipeline"""
+
+    entities = "image_classification"
+
+    def __init__(self, model, batch_size=1, device="gpu"):
+        super().__init__()
+        self._predict = ClasPredictor(model, batch_size=batch_size)
+
+    def predict(self, x):
+        self._check_input(x)
+        yield from self._predict(x)
+
+    def _check_input(self, x):
+        pass

+ 51 - 0
paddlex/inference/pipelines/ocr.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BasePipeline
+from ..predictors import TextDetPredictor, TextRecPredictor
+from ..components import CropByPolys
+from ..results import OCRResult
+
+
+class OCRPipeline(BasePipeline):
+    """OCR Pipeline"""
+
+    entities = "ocr"
+
+    def __init__(self, det_model, rec_model, det_batch_size, rec_batch_size, **kwargs):
+        self._det_predict = TextDetPredictor(det_model, batch_size=det_batch_size)
+        self._rec_predict = TextRecPredictor(rec_model, batch_size=rec_batch_size)
+        # TODO: foo
+        self._crop_by_polys = CropByPolys(det_box_type="foo")
+
+    def predict(self, x):
+        batch_ocr_res = []
+        for batch_det_res in self._det_predict(x):
+            for det_res in batch_det_res:
+                single_img_res = det_res["text_det_res"]
+                single_img_res["rec_text"] = []
+                single_img_res["rec_score"] = []
+                all_subs_of_img = list(self._crop_by_polys(single_img_res))
+                for batch_rec_res in self._rec_predict(all_subs_of_img):
+                    for rec_res in batch_rec_res:
+                        single_img_res["rec_text"].append(
+                            rec_res["text_rec_res"]["rec_text"]
+                        )
+                        single_img_res["rec_score"].append(
+                            rec_res["text_rec_res"]["rec_score"]
+                        )
+                # TODO(gaotingquan): using "ocr_res" or new a component or dict only?
+                batch_ocr_res.append({"ocr_res": OCRResult(single_img_res)})
+                # batch_ocr_res.append(OCRResult(single_img_res))
+        yield batch_ocr_res

+ 17 - 0
paddlex/inference/predictors/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .image_classification import ClasPredictor
+from .text_detection import TextDetPredictor
+from .text_recognition import TextRecPredictor

+ 67 - 0
paddlex/inference/predictors/base.py

@@ -0,0 +1,67 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import yaml
+import codecs
+from pathlib import Path
+from abc import abstractmethod
+
+from ...utils.misc import AutoRegisterABCMetaClass
+from ..components.base import BaseComponent, ComponentsEngine
+from .official_models import official_models
+
+
+class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
+    __is_base = True
+
+    INPUT_KEYS = "x"
+    OUTPUT_KEYS = None
+
+    KEEP_INPUT = False
+
+    MODEL_FILE_PREFIX = "inference"
+
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        self.model_dir = self._check_model(model)
+        self.kwargs = kwargs
+        self.config = self._load_config()
+        self.components = self._build_components()
+        self.engine = ComponentsEngine(self.components)
+        # alias predict() to the __call__()
+        self.predict = self.__call__
+
+    def _check_model(self, model):
+        if Path(model).exists():
+            return Path(model)
+        elif model in official_models:
+            return official_models[model]
+        else:
+            raise Exception(
+                f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!"
+            )
+
+    def _load_config(self):
+        config_path = self.model_dir / f"{self.MODEL_FILE_PREFIX}.yml"
+        with codecs.open(config_path, "r", "utf-8") as file:
+            dic = yaml.load(file, Loader=yaml.FullLoader)
+        return dic
+
+    def apply(self, x):
+        """predict"""
+        yield from self.engine(x)
+
+    @abstractmethod
+    def _build_components(self):
+        raise NotImplementedError

+ 113 - 0
paddlex/inference/predictors/image_classification.py

@@ -0,0 +1,113 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from functools import partial, wraps
+
+from ...modules.image_classification.model_list import MODELS
+from ..components import *
+from .base import BasePredictor
+
+
+def register(register_map, key):
+    """register the option setting func"""
+
+    def decorator(func):
+        register_map[key] = func
+
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            return func(self, *args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+class ClasPredictor(BasePredictor):
+
+    entities = MODELS
+
+    INPUT_KEYS = "x"
+    OUTPUT_KEYS = "topk_res"
+    DEAULT_INPUTS = {"x": "x"}
+    DEAULT_OUTPUTS = {"topk_res": "topk_res"}
+
+    _REGISTER_MAP = {}
+    register2self = partial(register, _REGISTER_MAP)
+
+    def _build_components(self):
+        ops = {}
+        ops["ReadImage"] = ReadImage(batch_size=self.kwargs.get("batch_size", 1))
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            func = self._REGISTER_MAP.get(tf_key)
+            args = cfg.get(tf_key, {})
+            op = func(self, **args) if args else func(self)
+            ops[tf_key] = op
+
+        kernel_option = PaddlePredictorOption()
+        # kernel_option.set_device(self.device)
+        predictor = ImagePredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=kernel_option,
+        )
+        predictor.set_inputs({"imgs": "img"})
+        ops["predictor"] = predictor
+
+        post_processes = self.config["PostProcess"]
+        for key in post_processes:
+            func = self._REGISTER_MAP.get(key)
+            args = post_processes.get(key, {})
+            op = func(self, **args) if args else func(self)
+            ops[key] = op
+        return ops
+
+    @register2self("ResizeImage")
+    def build_resize(self, resize_short=None, size=None):
+        assert resize_short or size
+        if resize_short:
+            op = ResizeByShort(
+                target_short_edge=resize_short, size_divisor=None, interp="LINEAR"
+            )
+        else:
+            op = Resize(target_size=size)
+        return op
+
+    @register2self("CropImage")
+    def build_crop(self, size=224):
+        return Crop(crop_size=size)
+
+    @register2self("NormalizeImage")
+    def build_normalize(
+        self,
+        mean=[0.485, 0.456, 0.406],
+        std=[0.229, 0.224, 0.225],
+        scale=1 / 255,
+        order="",
+        channel_num=3,
+    ):
+        assert channel_num == 3
+        assert order == ""
+        return Normalize(mean=mean, std=std)
+
+    @register2self("ToCHWImage")
+    def build_to_chw(self):
+        return ToCHWImage()
+
+    @register2self("Topk")
+    def build_topk(self, topk, label_list=None):
+        return Topk(topk=int(topk), class_ids=label_list)

+ 184 - 0
paddlex/inference/predictors/official_models.py

@@ -0,0 +1,184 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+
+from ...utils.cache import CACHE_DIR
+from ...utils.download import download_and_extract
+
+OFFICIAL_MODELS = {
+    "ResNet18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet18_infer.tar",
+    "ResNet18_vd": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet18_vd_infer.tar",
+    "ResNet34": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet34_infer.tar",
+    "ResNet34_vd": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet34_vd_infer.tar",
+    "ResNet50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet50_infer.tar",
+    "ResNet50_vd": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet50_vd_infer.tar",
+    "ResNet101": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet101_infer.tar",
+    "ResNet101_vd": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet101_vd_infer.tar",
+    "ResNet152": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet152_infer.tar",
+    "ResNet152_vd": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet152_vd_infer.tar",
+    "ResNet200_vd": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet200_vd_infer.tar",
+    "PP-LCNet_x0_25": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x0_25_infer.tar",
+    "PP-LCNet_x0_35": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x0_35_infer.tar",
+    "PP-LCNet_x0_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x0_5_infer.tar",
+    "PP-LCNet_x0_75": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x0_75_infer.tar",
+    "PP-LCNet_x1_0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x1_0_infer.tar",
+    "PP-LCNet_x1_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x1_5_infer.tar",
+    "PP-LCNet_x2_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x2_5_infer.tar",
+    "PP-LCNet_x2_0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x2_0_infer.tar",
+    "PP-LCNetV2_small": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNetV2_small_infer.tar",
+    "PP-LCNetV2_base": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNetV2_base_infer.tar",
+    "PP-LCNetV2_large": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNetV2_large_infer.tar",
+    "MobileNetV3_large_x0_35": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_large_x0_35_infer.tar",
+    "MobileNetV3_large_x0_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_large_x0_5_infer.tar",
+    "MobileNetV3_large_x0_75": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_large_x0_75_infer.tar",
+    "MobileNetV3_large_x1_0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_large_x1_0_infer.tar",
+    "MobileNetV3_large_x1_25": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_large_x1_25_infer.tar",
+    "MobileNetV3_small_x0_35": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_small_x0_35_infer.tar",
+    "MobileNetV3_small_x0_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_small_x0_5_infer.tar",
+    "MobileNetV3_small_x0_75": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_small_x0_75_infer.tar",
+    "MobileNetV3_small_x1_0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_small_x1_0_infer.tar",
+    "MobileNetV3_small_x1_25": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV3_small_x1_25_infer.tar",
+    "ConvNeXt_tiny": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ConvNeXt_tiny_infer.tar",
+    "ConvNeXt_small": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ConvNeXt_small_infer.tar",
+    "ConvNeXt_base_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ConvNeXt_base_224_infer.tar",
+    "ConvNeXt_base_384": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ConvNeXt_base_384_infer.tar",
+    "ConvNeXt_large_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ConvNeXt_large_224_infer.tar",
+    "ConvNeXt_large_384": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ConvNeXt_large_384_infer.tar",
+    "MobileNetV2_x0_25": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV2_x0_25_infer.tar",
+    "MobileNetV2_x0_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/MobileNetV2_x0_5_infer.tar",
+    "MobileNetV2_x1_0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/MobileNetV2_x1_0_infer.tar",
+    "MobileNetV2_x1_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/MobileNetV2_x1_5_infer.tar",
+    "MobileNetV2_x2_0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/MobileNetV2_x2_0_infer.tar",
+    "MobileNetV1_x0_25": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV1_x0_25_infer.tar",
+    "MobileNetV1_x0_5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV1_x0_5_infer.tar",
+    "MobileNetV1_x0_75": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV1_x0_75_infer.tar",
+    "MobileNetV1_x1_0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+MobileNetV1_x1_0_infer.tar",
+    "SwinTransformer_tiny_patch4_window7_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+SwinTransformer_tiny_patch4_window7_224_infer.tar",
+    "SwinTransformer_small_patch4_window7_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+SwinTransformer_small_patch4_window7_224_infer.tar",
+    "SwinTransformer_base_patch4_window7_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+SwinTransformer_base_patch4_window7_224_infer.tar",
+    "SwinTransformer_base_patch4_window12_384": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+SwinTransformer_base_patch4_window12_384_infer.tar",
+    "SwinTransformer_large_patch4_window7_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+SwinTransformer_large_patch4_window7_224_infer.tar",
+    "SwinTransformer_large_patch4_window12_384": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+SwinTransformer_large_patch4_window12_384_infer.tar",
+    "PP-HGNet_tiny": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNet_tiny_infer.tar",
+    "PP-HGNet_small": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNet_small_infer.tar",
+    "PP-HGNet_base": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNet_base_infer.tar",
+    "PP-HGNetV2-B0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B0_infer.tar",
+    "PP-HGNetV2-B1": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B1_infer.tar",
+    "PP-HGNetV2-B2": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B2_infer.tar",
+    "PP-HGNetV2-B3": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B3_infer.tar",
+    "PP-HGNetV2-B4": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B4_infer.tar",
+    "PP-HGNetV2-B5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B5_infer.tar",
+    "PP-HGNetV2-B6": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B6_infer.tar",
+    "CLIP_vit_base_patch16_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+CLIP_vit_base_patch16_224_infer.tar",
+    "CLIP_vit_large_patch14_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+CLIP_vit_large_patch14_224_infer.tar",
+    "PP-LCNet_x1_0_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x1_0_ML_infer.tar",
+    "PP-HGNetV2-B0_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B0_ML_infer.tar",
+    "PP-HGNetV2-B4_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B4_ML_infer.tar",
+    "PP-HGNetV2-B6_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B6_ML_infer.tar",
+    "ResNet50_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/ResNet50_ML_infer.tar",
+    "CLIP_vit_base_patch16_448_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/CLIP_vit_base_patch16_448_ML_infer.tar",
+    "PP-YOLOE_plus-X": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-YOLOE_plus-X_infer.tar",
+    "PP-YOLOE_plus-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-YOLOE_plus-L_infer.tar",
+    "PP-YOLOE_plus-M": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-YOLOE_plus-M_infer.tar",
+    "PP-YOLOE_plus-S": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-YOLOE_plus-S_infer.tar",
+    "RT-DETR-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/RT-DETR-L_infer.tar",
+    "RT-DETR-H": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/RT-DETR-H_infer.tar",
+    "RT-DETR-X": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/RT-DETR-X_infer.tar",
+    "YOLOv3-DarkNet53": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOv3-DarkNet53_infer.tar",
+    "YOLOv3-MobileNetV3": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOv3-MobileNetV3_infer.tar",
+    "YOLOv3-ResNet50_vd_DCN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOv3-ResNet50_vd_DCN_infer.tar",
+    "YOLOX-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOX-L_infer.tar",
+    "YOLOX-M": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOX-M_infer.tar",
+    "YOLOX-N": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOX-N_infer.tar",
+    "YOLOX-S": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOX-S_infer.tar",
+    "YOLOX-T": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOX-T_infer.tar",
+    "YOLOX-X": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/YOLOX-X_infer.tar",
+    "RT-DETR-R18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/RT-DETR-R18_infer.tar",
+    "RT-DETR-R50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/RT-DETR-R50_infer.tar",
+    "PicoDet-S": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet-S.tar",
+    "PicoDet-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet-L.tar",
+    "Deeplabv3-R50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/Deeplabv3-R50_infer.tar",
+    "Deeplabv3-R101": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/Deeplabv3-R101_infer.tar",
+    "Deeplabv3_Plus-R50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+Deeplabv3_Plus-R50_infer.tar",
+    "Deeplabv3_Plus-R101": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+Deeplabv3_Plus-R101_infer.tar",
+    "PP-LiteSeg-T": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LiteSeg-T_infer.tar",
+    "OCRNet_HRNet-W48": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/OCRNet_HRNet-W48_infer.tar",
+    "OCRNet_HRNet-W18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/OCRNet_HRNet-W18_infer.tar",
+    "SegFormer-B0": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SegFormer-B0_infer.tar",
+    "SegFormer-B1": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SegFormer-B1_infer.tar",
+    "SegFormer-B2": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SegFormer-B2_infer.tar",
+    "SegFormer-B3": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SegFormer-B3_infer.tar",
+    "SegFormer-B4": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SegFormer-B4_infer.tar",
+    "SegFormer-B5": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SegFormer-B5_infer.tar",
+    "SeaFormer_tiny": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SeaFormer_tiny_infer.tar",
+    "SeaFormer_small": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SeaFormer_small_infer.tar",
+    "SeaFormer_base": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SeaFormer_base_infer.tar",
+    "SeaFormer_large": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SeaFormer_large_infer.tar",
+    "Mask-RT-DETR-H": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/Mask-RT-DETR-H_infer.tar",
+    "Mask-RT-DETR-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/Mask-RT-DETR-L_infer.tar",
+    "PP-OCRv4_server_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_server_rec_infer.tar",
+    "PP-OCRv4_mobile_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_mobile_rec_infer.tar",
+    "PP-OCRv4_server_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_server_det_infer.tar",
+    "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_mobile_det_infer.tar",
+    "RepSVTR_mobile_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+openatom_rec_repsvtr_ch_infer.tar",
+    "SVTRv2_server_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+openatom_rec_svtrv2_ch_infer.tar",
+    "PicoDet_layout_1x": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet-L_layout_infer.tar",
+    "SLANet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SLANet_infer.tar",
+    "LaTeX_OCR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/LaTeX_OCR_rec_infer.tar",
+}
+
+
+class OfficialModelsDict(dict):
+    """Official Models Dict"""
+
+    def __getitem__(self, key):
+        url = super().__getitem__(key)
+        save_dir = Path(CACHE_DIR) / "official_models"
+        download_and_extract(url, save_dir, f"{key}", overwrite=False)
+        return save_dir / f"{key}"
+
+
+official_models = OfficialModelsDict(OFFICIAL_MODELS)

+ 123 - 0
paddlex/inference/predictors/text_detection.py

@@ -0,0 +1,123 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from functools import partial, wraps
+
+from ...modules.text_detection.model_list import MODELS
+
+from ..components import *
+from .base import BasePredictor
+
+
+def register(register_map, key):
+    """register the option setting func"""
+
+    def decorator(func):
+        register_map[key] = func
+
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            return func(self, *args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+class TextDetPredictor(BasePredictor):
+
+    entities = MODELS
+
+    INPUT_KEYS = "x"
+    OUTPUT_KEYS = "text_det_res"
+    DEAULT_INPUTS = {"x": "x"}
+    DEAULT_OUTPUTS = {"text_det_res": "text_det_res"}
+
+    _REGISTER_MAP = {}
+    register2self = partial(register, _REGISTER_MAP)
+
+    def _build_components(self):
+        ops = {}
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            func = self._REGISTER_MAP.get(tf_key)
+            args = cfg.get(tf_key, {})
+            op = func(self, **args) if args else func(self)
+            if op:
+                ops[tf_key] = op
+
+        kernel_option = PaddlePredictorOption()
+        # kernel_option.set_device(self.device)
+        predictor = ImagePredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=kernel_option,
+        )
+        predictor.set_inputs({"imgs": "img"})
+        ops["predictor"] = predictor
+
+        key, op = self.build_postprocess(**self.config["PostProcess"])
+        ops[key] = op
+        return ops
+
+    @register2self("DecodeImage")
+    def build_readimg(self, channel_first, img_mode):
+        assert channel_first == False
+        return ReadImage(format=img_mode, batch_size=self.kwargs.get("batch_size", 1))
+
+    @register2self("DetResizeForTest")
+    def build_resize(self, resize_long=960):
+        return DetResizeForTest(limit_side_len=resize_long, limit_type="max")
+
+    @register2self("NormalizeImage")
+    def build_normalize(
+        self,
+        mean=[0.485, 0.456, 0.406],
+        std=[0.229, 0.224, 0.225],
+        scale=1 / 255,
+        order="",
+        channel_num=3,
+    ):
+        return NormalizeImage(
+            mean=mean, std=std, scale=scale, order=order, channel_num=channel_num
+        )
+
+    @register2self("ToCHWImage")
+    def build_to_chw(self):
+        return ToCHWImage()
+
+    def build_postprocess(self, **kwargs):
+        if kwargs.get("name") == "DBPostProcess":
+            return "DBPostProcess", DBPostProcess(
+                thresh=kwargs.get("thresh", 0.3),
+                box_thresh=kwargs.get("box_thresh", 0.7),
+                max_candidates=kwargs.get("max_candidates", 1000),
+                unclip_ratio=kwargs.get("unclip_ratio", 2.0),
+                use_dilation=kwargs.get("use_dilation", False),
+                score_mode=kwargs.get("score_mode", "fast"),
+                box_type=kwargs.get("box_type", "quad"),
+            )
+
+        else:
+            raise Exception()
+
+    @register2self("DetLabelEncode")
+    def foo(self, *args, **kwargs):
+        return None
+
+    @register2self("KeepKeys")
+    def foo(self, *args, **kwargs):
+        return None

+ 100 - 0
paddlex/inference/predictors/text_recognition.py

@@ -0,0 +1,100 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from functools import partial, wraps
+
+from ...modules.text_recognition.model_list import MODELS
+
+from ..components import *
+from .base import BasePredictor
+
+
+def register(register_map, key):
+    """register the option setting func"""
+
+    def decorator(func):
+        register_map[key] = func
+
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            return func(self, *args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+class TextRecPredictor(BasePredictor):
+
+    entities = MODELS
+
+    INPUT_KEYS = "x"
+    OUTPUT_KEYS = "text_rec_res"
+    DEAULT_INPUTS = {"x": "x"}
+    DEAULT_OUTPUTS = {"text_rec_res": "text_rec_res"}
+
+    _REGISTER_MAP = {}
+    register2self = partial(register, _REGISTER_MAP)
+
+    def _build_components(self):
+        ops = {}
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            assert tf_key in self._REGISTER_MAP
+            func = self._REGISTER_MAP.get(tf_key)
+            args = cfg.get(tf_key, {})
+            op = func(self, **args) if args else func(self)
+            if op:
+                ops[tf_key] = op
+
+        kernel_option = PaddlePredictorOption()
+        # kernel_option.set_device(self.device)
+        predictor = ImagePredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=kernel_option,
+        )
+        predictor.set_inputs({"imgs": "img"})
+        ops["predictor"] = predictor
+
+        key, op = self.build_postprocess(**self.config["PostProcess"])
+        ops[key] = op
+        return ops
+
+    @register2self("DecodeImage")
+    def build_readimg(self, channel_first, img_mode):
+        assert channel_first == False
+        return ReadImage(format=img_mode, batch_size=self.kwargs.get("batch_size", 1))
+
+    @register2self("RecResizeImg")
+    def build_resize(self, image_shape):
+        return OCRReisizeNormImg(rec_image_shape=image_shape)
+
+    def build_postprocess(self, **kwargs):
+        if kwargs.get("name") == "CTCLabelDecode":
+            return "CTCLabelDecode", CTCLabelDecode(
+                character_list=kwargs.get("character_dict"),
+            )
+        else:
+            raise Exception()
+
+    @register2self("MultiLabelEncode")
+    def foo(self, *args, **kwargs):
+        return None
+
+    @register2self("KeepKeys")
+    def foo(self, *args, **kwargs):
+        return None

+ 18 - 0
paddlex/inference/results/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .topk import TopkResult
+from .text_det import TextDetResult
+from .text_rec import TextRecResult
+from .ocr import OCRResult

+ 150 - 0
paddlex/inference/results/ocr.py

@@ -0,0 +1,150 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+import json
+import math
+import random
+import numpy as np
+import cv2
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+
+from ...utils import logging
+from ...utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..utils.io import JsonWriter, ImageWriter, ImageReader
+
+
+class OCRResult(dict):
+    def __init__(self, data):
+        super().__init__(data)
+        self._json_writer = JsonWriter()
+        self._img_reader = ImageReader(backend="opencv")
+        self._img_writer = ImageWriter(backend="opencv")
+
+    def save_json(self, save_path, indent=4, ensure_ascii=False):
+        if not save_path.endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
+        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
+
+    def save_img(self, save_path):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
+        res_img = self._draw_ocr_box_txt(
+            self["img_path"], self["dt_polys"], self["rec_text"], self["rec_score"]
+        )
+        self._img_writer.write(save_path.as_posix(), res_img)
+        logging.info(f"The result has been saved in {save_path}.")
+
+    def print(self, json_format=True, indent=4, ensure_ascii=False):
+        str_ = self
+        if json_format:
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        logging.info(str_)
+
+    def _draw_ocr_box_txt(
+        self,
+        img_path,
+        boxes,
+        txts=None,
+        scores=None,
+        drop_score=0.5,
+        font_path=PINGFANG_FONT_FILE_PATH,
+    ):
+        """draw ocr result"""
+        img = self._img_reader.read(img_path)
+        image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+        h, w = image.height, image.width
+        img_left = image.copy()
+        img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
+        random.seed(0)
+
+        draw_left = ImageDraw.Draw(img_left)
+        if txts is None or len(txts) != len(boxes):
+            txts = [None] * len(boxes)
+        for idx, (box, txt) in enumerate(zip(boxes, txts)):
+            if scores is not None and scores[idx] < drop_score:
+                continue
+            color = (
+                random.randint(0, 255),
+                random.randint(0, 255),
+                random.randint(0, 255),
+            )
+            draw_left.polygon(box, fill=color)
+            img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
+            pts = np.array(box, np.int32).reshape((-1, 1, 2))
+            cv2.polylines(img_right_text, [pts], True, color, 1)
+            img_right = cv2.bitwise_and(img_right, img_right_text)
+        img_left = Image.blend(image, img_left, 0.5)
+        img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
+        img_show.paste(img_left, (0, 0, w, h))
+        img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
+        return np.array(img_show)
+
+
+def draw_box_txt_fine(img_size, box, txt, font_path=PINGFANG_FONT_FILE_PATH):
+    """draw box text"""
+    box_height = int(
+        math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+    )
+    box_width = int(
+        math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+    )
+
+    if box_height > 2 * box_width and box_height > 30:
+        img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_height, box_width), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+        img_text = img_text.transpose(Image.ROTATE_270)
+    else:
+        img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_width, box_height), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+
+    pts1 = np.float32(
+        [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+    )
+    pts2 = np.array(box, dtype=np.float32)
+    M = cv2.getPerspectiveTransform(pts1, pts2)
+
+    img_text = np.array(img_text, dtype=np.uint8)
+    img_right_text = cv2.warpPerspective(
+        img_text,
+        M,
+        img_size,
+        flags=cv2.INTER_NEAREST,
+        borderMode=cv2.BORDER_CONSTANT,
+        borderValue=(255, 255, 255),
+    )
+    return img_right_text
+
+
+def create_font(txt, sz, font_path=PINGFANG_FONT_FILE_PATH):
+    """create font"""
+    font_size = int(sz[1] * 0.8)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    if int(PIL.__version__.split(".")[0]) < 10:
+        length = font.getsize(txt)[0]
+    else:
+        length = font.getlength(txt)
+
+    if length > sz[0]:
+        font_size = int(font_size * sz[0] / length)
+        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    return font

+ 56 - 0
paddlex/inference/results/text_det.py

@@ -0,0 +1,56 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+import json
+import numpy as np
+import cv2
+
+from ...utils import logging
+from ..utils.io import JsonWriter, ImageWriter, ImageReader
+
+
+class TextDetResult(dict):
+    def __init__(self, data):
+        super().__init__(data)
+        self._json_writer = JsonWriter()
+        self._img_reader = ImageReader(backend="opencv")
+        self._img_writer = ImageWriter(backend="opencv")
+
+    def save_json(self, save_path, indent=4, ensure_ascii=False):
+        if not save_path.endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
+        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
+
+    def save_img(self, save_path):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
+        res_img = self._draw_rectangle(self["img_path"], self["dt_polys"])
+        self._img_writer.write(save_path.as_posix(), res_img)
+
+    def print(self, json_format=True, indent=4, ensure_ascii=False):
+        str_ = self
+        if json_format:
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        logging.info(str_)
+
+    def _draw_rectangle(self, img_path, boxes):
+        """draw rectangle"""
+        boxes = np.array(boxes)
+        img = self._img_reader.read(img_path)
+        img_show = img.copy()
+        for box in boxes.astype(int):
+            box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
+            cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
+        return img_show

+ 43 - 0
paddlex/inference/results/text_rec.py

@@ -0,0 +1,43 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+import json
+import numpy as np
+import cv2
+
+from ...utils import logging
+from ..utils.io import JsonWriter, ImageWriter, ImageReader
+
+
+class TextRecResult(dict):
+    def __init__(self, data):
+        super().__init__(data)
+        self._json_writer = JsonWriter()
+        self._img_reader = ImageReader(backend="opencv")
+        self._img_writer = ImageWriter(backend="opencv")
+
+    def save_json(self, save_path, indent=4, ensure_ascii=False):
+        if not save_path.endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
+        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
+
+    def save_img(self, save_path):
+        raise Exception()
+
+    def print(self, json_format=True, indent=4, ensure_ascii=False):
+        str_ = self
+        if json_format:
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        logging.info(str_)

+ 127 - 0
paddlex/inference/results/topk.py

@@ -0,0 +1,127 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+import json
+import PIL
+from PIL import ImageDraw, ImageFont
+import numpy as np
+
+from ...utils.fonts import PINGFANG_FONT_FILE_PATH
+from ...utils import logging
+from ..utils.io import JsonWriter, ImageWriter, ImageReader
+from ..utils.color_map import get_colormap
+
+
+class TopkResult(dict):
+    def __init__(self, data):
+        super().__init__(data)
+        self._json_writer = JsonWriter()
+        self._img_reader = ImageReader(backend="pil")
+        self._img_writer = ImageWriter(backend="pillow")
+
+    def save_json(self, save_path, indent=4, ensure_ascii=False):
+        if not save_path.endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
+        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
+
+    def save_img(self, save_path):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
+        labels = self.get("label_names", self["class_ids"])
+        res_img = self._draw_label(self["img_path"], self["scores"], labels)
+        self._img_writer.write(save_path, res_img)
+
+    def print(self, json_format=True, indent=4, ensure_ascii=False):
+        str_ = self
+        if json_format:
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        logging.info(str_)
+
+    def _draw_label(self, img_path, scores, class_ids):
+        """Draw label on image"""
+        label_str = f"{class_ids[0]} {scores[0]:.2f}"
+
+        image = self._img_reader.read(img_path)
+        image = image.convert("RGB")
+        image_size = image.size
+        draw = ImageDraw.Draw(image)
+        min_font_size = int(image_size[0] * 0.02)
+        max_font_size = int(image_size[0] * 0.05)
+        for font_size in range(max_font_size, min_font_size - 1, -1):
+            font = ImageFont.truetype(
+                PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8"
+            )
+            if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+                text_width_tmp, text_height_tmp = draw.textsize(label_str, font)
+            else:
+                left, top, right, bottom = draw.textbbox((0, 0), label_str, font)
+                text_width_tmp, text_height_tmp = right - left, bottom - top
+            if text_width_tmp <= image_size[0]:
+                break
+            else:
+                font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, min_font_size)
+        color_list = get_colormap(rgb=True)
+        color = tuple(color_list[0])
+        font_color = tuple(self._get_font_colormap(3))
+        if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+            text_width, text_height = draw.textsize(label_str, font)
+        else:
+            left, top, right, bottom = draw.textbbox((0, 0), label_str, font)
+            text_width, text_height = right - left, bottom - top
+
+        rect_left = 3
+        rect_top = 3
+        rect_right = rect_left + text_width + 3
+        rect_bottom = rect_top + text_height + 6
+
+        draw.rectangle([(rect_left, rect_top), (rect_right, rect_bottom)], fill=color)
+
+        text_x = rect_left + 3
+        text_y = rect_top
+        draw.text((text_x, text_y), label_str, fill=font_color, font=font)
+        return image
+
+    def _get_font_colormap(self, color_index):
+        """
+        Get font colormap
+        """
+        dark = np.array([0x14, 0x0E, 0x35])
+        light = np.array([0xFF, 0xFF, 0xFF])
+        light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+        if color_index in light_indexs:
+            return light.astype("int32")
+        else:
+            return dark.astype("int32")
+
+
+# class SaveClsResults(BaseComponent):
+
+#     INPUT_KEYS = ["img_path", "cls_pred"]
+#     OUTPUT_KEYS = None
+#     DEAULT_INPUTS = {"img_path": "img_path", "cls_pred": "cls_pred"}
+#     DEAULT_OUTPUTS = {}
+
+#     def __init__(self, save_dir, class_ids=None):
+#         super().__init__()
+#         self.save_dir = save_dir
+#         self.class_id_map = _parse_class_id_map(class_ids)
+#         self._json_writer = ImageWriter(backend="pillow")
+
+
+#     def _write_image(self, path, image):
+#         """write image"""
+#         if os.path.exists(path):
+#             logging.warning(f"{path} already exists. Overwriting it.")
+#         self._json_writer.write(path, image)

+ 89 - 0
paddlex/inference/utils/color_map.py

@@ -0,0 +1,89 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+def get_colormap(rgb=False):
+    """
+    Get colormap
+    """
+    color_list = np.array(
+        [
+            0xFF,
+            0x00,
+            0x00,
+            0xCC,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0x66,
+            0x00,
+            0x66,
+            0xFF,
+            0xCC,
+            0x00,
+            0xFF,
+            0xFF,
+            0x4D,
+            0x00,
+            0x80,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0xB2,
+            0x00,
+            0x1A,
+            0xFF,
+            0xFF,
+            0x00,
+            0xE5,
+            0xFF,
+            0x99,
+            0x00,
+            0x33,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0xFF,
+            0x33,
+            0x00,
+            0xFF,
+            0xFF,
+            0x00,
+            0x99,
+            0xFF,
+            0xE5,
+            0x00,
+            0x00,
+            0xFF,
+            0x1A,
+            0x00,
+            0xB2,
+            0xFF,
+            0x80,
+            0x00,
+            0xFF,
+            0xFF,
+            0x00,
+            0x4D,
+        ]
+    ).astype(np.float32)
+    color_list = color_list.reshape((-1, 3))
+    if not rgb:
+        color_list = color_list[:, ::-1]
+    return color_list.astype("int32")

+ 17 - 0
paddlex/inference/utils/io/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .readers import ImageReader, VideoReader, ReaderType
+from .writers import ImageWriter, TextWriter, JsonWriter, WriterType

+ 233 - 0
paddlex/inference/utils/io/readers.py

@@ -0,0 +1,233 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import enum
+import itertools
+import cv2
+from PIL import Image, ImageOps
+
+__all__ = ["ImageReader", "VideoReader", "ReaderType"]
+
+
+class ReaderType(enum.Enum):
+    """ReaderType"""
+
+    IMAGE = 1
+    GENERATIVE = 2
+    POINT_CLOUD = 3
+
+
+class _BaseReader(object):
+    """_BaseReader"""
+
+    def __init__(self, backend, **bk_args):
+        super().__init__()
+        if len(bk_args) == 0:
+            bk_args = self.get_default_backend_args()
+        self.bk_type = backend
+        self.bk_args = bk_args
+        self._backend = self.get_backend()
+
+    def read(self, in_path):
+        """read file from path"""
+        raise NotImplementedError
+
+    def get_backend(self, bk_args=None):
+        """get the backend"""
+        if bk_args is None:
+            bk_args = self.bk_args
+        return self._init_backend(self.bk_type, bk_args)
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        raise NotImplementedError
+
+    def get_type(self):
+        """get type"""
+        raise NotImplementedError
+
+    def get_default_backend_args(self):
+        """get default backend arguments"""
+        return {}
+
+
+class ImageReader(_BaseReader):
+    """ImageReader"""
+
+    def __init__(self, backend="opencv", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def read(self, in_path):
+        """read the image file from path"""
+        arr = self._backend.read_file(in_path)
+        return arr
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "opencv":
+            return OpenCVImageReaderBackend(**bk_args)
+        elif bk_type == "pil":
+            return PILImageReaderBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return ReaderType.IMAGE
+
+
+class _GenerativeReader(_BaseReader):
+    """_GenerativeReader"""
+
+    def get_type(self):
+        """get type"""
+        return ReaderType.GENERATIVE
+
+
+def is_generative_reader(reader):
+    """is_generative_reader"""
+    return isinstance(reader, _GenerativeReader)
+
+
+class VideoReader(_GenerativeReader):
+    """VideoReader"""
+
+    def __init__(
+        self,
+        backend="opencv",
+        st_frame_id=0,
+        max_num_frames=None,
+        auto_close=True,
+        **bk_args,
+    ):
+        super().__init__(backend=backend, **bk_args)
+        self.st_frame_id = st_frame_id
+        self.max_num_frames = max_num_frames
+        self.auto_close = auto_close
+
+    def read(self, in_path):
+        """read vide file from path"""
+        self._backend.set_pos(self.st_frame_id)
+        gen = self._backend.read_file(in_path)
+        if self.num_frames is not None:
+            gen = itertools.islice(gen, self.num_frames)
+        yield from gen
+        if self.auto_close:
+            self._backend.close()
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "opencv":
+            return OpenCVVideoReaderBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+
+class _BaseReaderBackend(object):
+    """_BaseReaderBackend"""
+
+    def read_file(self, in_path):
+        """read file from path"""
+        raise NotImplementedError
+
+
+class _ImageReaderBackend(_BaseReaderBackend):
+    """_ImageReaderBackend"""
+
+    pass
+
+
+class OpenCVImageReaderBackend(_ImageReaderBackend):
+    """OpenCVImageReaderBackend"""
+
+    def __init__(self, flags=cv2.IMREAD_COLOR):
+        super().__init__()
+        self.flags = flags
+
+    def read_file(self, in_path):
+        """read image file from path by OpenCV"""
+        return cv2.imread(in_path, flags=self.flags)
+
+
+class PILImageReaderBackend(_ImageReaderBackend):
+    """PILImageReaderBackend"""
+
+    def __init__(self):
+        super().__init__()
+
+    def read_file(self, in_path):
+        """read image file from path by PIL"""
+        return ImageOps.exif_transpose(Image.open(in_path))
+
+
+class _VideoReaderBackend(_BaseReaderBackend):
+    """_VideoReaderBackend"""
+
+    def set_pos(self, pos):
+        """set pos"""
+        raise NotImplementedError
+
+    def close(self):
+        """close io"""
+        raise NotImplementedError
+
+
+class OpenCVVideoReaderBackend(_VideoReaderBackend):
+    """OpenCVVideoReaderBackend"""
+
+    def __init__(self, **bk_args):
+        super().__init__()
+        self.cap_init_args = bk_args
+        self._cap = None
+        self._pos = 0
+        self._max_num_frames = None
+
+    def read_file(self, in_path):
+        """read vidio file from path"""
+        if self._cap is not None:
+            self._cap_release()
+        self._cap = self._cap_open(in_path)
+        if self._pos is not None:
+            self._cap_set_pos()
+        return self._read_frames(self._cap)
+
+    def _read_frames(self, cap):
+        """read frames"""
+        while True:
+            ret, frame = cap.read()
+            if not ret:
+                break
+            yield frame
+        self._cap_release()
+
+    def _cap_open(self, video_path):
+        self._cap = cv2.VideoCapture(video_path, **self.cap_init_args)
+        if not self._cap.isOpened():
+            raise RuntimeError(f"Failed to open {video_path}")
+        return self._cap
+
+    def _cap_release(self):
+        self._cap.release()
+
+    def _cap_set_pos(self):
+        self._cap.set(cv2.CAP_PROP_POS_FRAMES, self._pos)
+
+    def set_pos(self, pos):
+        self._pos = pos
+
+    def close(self):
+        if self._cap is not None:
+            self._cap_release()
+            self._cap = None

+ 226 - 0
paddlex/inference/utils/io/writers.py

@@ -0,0 +1,226 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import enum
+import json
+from pathlib import Path
+
+import cv2
+import numpy as np
+from PIL import Image
+
+__all__ = ["ImageWriter", "TextWriter", "JsonWriter", "WriterType"]
+
+
+class WriterType(enum.Enum):
+    """WriterType"""
+
+    IMAGE = 1
+    VIDEO = 2
+    TEXT = 3
+    JSON = 4
+
+
+class _BaseWriter(object):
+    """_BaseWriter"""
+
+    def __init__(self, backend, **bk_args):
+        super().__init__()
+        if len(bk_args) == 0:
+            bk_args = self.get_default_backend_args()
+        self.bk_type = backend
+        self.bk_args = bk_args
+        self._backend = self.get_backend()
+
+    def write(self, out_path, obj):
+        """write"""
+        raise NotImplementedError
+
+    def get_backend(self, bk_args=None):
+        """get backend"""
+        if bk_args is None:
+            bk_args = self.bk_args
+        return self._init_backend(self.bk_type, bk_args)
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        raise NotImplementedError
+
+    def get_type(self):
+        """get type"""
+        raise NotImplementedError
+
+    def get_default_backend_args(self):
+        """get default backend arguments"""
+        return {}
+
+
+class ImageWriter(_BaseWriter):
+    """ImageWriter"""
+
+    def __init__(self, backend="opencv", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj):
+        """write"""
+        return self._backend.write_obj(out_path, obj)
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "opencv":
+            return OpenCVImageWriterBackend(**bk_args)
+        elif bk_type == "pillow":
+            return PILImageWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.IMAGE
+
+
+class TextWriter(_BaseWriter):
+    """TextWriter"""
+
+    def __init__(self, backend="python", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj):
+        """write"""
+        return self._backend.write_obj(out_path, obj)
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "python":
+            return TextWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.TEXT
+
+
+class JsonWriter(_BaseWriter):
+    def __init__(self, backend="json", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj, **bk_args):
+        return self._backend.write_obj(out_path, obj, **bk_args)
+
+    def _init_backend(self, bk_type, bk_args):
+        if bk_type == "json":
+            return JsonWriterBackend(**bk_args)
+        elif bk_type == "ujson":
+            return UJsonWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.JSON
+
+
+class _BaseWriterBackend(object):
+    """_BaseWriterBackend"""
+
+    def write_obj(self, out_path, obj):
+        """write object"""
+        out_dir = os.path.dirname(out_path)
+        os.makedirs(out_dir, exist_ok=True)
+        return self._write_obj(out_path, obj)
+
+    def _write_obj(self, out_path, obj):
+        """write object"""
+        raise NotImplementedError
+
+
+class TextWriterBackend(_BaseWriterBackend):
+    """TextWriterBackend"""
+
+    def __init__(self, mode="w", encoding="utf-8"):
+        super().__init__()
+        self.mode = mode
+        self.encoding = encoding
+
+    def _write_obj(self, out_path, obj):
+        """write text object"""
+        with open(out_path, mode=self.mode, encoding=self.encoding) as f:
+            f.write(obj)
+
+
+class _ImageWriterBackend(_BaseWriterBackend):
+    """_ImageWriterBackend"""
+
+    pass
+
+
+class OpenCVImageWriterBackend(_ImageWriterBackend):
+    """OpenCVImageWriterBackend"""
+
+    def _write_obj(self, out_path, obj):
+        """write image object by OpenCV"""
+        if isinstance(obj, Image.Image):
+            arr = np.asarray(obj)
+        elif isinstance(obj, np.ndarray):
+            arr = obj
+        else:
+            raise TypeError("Unsupported object type")
+        return cv2.imwrite(out_path, arr)
+
+
+class PILImageWriterBackend(_ImageWriterBackend):
+    """PILImageWriterBackend"""
+
+    def __init__(self, format_=None):
+        super().__init__()
+        self.format = format_
+
+    def _write_obj(self, out_path, obj):
+        """write image object by PIL"""
+        if isinstance(obj, Image.Image):
+            img = obj
+        elif isinstance(obj, np.ndarray):
+            img = Image.fromarray(obj)
+        else:
+            raise TypeError("Unsupported object type")
+        return img.save(out_path, format=self.format)
+
+
+class _BaseJsonWriterBackend(object):
+    def __init__(self, indent=4, ensure_ascii=False):
+        super().__init__()
+        self.indent = indent
+        self.ensure_ascii = ensure_ascii
+
+    def write_obj(self, out_path, obj, **bk_args):
+        Path(out_path).parent.mkdir(parents=True, exist_ok=True)
+        return self._write_obj(out_path, obj, **bk_args)
+
+    def _write_obj(self, out_path, obj):
+        raise NotImplementedError
+
+
+class JsonWriterBackend(_BaseJsonWriterBackend):
+    def _write_obj(self, out_path, obj, **bk_args):
+        with open(out_path, "w") as f:
+            json.dump(obj, f, **bk_args)
+
+
+class UJsonWriterBackend(_BaseJsonWriterBackend):
+    # TODO
+    def _write_obj(self, out_path, obj, **bk_args):
+        raise NotImplementedError