瀏覽代碼

[WIP] hpi upgrade (#2724)

* hpi upgrade

* adapter

* fix

* add select_backend strategy

* add instance_seg and semantic_seg

* add ml_classification

* add object_detection

* adapt model_new pytest

* add ts modules

* update ts modules pytest

* add text_detection and text_recognition, fix trt chache

* add face_rec pytest

* add table_recognition

* fix

* update model_info_collection
zhang-prog 10 月之前
父節點
當前提交
bc218c5ffb
共有 32 個文件被更改,包括 996 次插入365 次删除
  1. 15 16
      libs/paddlex-hpi/src/paddlex_hpi/_config.py
  2. 91 0
      libs/paddlex-hpi/src/paddlex_hpi/_strategy.py
  3. 255 192
      libs/paddlex-hpi/src/paddlex_hpi/model_info_collection.json
  4. 64 25
      libs/paddlex-hpi/src/paddlex_hpi/models/base.py
  5. 23 8
      libs/paddlex-hpi/src/paddlex_hpi/models/image_classification.py
  6. 12 2
      libs/paddlex-hpi/src/paddlex_hpi/models/instance_segmentation.py
  7. 45 7
      libs/paddlex-hpi/src/paddlex_hpi/models/multilabel_classification.py
  8. 7 2
      libs/paddlex-hpi/src/paddlex_hpi/models/object_detection.py
  9. 26 3
      libs/paddlex-hpi/src/paddlex_hpi/models/semantic_segmentation.py
  10. 3 0
      libs/paddlex-hpi/src/paddlex_hpi/models/table_recognition.py
  11. 75 32
      libs/paddlex-hpi/src/paddlex_hpi/models/text_detection.py
  12. 27 20
      libs/paddlex-hpi/src/paddlex_hpi/models/ts_ad.py
  13. 25 18
      libs/paddlex-hpi/src/paddlex_hpi/models/ts_cls.py
  14. 27 20
      libs/paddlex-hpi/src/paddlex_hpi/models/ts_fc.py
  15. 68 14
      libs/paddlex-hpi/tests/models/base.py
  16. 2 0
      libs/paddlex-hpi/tests/models/test_anomaly_detection.py
  17. 51 0
      libs/paddlex-hpi/tests/models/test_face_recognition.py
  18. 2 0
      libs/paddlex-hpi/tests/models/test_formula_recognition.py
  19. 4 2
      libs/paddlex-hpi/tests/models/test_general_recognition.py
  20. 32 0
      libs/paddlex-hpi/tests/models/test_image_classification.py
  21. 2 0
      libs/paddlex-hpi/tests/models/test_image_unwarping.py
  22. 32 0
      libs/paddlex-hpi/tests/models/test_instance_segmentation.py
  23. 32 0
      libs/paddlex-hpi/tests/models/test_multilabel_classification.py
  24. 32 0
      libs/paddlex-hpi/tests/models/test_object_detection.py
  25. 30 0
      libs/paddlex-hpi/tests/models/test_semantic_segmentation.py
  26. 2 0
      libs/paddlex-hpi/tests/models/test_table_recognition.py
  27. 2 0
      libs/paddlex-hpi/tests/models/test_text_recognition.py
  28. 2 0
      libs/paddlex-hpi/tests/models/test_ts_ad.py
  29. 2 0
      libs/paddlex-hpi/tests/models/test_ts_cls.py
  30. 2 0
      libs/paddlex-hpi/tests/models/test_ts_fc.py
  31. 1 1
      paddlex/inference/models_new/base/predictor/base_predictor.py
  32. 3 3
      paddlex/inference/models_new/common/static_infer.py

+ 15 - 16
libs/paddlex-hpi/src/paddlex_hpi/_config.py

@@ -22,6 +22,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
 from typing_extensions import Annotated, TypeAlias, TypedDict, assert_never
 
 from paddlex_hpi._model_info import get_model_info
+from paddlex_hpi._strategy import SelectSpecificStrategy, SelectFirstStrategy
 from paddlex_hpi._utils.typing import Backend, DeviceType
 
 
@@ -81,7 +82,9 @@ class TensorRTConfig(_BackendConfig):
 
     def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
         option.use_trt_backend()
-        option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
+        option.trt_option.serialize_file = str(
+            model_dir / f"trt_serialized_{self.precision}.trt"
+        )
         if self.precision == "FP16":
             option.trt_option.enable_fp16 = True
         if self.dynamic_shapes is not None:
@@ -152,7 +155,7 @@ class HPIConfig(BaseModel):
     ] = None
 
     def get_backend_and_config(
-        self, model_name: str, device_type: DeviceType
+        self, model_name: str, device_type: DeviceType, onnx_format: bool
     ) -> Tuple[Backend, BackendConfig]:
         # Do we need an extensible selector?
         model_info = get_model_info(model_name, device_type)
@@ -160,21 +163,17 @@ class HPIConfig(BaseModel):
             backend_config_pairs = model_info["backend_config_pairs"]
         else:
             backend_config_pairs = []
-        config_dict: Dict[str, Any] = {}
-        if self.selected_backends and device_type in self.selected_backends:
-            backend = self.selected_backends[device_type]
-            for pair in backend_config_pairs:
-                # Use the first one
-                if pair[0] == self.selected_backends[device_type]:
-                    config_dict.update(pair[1])
-                    break
+
+        use_specific_backend = (
+            self.selected_backends and device_type in self.selected_backends
+        )
+        if use_specific_backend:
+            specified_backend = self.selected_backends[device_type]
+            strategy = SelectSpecificStrategy(onnx_format, specified_backend)
         else:
-            if backend_config_pairs:
-                # Currently we select the first one
-                backend = backend_config_pairs[0][0]
-                config_dict.update(backend_config_pairs[0][1])
-            else:
-                backend = "paddle_infer"
+            strategy = SelectFirstStrategy(onnx_format)
+
+        backend, config_dict = strategy.select_backend_and_config(backend_config_pairs)
         if self.backend_configs and backend in self.backend_configs:
             config_dict.update(
                 self.backend_configs[backend].model_dump(exclude_unset=True)

+ 91 - 0
libs/paddlex-hpi/src/paddlex_hpi/_strategy.py

@@ -0,0 +1,91 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from typing import List, Dict, Tuple, Union, Final
+
+
+class BackendSelectionStrategy(metaclass=abc.ABCMeta):
+    ONNX_ALLOWED_BACKENDS: Final[List[str]] = ["onnx_runtime", "tensorrt", "openvino"]
+    PADDLE_ALLOWED_BACKENDS: Final[List[str]] = ONNX_ALLOWED_BACKENDS + ["paddle_infer"]
+
+    onnx_format: bool
+
+    def backend_filter(self, backend_config_pairs: List[Tuple[str, Dict]]) -> List[str]:
+        allowed_backends = (
+            self.ONNX_ALLOWED_BACKENDS
+            if self.onnx_format
+            else self.PADDLE_ALLOWED_BACKENDS
+        )
+        filtered_backends = [
+            backend
+            for backend in backend_config_pairs
+            if backend[0] in allowed_backends
+        ]
+
+        return filtered_backends
+
+    @abc.abstractmethod
+    def select_backend_and_config(
+        self, backend_config_pairs: List[Tuple[str, Dict]]
+    ) -> Union[str, Dict]:
+        raise NotImplementedError
+
+
+class SelectSpecificStrategy(BackendSelectionStrategy):
+
+    def __init__(self, onnx_format: bool, specified_backend: str):
+        self.onnx_format = onnx_format
+        self.specified_backend = specified_backend
+
+    def select_backend_and_config(
+        self, backend_config_pairs: List[Tuple[str, Dict]]
+    ) -> Union[str, Dict]:
+        filtered_backends = self.backend_filter(backend_config_pairs)
+
+        for backend, config_dict in filtered_backends:
+            if backend == self.specified_backend:
+                return backend, config_dict
+
+        if self.onnx_format:
+            raise ValueError(
+                f"Unspported backend: {self.specified_backend}. Supported backends are: {', '.join(i[0] for i in filtered_backends)}"
+            )
+        else:
+            return "paddle_infer", {}
+
+
+class SelectFirstStrategy(BackendSelectionStrategy):
+
+    def __init__(self, onnx_format: bool):
+        self.onnx_format = onnx_format
+
+    def select_backend_and_config(
+        self, backend_config_pairs: List[Tuple[str, Dict]]
+    ) -> Union[str, Dict]:
+        filtered_backends = self.backend_filter(backend_config_pairs)
+
+        if filtered_backends:
+            backend, config_dict = filtered_backends[0]
+        else:
+            if self.onnx_format:
+                raise ValueError(
+                    "There is no supported backend for the ONNX model. Please use Paddle model instead."
+                )
+            else:
+                backend, config_dict = (
+                    filtered_backends[0] if filtered_backends else ["paddle_infer", {}]
+                )
+
+        return backend, config_dict

文件差異過大導致無法顯示
+ 255 - 192
libs/paddlex-hpi/src/paddlex_hpi/model_info_collection.json


+ 64 - 25
libs/paddlex-hpi/src/paddlex_hpi/models/base.py

@@ -27,7 +27,7 @@ from typing import (
 
 import ultra_infer as ui
 from ultra_infer.model import BaseUltraInferModel
-from paddlex.inference.common.reader import ReadImage
+from paddlex.inference.common.reader import ReadImage, ReadTS
 from paddlex.inference.models_new import BasePredictor
 from paddlex.inference.utils.new_ir_blacklist import NEWIR_BLOCKLIST
 from paddlex.utils import device as device_helper
@@ -55,37 +55,47 @@ class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
         model_dir: Union[str, PathLike],
         config: Optional[Dict[str, Any]] = None,
         device: Optional[str] = None,
+        use_onnx_model: Optional[bool] = None,
         hpi_params: Optional[HPIParams] = None,
     ) -> None:
         super().__init__(model_dir=model_dir, config=config)
         self._device = device or device_helper.get_default_device()
+        self._onnx_format = use_onnx_model
+        self._check_and_choose_model_format()
         self._hpi_params = hpi_params or {}
         self._hpi_config = self._get_hpi_config()
         self._ui_model = self.build_ui_model()
         self._data_reader = self._build_data_reader()
 
-    def __call__(self, input: Any, **kwargs: dict[str, Any]) -> Iterator[Any]:
-        self.set_predictor(**kwargs)
-        yield from self.apply(input)
+    def __call__(
+        self,
+        input: Any,
+        batch_size: int = None,
+        device: str = None,
+        **kwargs: dict[str, Any],
+    ) -> Iterator[Any]:
+        self.set_predictor(batch_size, device)
+        yield from self.apply(input, **kwargs)
 
     @property
     def model_path(self) -> Path:
-        return self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdmodel"
+        if self._onnx_format:
+            return self.model_dir / f"{self.MODEL_FILE_PREFIX}.onnx"
+        else:
+            return self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdmodel"
 
     @property
-    def params_path(self) -> Path:
-        return self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdiparams"
-
-    def set_predictor(self, **kwargs: Any) -> None:
-        if "device" in kwargs:
-            device = kwargs.pop("device")
-            if device is not None:
-                if device != self._device:
-                    raise RuntimeError("Currently, changing devices is not supported.")
-        if "batch_size" in kwargs:
-            self.batch_sampler.batch_size = kwargs.pop("batch_size")
-        if kwargs:
-            raise TypeError(f"Unexpected arguments: {kwargs}")
+    def params_path(self) -> Union[Path, None]:
+        if self._onnx_format:
+            return None
+        else:
+            return self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdiparams"
+
+    def set_predictor(self, batch_size: int = None, device: str = None) -> None:
+        if device and device != self._device:
+            raise RuntimeError("Currently, changing devices is not supported.")
+        if batch_size:
+            self.batch_sampler.batch_size = batch_size
 
     def build_ui_model(self) -> BaseUltraInferModel:
         option = self._create_ui_option()
@@ -102,11 +112,6 @@ class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
         )
         return hpi_config
 
-    def _get_selected_backend(self) -> Backend:
-        device_type, _ = device_helper.parse_device(self._device)
-        backend = self._hpi_config.get_selected_backend(self.model_name, device_type)
-        return backend
-
     def _create_ui_option(self) -> ui.RuntimeOption:
         option = ui.RuntimeOption()
         # HACK: Disable new IR for models that are known to have issues with the
@@ -128,13 +133,47 @@ class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
         else:
             assert_never(device_type)
         backend, backend_config = self._hpi_config.get_backend_and_config(
-            model_name=self.model_name, device_type=device_type
+            model_name=self.model_name,
+            device_type=device_type,
+            onnx_format=self._onnx_format,
         )
         logging.info("Backend: %s", backend)
         logging.info("Backend config: %s", backend_config)
         backend_config.update_ui_option(option, self.model_dir)
         return option
 
+    def _check_and_choose_model_format(self) -> None:
+        has_onnx_model = any(self.model_dir.glob(f"{self.MODEL_FILE_PREFIX}.onnx"))
+        has_pd_model = any(self.model_dir.glob(f"{self.MODEL_FILE_PREFIX}.pdmodel"))
+        if self._onnx_format is None:
+            if has_onnx_model and has_pd_model:
+                logging.warning(
+                    "Both ONNX and Paddle models are detected, but no preference is set. Default model (.pdmodel) will be used."
+                )
+            elif has_pd_model:
+                logging.warning(
+                    "Only Paddle model is detected. Paddle model will be used by default."
+                )
+            elif has_onnx_model:
+                self._onnx_format = True
+                logging.warning(
+                    "Only ONNX model is detected. ONNX model will be used by default."
+                )
+            else:
+                raise RuntimeError(
+                    "No models are detected. Please ensure the model file exists."
+                )
+        elif self._onnx_format:
+            if not has_onnx_model:
+                raise RuntimeError(
+                    "ONNX model is specified but not detected. Please ensure the ONNX model file exists."
+                )
+        else:
+            if not has_pd_model:
+                raise RuntimeError(
+                    "Paddle model is specified but not detected. Please ensure the Paddle model file exists."
+                )
+
     @abc.abstractmethod
     def _build_ui_model(self, option: ui.RuntimeOption) -> BaseUltraInferModel:
         raise NotImplementedError
@@ -151,4 +190,4 @@ class CVPredictor(HPPredictor):
 
 class TSPredictor(HPPredictor):
     def _build_data_reader(self):
-        return None
+        return ReadTS()

+ 23 - 8
libs/paddlex-hpi/src/paddlex_hpi/models/image_classification.py

@@ -38,16 +38,19 @@ class ClasPredictor(CVPredictor):
         model_dir: Union[str, os.PathLike],
         config: Optional[Dict[str, Any]] = None,
         device: Optional[str] = None,
+        use_onnx_model: Optional[bool] = None,
         hpi_params: Optional[HPIParams] = None,
+        topk: Union[int, None] = None,
     ) -> None:
         super().__init__(
             model_dir=model_dir,
             config=config,
             device=device,
+            use_onnx_model=use_onnx_model,
             hpi_params=hpi_params,
         )
         self._pp_params = self._get_pp_params()
-        self._ui_model.postprocessor.topk = self._pp_params.topk
+        self._topk = topk or self._pp_params.topk
 
     def _build_batch_sampler(self) -> ImageBatchSampler:
         return ImageBatchSampler()
@@ -58,17 +61,29 @@ class ClasPredictor(CVPredictor):
     def _build_ui_model(
         self, option: ui.RuntimeOption
     ) -> ui.vision.classification.PaddleClasModel:
-        model = ui.vision.classification.PaddleClasModel(
-            str(self.model_path),
-            str(self.params_path),
-            str(self.config_path),
-            runtime_option=option,
-        )
+        if self._onnx_format:
+            model = ui.vision.classification.PaddleClasModel(
+                str(self.model_path),
+                str(self.params_path),
+                str(self.config_path),
+                runtime_option=option,
+                model_format=ui.ModelFormat.ONNX,
+            )
+        else:
+            model = ui.vision.classification.PaddleClasModel(
+                str(self.model_path),
+                str(self.params_path),
+                str(self.config_path),
+                runtime_option=option,
+            )
         return model
 
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+    def process(
+        self, batch_data: List[Any], topk: Union[int, None] = None
+    ) -> Dict[str, List[Any]]:
         batch_raw_imgs = self._data_reader(imgs=batch_data)
         imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
+        self._ui_model.postprocessor.topk = topk or self._topk
         ui_results = self._ui_model.batch_predict(imgs)
 
         class_ids_list = []

+ 12 - 2
libs/paddlex-hpi/src/paddlex_hpi/models/instance_segmentation.py

@@ -39,6 +39,7 @@ class InstanceSegPredictor(CVPredictor):
         config: Optional[Dict[str, Any]] = None,
         device: Optional[str] = None,
         hpi_params: Optional[HPIParams] = None,
+        threshold: Optional[float] = None,
     ) -> None:
         super().__init__(
             model_dir=model_dir,
@@ -46,7 +47,10 @@ class InstanceSegPredictor(CVPredictor):
             device=device,
             hpi_params=hpi_params,
         )
+        if threshold and self.model_name == "SOLOv2":
+            raise TypeError("SOLOv2 does not support `threshold` in PaddleX HPI.")
         self._pp_params = self._get_pp_params()
+        self._threshold = threshold or self._pp_params.threshold
 
     def _build_ui_model(
         self, option: ui.RuntimeOption
@@ -65,9 +69,15 @@ class InstanceSegPredictor(CVPredictor):
     def _get_result_class(self) -> type:
         return InstanceSegResult
 
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+    def process(
+        self, batch_data: List[Any], threshold: Optional[float] = None
+    ) -> Dict[str, List[Any]]:
+        if threshold and self.model_name == "SOLOv2":
+            raise TypeError("SOLOv2 does not support `threshold` in PaddleX HPI.")
+
         batch_raw_imgs = self._data_reader(imgs=batch_data)
         imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
+        threshold = threshold or self._threshold
         ui_results = self._ui_model.batch_predict(imgs)
 
         boxes_list = []
@@ -78,7 +88,7 @@ class InstanceSegPredictor(CVPredictor):
                 key=ui_result.scores.__getitem__,
                 reverse=True,
             )
-            inds = [i for i in inds if ui_result.scores[i] > self._pp_params.threshold]
+            inds = [i for i in inds if ui_result.scores[i] > threshold]
             inds = [i for i in inds if ui_result.label_ids[i] > -1]
             ids = [ui_result.label_ids[i] for i in inds]
             scores = [ui_result.scores[i] for i in inds]

+ 45 - 7
libs/paddlex-hpi/src/paddlex_hpi/models/multilabel_classification.py

@@ -17,6 +17,9 @@ from typing import Any, Dict, List, Optional, Union
 
 import ultra_infer as ui
 import numpy as np
+from pathlib import Path
+import tempfile
+import yaml
 from paddlex.inference.common.batch_sampler import ImageBatchSampler
 from paddlex.inference.results import MLClassResult
 from paddlex.modules.multilabel_classification.model_list import MODELS
@@ -33,7 +36,9 @@ class MLClasPredictor(CVPredictor):
         config: Optional[Dict[str, Any]] = None,
         device: Optional[str] = None,
         hpi_params: Optional[HPIParams] = None,
+        threshold: Union[float, dict, list, None] = None,
     ) -> None:
+        self._threshold = threshold
         super().__init__(
             model_dir=model_dir,
             config=config,
@@ -45,12 +50,36 @@ class MLClasPredictor(CVPredictor):
     def _build_ui_model(
         self, option: ui.RuntimeOption
     ) -> ui.vision.classification.PyOnlyMultilabelClassificationModel:
-        model = ui.vision.classification.PyOnlyMultilabelClassificationModel(
-            str(self.model_path),
-            str(self.params_path),
-            str(self.config_path),
-            runtime_option=option,
-        )
+        if self._threshold:
+            if isinstance(self._threshold, (dict, list)):
+                raise TypeError("`threshold` must be float or None in PaddleX HPI")
+
+            with open(self.config_path, "r") as file:
+                config = yaml.safe_load(file)
+
+            config["PostProcess"]["MultiLabelThreshOutput"][
+                "threshold"
+            ] = self._threshold
+
+            temp_dir = os.path.dirname(self.config_path)
+            with tempfile.NamedTemporaryFile(
+                delete=False, dir=temp_dir, suffix=".yml", mode="w", encoding="utf-8"
+            ) as temp_file:
+                temp_file_path = temp_file.name
+                yaml.safe_dump(config, temp_file, default_flow_style=False)
+                model = ui.vision.classification.PyOnlyMultilabelClassificationModel(
+                    str(self.model_path),
+                    str(self.params_path),
+                    str(Path(temp_file_path)),
+                    runtime_option=option,
+                )
+        else:
+            model = ui.vision.classification.PyOnlyMultilabelClassificationModel(
+                str(self.model_path),
+                str(self.params_path),
+                str(self.config_path),
+                runtime_option=option,
+            )
         return model
 
     def _build_batch_sampler(self) -> ImageBatchSampler:
@@ -59,7 +88,16 @@ class MLClasPredictor(CVPredictor):
     def _get_result_class(self) -> type:
         return MLClassResult
 
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+    def process(
+        self,
+        batch_data: List[Any],
+        threshold: Union[float, dict, list, None] = None,
+    ) -> Dict[str, List[Any]]:
+        if threshold:
+            raise TypeError(
+                "`threshold` is not supported for multilabel classification in PaddleX HPI"
+            )
+
         batch_raw_imgs = self._data_reader(imgs=batch_data)
         imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
         ui_results = self._ui_model.batch_predict(imgs)

+ 7 - 2
libs/paddlex-hpi/src/paddlex_hpi/models/object_detection.py

@@ -39,6 +39,7 @@ class DetPredictor(CVPredictor):
         config: Optional[Dict[str, Any]] = None,
         device: Optional[str] = None,
         hpi_params: Optional[HPIParams] = None,
+        threshold: Optional[float] = None,
     ) -> None:
         super().__init__(
             model_dir=model_dir,
@@ -47,6 +48,7 @@ class DetPredictor(CVPredictor):
             hpi_params=hpi_params,
         )
         self._pp_params = self._get_pp_params()
+        self._threshold = threshold or self._pp_params.threshold
 
     def _build_ui_model(
         self, option: ui.RuntimeOption
@@ -65,9 +67,12 @@ class DetPredictor(CVPredictor):
     def _get_result_class(self) -> type:
         return DetResult
 
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+    def process(
+        self, batch_data: List[Any], threshold: Optional[float] = None
+    ) -> Dict[str, List[Any]]:
         batch_raw_imgs = self._data_reader(imgs=batch_data)
         imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
+        threshold = threshold or self._threshold
         ui_results = self._ui_model.batch_predict(imgs)
 
         boxes_list = []
@@ -77,7 +82,7 @@ class DetPredictor(CVPredictor):
                 key=ui_result.scores.__getitem__,
                 reverse=True,
             )
-            inds = [i for i in inds if ui_result.scores[i] > self._pp_params.threshold]
+            inds = [i for i in inds if ui_result.scores[i] > threshold]
             inds = [i for i in inds if ui_result.label_ids[i] > -1]
             ids = [ui_result.label_ids[i] for i in inds]
             scores = [ui_result.scores[i] for i in inds]

+ 26 - 3
libs/paddlex-hpi/src/paddlex_hpi/models/semantic_segmentation.py

@@ -12,7 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List
+import os
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import ultra_infer as ui
 import numpy as np
@@ -20,12 +21,29 @@ from paddlex.inference.common.batch_sampler import ImageBatchSampler
 from paddlex.inference.results import SegResult
 from paddlex.modules.semantic_segmentation.model_list import MODELS
 
-from paddlex_hpi.models.base import CVPredictor
+from paddlex_hpi.models.base import CVPredictor, HPIParams
 
 
 class SegPredictor(CVPredictor):
     entities = MODELS
 
+    def __init__(
+        self,
+        model_dir: Union[str, os.PathLike],
+        config: Optional[Dict[str, Any]] = None,
+        device: Optional[str] = None,
+        hpi_params: Optional[HPIParams] = None,
+        target_size: Union[int, Tuple[int], None] = None,
+    ) -> None:
+        if target_size:
+            raise TypeError("`target_size` is not supported in PaddleX HPI.")
+        super().__init__(
+            model_dir=model_dir,
+            config=config,
+            device=device,
+            hpi_params=hpi_params,
+        )
+
     def _build_ui_model(
         self, option: ui.RuntimeOption
     ) -> ui.vision.segmentation.PaddleSegModel:
@@ -43,7 +61,12 @@ class SegPredictor(CVPredictor):
     def _get_result_class(self) -> type:
         return SegResult
 
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+    def process(
+        self, batch_data: List[Any], target_size: Union[int, Tuple[int], None] = None
+    ) -> Dict[str, List[Any]]:
+        if target_size:
+            raise TypeError("`target_size` is not supported in PaddleX HPI.")
+
         batch_raw_imgs = self._data_reader(imgs=batch_data)
         imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
         ui_results = self._ui_model.batch_predict(imgs)

+ 3 - 0
libs/paddlex-hpi/src/paddlex_hpi/models/table_recognition.py

@@ -63,13 +63,16 @@ class TablePredictor(CVPredictor):
 
         bbox_list = []
         structure_list = []
+        structure_score_list = []
         for ui_result in ui_results:
             bbox_list.append(ui_result.table_boxes)
             structure_list.append(ui_result.table_structure)
+            structure_score_list.append(0.0)
 
         return {
             "input_path": batch_data,
             "input_img": batch_raw_imgs,
             "bbox": bbox_list,
             "structure": structure_list,
+            "structure_score": structure_score_list,
         }

+ 75 - 32
libs/paddlex-hpi/src/paddlex_hpi/models/text_detection.py

@@ -34,13 +34,33 @@ class TextDetPredictor(CVPredictor):
         config: Optional[Dict[str, Any]] = None,
         device: Optional[str] = None,
         hpi_params: Optional[HPIParams] = None,
+        limit_side_len: Union[int, None] = None,
+        limit_type: Union[str, None] = None,
+        thresh: Union[float, None] = None,
+        box_thresh: Union[float, None] = None,
+        max_candidates: Union[int, None] = None,
+        unclip_ratio: Union[float, None] = None,
+        use_dilation: Union[bool, None] = None,
     ) -> None:
+        if limit_type is not None:
+            raise TypeError(
+                "The default value for `limit_type` is max, and cannot be set in PaddleX HPI."
+            )
+        if max_candidates is not None:
+            raise TypeError(
+                "The default value for `max_candidates` is 1000, and cannot be set in PaddleX HPI."
+            )
         super().__init__(
             model_dir=model_dir,
             config=config,
             device=device,
             hpi_params=hpi_params,
         )
+        self._limit_side_len = limit_side_len or self._max_side_len
+        self._thresh = thresh or self._changeable_params["thresh"]
+        self._box_thresh = box_thresh or self._changeable_params["thresh"]
+        self._unclip_ratio = unclip_ratio or self._changeable_params["unclip_ratio"]
+        self._use_dilation = use_dilation or self._changeable_params["use_dilation"]
 
     def _build_batch_sampler(self) -> ImageBatchSampler:
         return ImageBatchSampler()
@@ -68,11 +88,47 @@ class TextDetPredictor(CVPredictor):
                 str(self.params_path),
                 runtime_option=option,
             )
-        self._config_ui_preprocessor(model)
-        self._config_ui_postprocessor(model)
+        self._config_ui_preprocessor()
+        self._config_ui_postprocessor()
         return model
 
-    def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
+    def process(
+        self,
+        batch_data: List[Any],
+        limit_side_len: Union[int, None] = None,
+        limit_type: Union[str, None] = None,
+        thresh: Union[float, None] = None,
+        box_thresh: Union[float, None] = None,
+        max_candidates: Union[int, None] = None,
+        unclip_ratio: Union[float, None] = None,
+        use_dilation: Union[bool, None] = None,
+    ) -> Dict[str, List[Any]]:
+        if limit_type is not None:
+            raise TypeError(
+                "The default value for `limit_type` is max, and cannot be set in PaddleX HPI."
+            )
+        if max_candidates is not None:
+            raise TypeError(
+                "The default value for `max_candidates` is 1000, and cannot be set in PaddleX HPI."
+            )
+        self._ui_model.preprocessor.set_normalize(self._mean, self._std, True)
+        self._ui_model.preprocessor.max_side_len = (
+            limit_side_len or self._limit_side_len
+        )
+        postprocessor = self._ui_model.postprocessor
+        postprocessor.det_db_thresh = thresh or self._thresh
+        postprocessor.det_db_box_thresh = box_thresh or self._box_thresh
+        postprocessor.det_db_unclip_ratio = unclip_ratio or self._unclip_ratio
+        postprocessor.use_dilation = use_dilation or self._use_dilation
+        postprocessor.det_db_score_mode = self._changeable_params["score_mode"]
+        if self._is_curve_model:
+            if self._changeable_params["box_type"] not in ("quad", "poly"):
+                raise RuntimeError("Invalid value of `DBPostProcess.box_type`.")
+            if self._changeable_params["box_type"] == "quad":
+                postprocessor.det_db_box_type = "bbox"
+            else:
+                postprocessor.det_db_box_type = "poly"
+
         batch_raw_imgs = self._data_reader(imgs=batch_data)
         imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
         ui_results = self._ui_model.batch_predict(imgs)
@@ -86,7 +142,7 @@ class TextDetPredictor(CVPredictor):
             # temporarily use dummy scores here.
             dummy_scores = [0.0 for _ in ui_result.boxes]
             dt_scores_list.append(dummy_scores)
-
+        # breakpoint()
         return {
             "input_path": batch_data,
             "input_img": batch_raw_imgs,
@@ -94,9 +150,8 @@ class TextDetPredictor(CVPredictor):
             "dt_scores": dt_scores_list,
         }
 
-    def _config_ui_preprocessor(self, model: ui.vision.ocr.DBDetector) -> None:
+    def _config_ui_preprocessor(self) -> None:
         pp_config = self.config["PreProcess"]
-        preprocessor = model.preprocessor
         for item in pp_config["transform_ops"]:
             op_name = next(iter(item))
             op_config = item[op_name]
@@ -108,7 +163,7 @@ class TextDetPredictor(CVPredictor):
                         "`DecodeImage.channel_first` must be set to False."
                     )
             elif op_name == "DetResizeForTest":
-                preprocessor.max_side_len = op_config.get("resize_long", 960)
+                self._max_side_len = op_config.get("resize_long", 960)
             elif op_name == "NormalizeImage":
                 if "scale" in op_config and not (
                     abs(parse_scale(op_config["scale"]) - 1 / 255) < 1e-9
@@ -116,11 +171,10 @@ class TextDetPredictor(CVPredictor):
                     raise RuntimeError("`NormalizeImage.scale` must be set to 1/255.")
                 if "channel_num" in op_config and op_config["channel_num"] != 3:
                     raise RuntimeError("`NormalizeImage.channel_num` must be set to 3.")
-                preprocessor.set_normalize(
-                    op_config.get("mean", [0.485, 0.456, 0.406]),
-                    op_config.get("std", [0.229, 0.224, 0.225]),
-                    True,
-                )
+
+                self._mean = op_config.get("mean", [0.485, 0.456, 0.406])
+                self._std = op_config.get("std", [0.229, 0.224, 0.225])
+
             elif op_name == "ToCHWImage":
                 # Do nothing
                 pass
@@ -131,44 +185,33 @@ class TextDetPredictor(CVPredictor):
             else:
                 raise RuntimeError(f"Unkown preprocessing operator: {op_name}")
 
-    def _config_ui_postprocessor(self, model: ui.vision.ocr.DBDetector) -> None:
+    def _config_ui_postprocessor(self) -> None:
         pp_config = self.config["PostProcess"]
         # XXX: Default values copied from
         # `paddlex.inference.models.TextDetPredictor`
-        changeable_params: Dict[str, Any] = {
+        self._changeable_params: Dict[str, Any] = {
             "thresh": 0.3,
             "box_thresh": 0.7,
             "unclip_ratio": 2.0,
             "score_mode": "fast",
             "use_dilation": False,
         }
-        unchangeable_params: Dict[str, Any] = {
+        self._unchangeable_params: Dict[str, Any] = {
             "max_candidates": 1000,
             "box_type": "quad",
         }
         if self._is_curve_model:
-            changeable_params["box_type"] = unchangeable_params.pop("box_type")
+            self._changeable_params["box_type"] = self._unchangeable_params.pop(
+                "box_type"
+            )
         if "name" in pp_config and pp_config["name"] == "DBPostProcess":
-            for name in changeable_params:
+            for name in self._changeable_params:
                 if name in pp_config:
-                    changeable_params[name] = pp_config[name]
-            for name, val in unchangeable_params.items():
+                    self._changeable_params[name] = pp_config[name]
+            for name, val in self._unchangeable_params.items():
                 if name in pp_config and pp_config[name] != val:
                     raise RuntimeError(
                         f"`DBPostProcess.{name}` must be set to {repr(val)}."
                     )
         else:
             raise RuntimeError("Invalid config")
-        postprocessor = model.postprocessor
-        postprocessor.det_db_thresh = changeable_params["thresh"]
-        postprocessor.det_db_box_thresh = changeable_params["box_thresh"]
-        postprocessor.det_db_unclip_ratio = changeable_params["unclip_ratio"]
-        postprocessor.use_dilation = changeable_params["use_dilation"]
-        postprocessor.det_db_score_mode = changeable_params["score_mode"]
-        if self._is_curve_model:
-            if changeable_params["box_type"] not in ("quad", "poly"):
-                raise RuntimeError("Invalid value of `DBPostProcess.box_type`.")
-            if changeable_params["box_type"] == "quad":
-                postprocessor.det_db_box_type = "bbox"
-            else:
-                postprocessor.det_db_box_type = "poly"

+ 27 - 20
libs/paddlex-hpi/src/paddlex_hpi/models/ts_ad.py

@@ -12,20 +12,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List
+from typing import Any, Dict, List, Union
 
 import ultra_infer as ui
 import pandas as pd
+from paddlex.inference.common.batch_sampler import TSBatchSampler
 from paddlex.inference.results import TSAdResult
 from paddlex.modules.ts_anomaly_detection.model_list import MODELS
 
-from paddlex_hpi._utils.typing import BatchData, Data
 from paddlex_hpi.models.base import TSPredictor
 
 
 class TSAdPredictor(TSPredictor):
     entities = MODELS
 
+    def _build_batch_sampler(self) -> TSBatchSampler:
+        return TSBatchSampler()
+
+    def _get_result_class(self) -> type:
+        return TSAdResult
+
     def _build_ui_model(
         self, option: ui.RuntimeOption
     ) -> ui.ts.anomalydetection.PyOnlyAnomalyDetectionModel:
@@ -37,22 +43,23 @@ class TSAdPredictor(TSPredictor):
         )
         return model
 
-    def _predict(self, batch_data: BatchData) -> BatchData:
-        ts_data = [data["ts"] for data in batch_data]
-        ui_results = self._ui_model.batch_predict(ts_data)
-        results: BatchData = []
-        for data, ui_result in zip(batch_data, ui_results):
-            ts_ad_result = self._create_ts_ad_result(data, ui_result)
-            results.append({"result": ts_ad_result})
-        return results
-
-    def _create_ts_ad_result(self, data: Data, ui_result: Any) -> TSAdResult:
-        data_dict = {
-            ui_result.col_names[i]: ui_result.data[i]
-            for i in range(len(ui_result.col_names))
+    def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
+        batch_raw_ts = self._data_reader(ts_list=batch_data)
+        ui_results = self._ui_model.batch_predict(batch_raw_ts)
+
+        anomaly_list = []
+        for ui_result in ui_results:
+            data_dict = {
+                ui_result.col_names[i]: ui_result.data[i]
+                for i in range(len(ui_result.col_names))
+            }
+            anomaly = pd.DataFrame.from_dict(data_dict)
+            anomaly.index = ui_result.dates
+            anomaly.index.name = "timestamp"
+            anomaly_list.append(anomaly)
+
+        return {
+            "input_path": batch_data,
+            "input_ts": batch_raw_ts,
+            "anomaly": anomaly_list,
         }
-        anomaly = pd.DataFrame.from_dict(data_dict)
-        anomaly.index = ui_result.dates
-        anomaly.index.name = "timestamp"
-        dic = {"input_path": data["input_path"], "anomaly": anomaly}
-        return TSAdResult(dic)

+ 25 - 18
libs/paddlex-hpi/src/paddlex_hpi/models/ts_cls.py

@@ -12,20 +12,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List
+from typing import Any, Dict, List, Union
 
 import ultra_infer as ui
 import pandas as pd
+from paddlex.inference.common.batch_sampler import TSBatchSampler
 from paddlex.inference.results import TSClsResult
 from paddlex.modules.ts_classification.model_list import MODELS
 
-from paddlex_hpi._utils.typing import BatchData, Data
 from paddlex_hpi.models.base import TSPredictor
 
 
 class TSClsPredictor(TSPredictor):
     entities = MODELS
 
+    def _build_batch_sampler(self) -> TSBatchSampler:
+        return TSBatchSampler()
+
+    def _get_result_class(self) -> type:
+        return TSClsResult
+
     def _build_ui_model(
         self, option: ui.RuntimeOption
     ) -> ui.ts.classification.PyOnlyClassificationModel:
@@ -37,19 +43,20 @@ class TSClsPredictor(TSPredictor):
         )
         return model
 
-    def _predict(self, batch_data: BatchData) -> BatchData:
-        ts_data = [data["ts"] for data in batch_data]
-        ui_results = self._ui_model.batch_predict(ts_data)
-        results: BatchData = []
-        for data, ui_result in zip(batch_data, ui_results):
-            ts_cls_result = self._create_ts_cls_result(data, ui_result)
-            results.append({"result": ts_cls_result})
-        return results
-
-    def _create_ts_cls_result(self, data: Data, ui_result: Any) -> TSClsResult:
-        classification = pd.DataFrame.from_dict(
-            {"classid": [ui_result.class_id], "score": [ui_result.score]}
-        )
-        classification.index.name = "sample"
-        dic = {"input_path": data["input_path"], "classification": classification}
-        return TSClsResult(dic)
+    def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
+        batch_raw_ts = self._data_reader(ts_list=batch_data)
+        ui_results = self._ui_model.batch_predict(batch_raw_ts)
+
+        classification_list = []
+        for ui_result in ui_results:
+            classification = pd.DataFrame.from_dict(
+                {"classid": [ui_result.class_id], "score": [ui_result.score]}
+            )
+            classification.index.name = "sample"
+            classification_list.append(classification)
+
+        return {
+            "input_path": batch_data,
+            "input_ts": batch_raw_ts,
+            "classification": classification_list,
+        }

+ 27 - 20
libs/paddlex-hpi/src/paddlex_hpi/models/ts_fc.py

@@ -12,20 +12,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List
+from typing import Any, Dict, List, Union
 
 import ultra_infer as ui
 import pandas as pd
+from paddlex.inference.common.batch_sampler import TSBatchSampler
 from paddlex.inference.results import TSFcResult
 from paddlex.modules.ts_forecast.model_list import MODELS
 
-from paddlex_hpi._utils.typing import BatchData, Data
 from paddlex_hpi.models.base import TSPredictor
 
 
 class TSFcPredictor(TSPredictor):
     entities = MODELS
 
+    def _build_batch_sampler(self) -> TSBatchSampler:
+        return TSBatchSampler()
+
+    def _get_result_class(self) -> type:
+        return TSFcResult
+
     def _build_ui_model(
         self, option: ui.RuntimeOption
     ) -> ui.ts.forecasting.PyOnlyForecastingModel:
@@ -37,22 +43,23 @@ class TSFcPredictor(TSPredictor):
         )
         return model
 
-    def _predict(self, batch_data: BatchData) -> BatchData:
-        ts_data = [data["ts"] for data in batch_data]
-        ui_results = self._ui_model.batch_predict(ts_data)
-        results: BatchData = []
-        for data, ui_result in zip(batch_data, ui_results):
-            ts_fc_result = self._create_ts_fc_result(data, ui_result)
-            results.append({"result": ts_fc_result})
-        return results
-
-    def _create_ts_fc_result(self, data: Data, ui_result: Any) -> TSFcResult:
-        data_dict = {
-            ui_result.col_names[i]: ui_result.data[i]
-            for i in range(len(ui_result.col_names))
+    def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
+        batch_raw_ts = self._data_reader(ts_list=batch_data)
+        ui_results = self._ui_model.batch_predict(batch_raw_ts)
+
+        forecast_list = []
+        for ui_result in ui_results:
+            data_dict = {
+                ui_result.col_names[i]: ui_result.data[i]
+                for i in range(len(ui_result.col_names))
+            }
+            forecast = pd.DataFrame.from_dict(data_dict)
+            forecast.index = ui_result.dates
+            forecast.index.name = "date"
+            forecast_list.append(forecast)
+
+        return {
+            "input_path": batch_data,
+            "input_ts": batch_raw_ts,
+            "forecast": forecast_list,
         }
-        forecast = pd.DataFrame.from_dict(data_dict)
-        forecast.index = ui_result.dates
-        forecast.index.name = "date"
-        dic = {"input_path": data["input_path"], "forecast": forecast}
-        return TSFcResult(dic)

+ 68 - 14
libs/paddlex-hpi/tests/models/base.py

@@ -45,9 +45,17 @@ class BaseTestPredictor(object):
         raise NotImplementedError
 
     @property
+    def expected_result_with_args_url(self):
+        raise NotImplementedError
+
+    @property
     def predictor_cls(self):
         raise NotImplementedError
 
+    @property
+    def should_test_with_args(self):
+        return False
+
     @pytest.fixture(scope="class")
     def data_dir(self):
         with tempfile.TemporaryDirectory() as td:
@@ -67,17 +75,6 @@ class BaseTestPredictor(object):
         yield input_data_path
 
     @pytest.fixture(scope="class")
-    def input_data_dir(self, data_dir, input_data_path):
-        input_data_dir = data_dir / "input_data"
-        input_data_dir.mkdir()
-        for i in range(NUM_INPUT_FILES):
-            shutil.copy(
-                input_data_path,
-                (input_data_dir / f"test_{i}").with_suffix(input_data_path.suffix),
-            )
-        yield input_data_dir
-
-    @pytest.fixture(scope="class")
     def expected_result(self, data_dir):
         expected_result_path = data_dir / "expected.json"
         download(self.expected_result_url, expected_result_path)
@@ -85,6 +82,14 @@ class BaseTestPredictor(object):
             expected_result = json.load(f)
         yield expected_result
 
+    @pytest.fixture(scope="class")
+    def expected_result_with_args(self, data_dir):
+        expected_result_with_args_path = data_dir / "expected_with_args.json"
+        download(self.expected_result_with_args_url, expected_result_with_args_path)
+        with open(expected_result_with_args_path, "r", encoding="utf-8") as f:
+            expected_result = json.load(f)
+        yield expected_result
+
     @pytest.mark.parametrize("device", DEVICES)
     def test___call__single_input_data(
         self, model_path, input_data_path, device, expected_result
@@ -97,14 +102,48 @@ class BaseTestPredictor(object):
 
     @pytest.mark.parametrize("device", DEVICES)
     @pytest.mark.parametrize("batch_size", BATCH_SIZES)
-    def test___call__input_data_dir(
-        self, model_path, input_data_dir, device, batch_size, expected_result
+    def test___call__input_batch_data(
+        self, model_path, input_data_path, device, batch_size, expected_result
     ):
         predictor = self.predictor_cls(model_path, device=device)
         predictor.set_predictor(batch_size=batch_size)
-        output = predictor(str(input_data_dir))
+        output = predictor([str(input_data_path)] * NUM_INPUT_FILES)
         self._check_output(output, expected_result, NUM_INPUT_FILES)
 
+    @pytest.mark.parametrize("device", DEVICES)
+    def test__call__with_predictor_args(
+        self, model_path, input_data_path, device, request
+    ):
+        if self.should_test_with_args:
+            self._predict_with_predictor_args(
+                model_path,
+                input_data_path,
+                device,
+                request.getfixturevalue("expected_result_with_args"),
+            )
+        else:
+            pytest.skip("Skipping test__call__with_predictor_args for this predictor")
+
+    @pytest.mark.parametrize("device", DEVICES)
+    def test__call__with_predict_args(
+        self,
+        model_path,
+        input_data_path,
+        device,
+        expected_result,
+        request,
+    ):
+        if self.should_test_with_args:
+            self._predict_with_predict_args(
+                model_path,
+                input_data_path,
+                device,
+                expected_result,
+                request.getfixturevalue("expected_result_with_args"),
+            )
+        else:
+            pytest.skip("Skipping test__call__with_predict_args for this predictor")
+
     def _check_output(self, output, expected_result, expected_num_results):
         assert isinstance(output, GeneratorType)
         # Note that this exhausts the generator
@@ -115,3 +154,18 @@ class BaseTestPredictor(object):
 
     def _check_result(self, result, expected_result):
         raise NotImplementedError
+
+    def _predict_with_predictor_args(
+        self, model_path, input_data_path, device, expected_result_with_args
+    ):
+        raise NotImplementedError
+
+    def _predict_with_predict_args(
+        self,
+        model_path,
+        input_data_path,
+        device,
+        expected_result,
+        expected_result_with_args,
+    ):
+        raise NotImplementedError

+ 2 - 0
libs/paddlex-hpi/tests/models/test_anomaly_detection.py

@@ -42,6 +42,8 @@ class TestUadPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, SegResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         pred = result["pred"]
         expected_pred = np.array(expected_result["pred"], dtype=np.int32)

+ 51 - 0
libs/paddlex-hpi/tests/models/test_face_recognition.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex.inference.results import BaseResult
+from tests.models.base import BaseTestPredictor
+
+from paddlex_hpi.models import FaceRecPredictor
+
+MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/face_rec_model.zip"
+INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/face_rec_input.jpg"
+EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/face_rec_result.json"
+
+
+class TestFaceRecPredictor(BaseTestPredictor):
+    @property
+    def model_url(self):
+        return MODEL_URL
+
+    @property
+    def input_data_url(self):
+        return INPUT_DATA_URL
+
+    @property
+    def expected_result_url(self):
+        return EXPECTED_RESULT_URL
+
+    @property
+    def predictor_cls(self):
+        return FaceRecPredictor
+
+    def _check_result(self, result, expected_result):
+        assert isinstance(result, BaseResult)
+        assert "input_img" in result
+        result.pop("input_img")
+        assert set(result) == set(expected_result)
+        expected_result = expected_result["feature"]
+        result = result["feature"].tolist()
+        assert sum([abs(x - y) for x, y in zip(result, expected_result)]) < 0.001 * len(
+            result
+        )

+ 2 - 0
libs/paddlex-hpi/tests/models/test_formula_recognition.py

@@ -41,5 +41,7 @@ class TestLaTeXOCRPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, FormulaRecResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         assert result["rec_text"] == expected_result["rec_text"]

+ 4 - 2
libs/paddlex-hpi/tests/models/test_general_recognition.py

@@ -41,9 +41,11 @@ class TestShiTuRecPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, BaseResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
-        expected_result = expected_result["rec_feature"]
-        result = result["rec_feature"].tolist()
+        expected_result = expected_result["feature"]
+        result = result["feature"].tolist()
         assert sum([abs(x - y) for x, y in zip(result, expected_result)]) < 0.001 * len(
             result
         )

+ 32 - 0
libs/paddlex-hpi/tests/models/test_image_classification.py

@@ -21,6 +21,7 @@ from paddlex_hpi.models import ClasPredictor
 MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_model.zip"
 INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_input.jpg"
 EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_result.json"
+EXPECTED_RESULT_WITH_ARGS_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_result_with_args.json"
 
 
 class TestClasPredictor(BaseTestPredictor):
@@ -37,11 +38,42 @@ class TestClasPredictor(BaseTestPredictor):
         return EXPECTED_RESULT_URL
 
     @property
+    def expected_result_with_args_url(self):
+        return EXPECTED_RESULT_WITH_ARGS_URL
+
+    @property
+    def should_test_with_args(self):
+        return True
+
+    @property
     def predictor_cls(self):
         return ClasPredictor
 
+    def _predict_with_predictor_args(
+        self, model_path, input_data_path, device, expected_result_with_args
+    ):
+        predictor = self.predictor_cls(model_path, device=device, topk=2)
+        output = predictor(str(input_data_path))
+        self._check_output(output, expected_result_with_args, 1)
+
+    def _predict_with_predict_args(
+        self,
+        model_path,
+        input_data_path,
+        device,
+        expected_result,
+        expected_result_with_args,
+    ):
+        predictor = self.predictor_cls(model_path, device=device)
+        output = predictor(str(input_data_path), topk=2)
+        self._check_output(output, expected_result_with_args, 1)
+        output = predictor(str(input_data_path))
+        self._check_output(output, expected_result, 1)
+
     def _check_result(self, result, expected_result):
         assert isinstance(result, TopkResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         assert result["class_ids"] == expected_result["class_ids"]
         assert np.allclose(

+ 2 - 0
libs/paddlex-hpi/tests/models/test_image_unwarping.py

@@ -42,6 +42,8 @@ class TestWarpPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, DocTrResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         assert np.allclose(
             result["doctr_img"],

+ 32 - 0
libs/paddlex-hpi/tests/models/test_instance_segmentation.py

@@ -21,6 +21,7 @@ from paddlex_hpi.models import InstanceSegPredictor
 MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/instance_seg_model.zip"
 INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/instance_seg_input.jpg"
 EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/instance_seg_result.json"
+EXPECTED_RESULT_WITH_ARGS_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/instance_seg_result_with_args.json"
 
 
 class TestInstanceSegPredictor(BaseTestPredictor):
@@ -37,11 +38,42 @@ class TestInstanceSegPredictor(BaseTestPredictor):
         return EXPECTED_RESULT_URL
 
     @property
+    def expected_result_with_args_url(self):
+        return EXPECTED_RESULT_WITH_ARGS_URL
+
+    @property
+    def should_test_with_args(self):
+        return True
+
+    @property
     def predictor_cls(self):
         return InstanceSegPredictor
 
+    def _predict_with_predictor_args(
+        self, model_path, input_data_path, device, expected_result_with_args
+    ):
+        predictor = self.predictor_cls(model_path, device=device, threshold=0.85)
+        output = predictor(str(input_data_path))
+        self._check_output(output, expected_result_with_args, 1)
+
+    def _predict_with_predict_args(
+        self,
+        model_path,
+        input_data_path,
+        device,
+        expected_result,
+        expected_result_with_args,
+    ):
+        predictor = self.predictor_cls(model_path, device=device)
+        output = predictor(str(input_data_path), threshold=0.85)
+        self._check_output(output, expected_result_with_args, 1)
+        output = predictor(str(input_data_path))
+        self._check_output(output, expected_result, 1)
+
     def _check_result(self, result, expected_result):
         assert isinstance(result, InstanceSegResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         # TODO: Check masks
         compare_det_results(

+ 32 - 0
libs/paddlex-hpi/tests/models/test_multilabel_classification.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import numpy as np
+import pytest
 from paddlex.inference.results import MLClassResult
 from tests.models.base import BaseTestPredictor
 
@@ -21,6 +22,7 @@ from paddlex_hpi.models import MLClasPredictor
 MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_model.zip"
 INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_input.jpg"
 EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_result.json"
+EXPECTED_RESULT_WITH_ARGS_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_result_with_args.json"
 
 
 class TestMLClasPredictor(BaseTestPredictor):
@@ -37,11 +39,41 @@ class TestMLClasPredictor(BaseTestPredictor):
         return EXPECTED_RESULT_URL
 
     @property
+    def expected_result_with_args_url(self):
+        return EXPECTED_RESULT_WITH_ARGS_URL
+
+    @property
+    def should_test_with_args(self):
+        return True
+
+    @property
     def predictor_cls(self):
         return MLClasPredictor
 
+    def _predict_with_predictor_args(
+        self, model_path, input_data_path, device, expected_result_with_args
+    ):
+        predictor = self.predictor_cls(model_path, device=device, threshold=0.85)
+        output = predictor(str(input_data_path))
+        self._check_output(output, expected_result_with_args, 1)
+
+    def _predict_with_predict_args(
+        self,
+        model_path,
+        input_data_path,
+        device,
+        expected_result,
+        expected_result_with_args,
+    ):
+        predictor = self.predictor_cls(model_path, device=device)
+        with pytest.raises(TypeError):
+            output = predictor(str(input_data_path), threshold=0.85)
+            output = list(output)
+
     def _check_result(self, result, expected_result):
         assert isinstance(result, MLClassResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         assert result["class_ids"] == expected_result["class_ids"]
         assert np.allclose(

+ 32 - 0
libs/paddlex-hpi/tests/models/test_object_detection.py

@@ -21,6 +21,7 @@ from paddlex_hpi.models import DetPredictor
 MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_model.zip"
 INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_input.jpg"
 EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_result.json"
+EXPECTED_RESULT_WITH_ARGS_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_result_with_args.json"
 
 
 class TestDetPredictor(BaseTestPredictor):
@@ -37,11 +38,42 @@ class TestDetPredictor(BaseTestPredictor):
         return EXPECTED_RESULT_URL
 
     @property
+    def expected_result_with_args_url(self):
+        return EXPECTED_RESULT_WITH_ARGS_URL
+
+    @property
+    def should_test_with_args(self):
+        return True
+
+    @property
     def predictor_cls(self):
         return DetPredictor
 
+    def _predict_with_predictor_args(
+        self, model_path, input_data_path, device, expected_result_with_args
+    ):
+        predictor = self.predictor_cls(model_path, device=device, threshold=0.7)
+        output = predictor(str(input_data_path))
+        self._check_output(output, expected_result_with_args, 1)
+
+    def _predict_with_predict_args(
+        self,
+        model_path,
+        input_data_path,
+        device,
+        expected_result,
+        expected_result_with_args,
+    ):
+        predictor = self.predictor_cls(model_path, device=device)
+        output = predictor(str(input_data_path), threshold=0.7)
+        self._check_output(output, expected_result_with_args, 1)
+        output = predictor(str(input_data_path))
+        self._check_output(output, expected_result, 1)
+
     def _check_result(self, result, expected_result):
         assert isinstance(result, DetResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         compare_det_results(
             [obj["coordinate"] for obj in result["boxes"]],

+ 30 - 0
libs/paddlex-hpi/tests/models/test_semantic_segmentation.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import numpy as np
+import pytest
 from paddlex.inference.results import SegResult
 from tests.models.base import BaseTestPredictor
 
@@ -37,11 +38,40 @@ class TestSegPredictor(BaseTestPredictor):
         return EXPECTED_RESULT_URL
 
     @property
+    def expected_result_with_args_url(self):
+        return EXPECTED_RESULT_URL
+
+    @property
     def predictor_cls(self):
         return SegPredictor
 
+    @property
+    def should_test_with_args(self):
+        return True
+
+    def _predict_with_predictor_args(
+        self, model_path, input_data_path, device, expected_result_with_args
+    ):
+        with pytest.raises(TypeError):
+            predictor = self.predictor_cls(model_path, device=device, target_size=400)
+
+    def _predict_with_predict_args(
+        self,
+        model_path,
+        input_data_path,
+        device,
+        expected_result,
+        expected_result_with_args,
+    ):
+        predictor = self.predictor_cls(model_path, device=device)
+        with pytest.raises(TypeError):
+            output = predictor(str(input_data_path), target_size=400)
+            output = list(output)
+
     def _check_result(self, result, expected_result):
         assert isinstance(result, SegResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         pred = result["pred"]
         expected_pred = np.array(expected_result["pred"], dtype=np.int32)

+ 2 - 0
libs/paddlex-hpi/tests/models/test_table_recognition.py

@@ -50,6 +50,8 @@ class TestTablePredictor(BaseTestPredictor):
             ]
 
         assert isinstance(result, TableRecResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         compare_det_results(
             [_unflatten_poly(poly) for poly in result["bbox"]],

+ 2 - 0
libs/paddlex-hpi/tests/models/test_text_recognition.py

@@ -42,6 +42,8 @@ class TestTextRecPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, TextRecResult)
+        assert "input_img" in result
+        result.pop("input_img")
         assert set(result) == set(expected_result)
         assert result["rec_text"] == expected_result["rec_text"]
         assert np.allclose(

+ 2 - 0
libs/paddlex-hpi/tests/models/test_ts_ad.py

@@ -43,6 +43,8 @@ class TestTSAdPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, TSAdResult)
+        assert "input_ts" in result
+        result.pop("input_ts")
         assert set(result) == set(expected_result)
         expected_result = json.loads(expected_result["anomaly"])
         result = result["anomaly"].to_dict(orient="records")

+ 2 - 0
libs/paddlex-hpi/tests/models/test_ts_cls.py

@@ -43,6 +43,8 @@ class TestTSClsPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, TSClsResult)
+        assert "input_ts" in result
+        result.pop("input_ts")
         assert set(result) == set(expected_result)
         expected_result = json.loads(expected_result["classification"])
         result = result["classification"].to_dict(orient="records")

+ 2 - 0
libs/paddlex-hpi/tests/models/test_ts_fc.py

@@ -43,6 +43,8 @@ class TestTSFcPredictor(BaseTestPredictor):
 
     def _check_result(self, result, expected_result):
         assert isinstance(result, TSFcResult)
+        assert "input_ts" in result
+        result.pop("input_ts")
         assert set(result) == set(expected_result)
         expected_result = json.loads(expected_result["forecast"])
         expected_result = [{"OT": round(i["OT"], 3)} for i in expected_result]

+ 1 - 1
paddlex/inference/models_new/base/predictor/base_predictor.py

@@ -124,7 +124,7 @@ class BasePredictor(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def set_predictor(self) -> None:
+    def set_predictor(self, batch_size: int = None, device: str = None, *args) -> None:
         """Sets up the predictor."""
         raise NotImplementedError
 

+ 3 - 3
paddlex/inference/models_new/common/static_infer.py

@@ -126,9 +126,9 @@ class StaticInfer:
     def _create(
         self,
     ) -> Tuple[
-        paddle.base.libpaddle.PaddleInferPredictor,
-        paddle.base.libpaddle.PaddleInferTensor,
-        paddle.base.libpaddle.PaddleInferTensor,
+        "paddle.base.libpaddle.PaddleInferPredictor",
+        "paddle.base.libpaddle.PaddleInferTensor",
+        "paddle.base.libpaddle.PaddleInferTensor",
     ]:
         """_create"""
         from lazy_paddle.inference import Config, create_predictor

部分文件因文件數量過多而無法顯示