Browse Source

support set device in create_predictor & bugfix

gaotingquan 1 năm trước cách đây
mục cha
commit
96e3c004da

+ 2 - 1
paddlex/__init__.py

@@ -28,7 +28,8 @@ from .modules import (
 )
 
 
-from .inference import create_model, create_pipeline
+from .model import create_model
+from .inference import create_predictor, create_pipeline
 
 
 def _initialize():

+ 20 - 30
paddlex/engine.py

@@ -14,46 +14,36 @@
 
 import os
 
-from .modules import (
-    build_dataset_checker,
-    build_trainer,
-    build_evaluater,
-    build_exportor,
-    build_predictor,
-)
+
 from .utils.result_saver import try_except_decorator
-from .utils import config
+from .utils.config import parse_args, get_config
 from .utils.errors import raise_unsupported_api_error
+from .model import _ModelBasedConfig
 
 
 class Engine(object):
     """Engine"""
 
     def __init__(self):
-        args = config.parse_args()
-        self.config = config.get_config(
-            args.config, overrides=args.override, show=False
-        )
-        self.mode = self.config.Global.mode
-        self.output = self.config.Global.output
+        args = parse_args()
+        config = get_config(args.config, overrides=args.override, show=False)
+        self._mode = config.Global.mode
+        self._output = config.Global.output
+        self._model = _ModelBasedConfig(config)
 
     @try_except_decorator
     def run(self):
         """the main function"""
-        if self.config.Global.mode == "check_dataset":
-            dataset_checker = build_dataset_checker(self.config)
-            return dataset_checker.check()
-        elif self.config.Global.mode == "train":
-            trainer = build_trainer(self.config)
-            trainer.train()
-        elif self.config.Global.mode == "evaluate":
-            evaluator = build_evaluater(self.config)
-            return evaluator.evaluate()
-        elif self.config.Global.mode == "export":
-            exportor = build_exportor(self.config)
-            return exportor.export()
-        elif self.config.Global.mode == "predict":
-            predictor = build_predictor(self.config)
-            return predictor.predict()
+        if self._mode == "check_dataset":
+            return self._model.check_dataset()
+        elif self._mode == "train":
+            self._model.train()
+        elif self._mode == "evaluate":
+            return self._model.evaluate()
+        elif self._mode == "export":
+            return self._model.export()
+        elif self._mode == "predict":
+            for res in self._model.predict():
+                res.print(json_format=False)
         else:
-            raise_unsupported_api_error(f"{self.config.Global.mode}", self.__class__)
+            raise_unsupported_api_error(f"{self._mode}", self.__class__)

+ 1 - 1
paddlex/inference/__init__.py

@@ -12,6 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .models import create_model
+from .models import create_predictor
 from .pipelines import create_pipeline
 from .utils.pp_option import PaddlePredictorOption

+ 11 - 2
paddlex/inference/models/__init__.py

@@ -57,8 +57,14 @@ def _create_hp_predictor(
     )
 
 
-def create_model(
-    model: str, *args, use_hpip=False, hpi_params=None, **kwargs
+def create_predictor(
+    model: str,
+    *args,
+    device=None,
+    pp_option=None,
+    use_hpip=False,
+    hpi_params=None,
+    **kwargs,
 ) -> BasePredictor:
     model_dir = check_model(model)
     config = BasePredictor.load_config(model_dir)
@@ -69,6 +75,7 @@ def create_model(
             model_dir=model_dir,
             config=config,
             hpi_params=hpi_params,
+            device=device,
             *args,
             **kwargs,
         )
@@ -76,6 +83,8 @@ def create_model(
         return BasicPredictor.get(model_name)(
             model_dir=model_dir,
             config=config,
+            device=device,
+            pp_option=pp_option,
             *args,
             **kwargs,
         )

+ 14 - 10
paddlex/inference/models/base/base_predictor.py

@@ -44,14 +44,11 @@ class BasePredictor(BaseComponent):
         self.model_dir = Path(model_dir)
         self.config = config if config else self.load_config(self.model_dir)
 
-        self._pred_set_func_map = {}
-        self._pred_set_register = FuncRegister(self._pred_set_func_map)
-
         # alias predict() to the __call__()
         self.predict = self.__call__
 
     def __call__(self, input, **kwargs):
-        self._set_predict(**kwargs)
+        self.set_predict(**kwargs)
         for res in super().__call__(input):
             yield res["result"]
 
@@ -67,6 +64,10 @@ class BasePredictor(BaseComponent):
     def apply(self, x):
         raise NotImplementedError
 
+    @abstractmethod
+    def set_predict(self):
+        raise NotImplementedError
+
     @classmethod
     def get_config_path(cls, model_dir):
         return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
@@ -78,10 +79,6 @@ class BasePredictor(BaseComponent):
             dic = yaml.load(file, Loader=yaml.FullLoader)
         return dic
 
-    def _set_predict(self, **kwargs):
-        for k in kwargs:
-            self._pred_set_func_map[k](kwargs[k])
-
 
 class BasicPredictor(
     BasePredictor, DeviceSetMixin, PPOptionSetMixin, metaclass=AutoRegisterABCMetaClass
@@ -89,12 +86,15 @@ class BasicPredictor(
 
     __is_base = True
 
-    def __init__(self, model_dir, config=None):
+    def __init__(self, model_dir, config=None, device=None, pp_option=None):
         super().__init__(model_dir=model_dir, config=config)
+        self._pred_set_func_map = {}
+        self._pred_set_register = FuncRegister(self._pred_set_func_map)
         self._pred_set_register("device")(self.set_device)
         self._pred_set_register("pp_option")(self.set_pp_option)
 
-        self.pp_option = PaddlePredictorOption()
+        self.pp_option = pp_option if pp_option else PaddlePredictorOption()
+        self.pp_option.set_device(device)
         self.components = {}
         self._build_components()
         self.engine = ComponentsEngine(self.components)
@@ -128,6 +128,10 @@ class BasicPredictor(
             ), f"The key ({key}) has been used: {self.components}!"
             self.components[key] = cmp
 
+    def set_predict(self, **kwargs):
+        for k in kwargs:
+            self._pred_set_func_map[k](kwargs[k])
+
     @abstractmethod
     def _build_components(self):
         raise NotImplementedError

+ 1 - 1
paddlex/inference/models/general_recognition.py

@@ -98,4 +98,4 @@ class ShiTuRecPredictor(CVPredictor):
     @batchable_method
     def _pack_res(self, data):
         keys = ["img_path", "rec_feature"]
-        return {"result": BaseResult({key: data[key] for key in keys})}
+        return BaseResult({key: data[key] for key in keys})

+ 5 - 1
paddlex/inference/pipelines/__init__.py

@@ -46,10 +46,14 @@ def create_pipeline(
             raise Exception(f"The pipeline don't exist! ({pipeline})")
     config = parse_config(pipeline)
     pipeline_name = config["Global"]["pipeline_name"]
+    pipeline_setting = config["Pipeline"]
+    pipeline_setting.update(kwargs)
+
     predictor_kwargs = {"use_hpip": use_hpip}
     if hpi_params is not None:
         predictor_kwargs["hpi_params"] = hpi_params
+
     pipeline = BasePipeline.get(pipeline_name)(
-        predictor_kwargs=predictor_kwargs, *args, **config["Pipeline"], **kwargs
+        predictor_kwargs=predictor_kwargs, *args, **pipeline_setting
     )
     return pipeline

+ 2 - 2
paddlex/inference/pipelines/base.py

@@ -16,7 +16,7 @@ from abc import ABC
 from typing import Any, Dict, Optional
 
 from ...utils.subclass_register import AutoRegisterABCMetaClass
-from ..models import create_model
+from ..models import create_predictor
 
 
 class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
@@ -35,4 +35,4 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         yield from self.predict(*args, **kwargs)
 
     def _create_model(self, *args, **kwargs):
-        return create_model(*args, **kwargs, **self._predictor_kwargs)
+        return create_predictor(*args, **kwargs, **self._predictor_kwargs)

+ 5 - 5
paddlex/inference/pipelines/ocr.py

@@ -29,9 +29,9 @@ class OCRPipeline(BasePipeline):
         predictor_kwargs=None,
     ):
         super().__init__(predictor_kwargs)
-        self._det_predict = self._create_model(det_model)
-        self._rec_predict = self._create_model(rec_model)
-        self.is_curve = self._det_predict.model_name in [
+        self.det_model = self._create_model(det_model)
+        self.rec_model = self._create_model(rec_model)
+        self.is_curve = self.det_model.model_name in [
             "PP-OCRv4_mobile_seal_det",
             "PP-OCRv4_server_seal_det",
         ]
@@ -42,7 +42,7 @@ class OCRPipeline(BasePipeline):
 
     def predict(self, input, **kwargs):
         device = kwargs.get("device", "gpu")
-        for det_res in self._det_predict(
+        for det_res in self.det_model(
             input, batch_size=kwargs.get("det_batch_size", 1), device=device
         ):
             single_img_res = (
@@ -52,7 +52,7 @@ class OCRPipeline(BasePipeline):
             single_img_res["rec_score"] = []
             if len(single_img_res["dt_polys"]) > 0:
                 all_subs_of_img = list(self._crop_by_polys(single_img_res))
-                for rec_res in self._rec_predict(
+                for rec_res in self.rec_model(
                     all_subs_of_img,
                     batch_size=kwargs.get("rec_batch_size", 1),
                     device=device,

+ 2 - 2
paddlex/inference/pipelines/single_model_pipeline.py

@@ -31,7 +31,7 @@ class SingleModelPipeline(BasePipeline):
 
     def __init__(self, model, predictor_kwargs=None):
         super().__init__(predictor_kwargs)
-        self._predict = self._create_model(model)
+        self.model = self._create_model(model)
 
     def predict(self, input, **kwargs):
-        yield from self._predict(input, **kwargs)
+        yield from self.model(input, **kwargs)

+ 7 - 0
paddlex/inference/utils/pp_option.py

@@ -77,6 +77,8 @@ class PaddlePredictorOption(object):
     @register("device")
     def set_device(self, device: str):
         """set device"""
+        if not device:
+            return
         device_type, device_ids = parse_device(device)
         self._cfg["device"] = device_type
         if device_type not in self.SUPPORT_DEVICE:
@@ -147,3 +149,8 @@ class PaddlePredictorOption(object):
         if key not in self._cfg:
             raise Exception(f"The key ({key}) is not found in cfg: \n {self._cfg}")
         return self._cfg.get(key)
+
+    def __eq__(self, obj):
+        if isinstance(obj, PaddlePredictorOption):
+            return obj._cfg == self._cfg
+        return False

+ 112 - 0
paddlex/model.py

@@ -0,0 +1,112 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+from copy import deepcopy
+
+from .inference import create_predictor, PaddlePredictorOption
+from .modules import (
+    build_dataset_checker,
+    build_trainer,
+    build_evaluater,
+    build_exportor,
+)
+
+
+# TODO(gaotingquan): support _ModelBasedConfig
+def create_model(model=None, **kwargs):
+    return _ModelBasedInference(model, **kwargs)
+
+
+class _BaseModel:
+    @abstractmethod
+    def check_dataset(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def train(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def evaluate(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def export(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def predict(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def set_predict(self, *args, **kwargs):
+        raise NotImplementedError
+
+    def __call__(self, *args, **kwargs):
+        yield from self.predict(*args, **kwargs)
+
+
+class _ModelBasedInference(_BaseModel):
+    def __init__(self, model, device=None, **kwargs):
+        self._predictor = create_predictor(model, device=device, **kwargs)
+
+    def predict(self, *args, **kwargs):
+        yield from self._predictor(*args, **kwargs)
+
+    def set_predict(self, **kwargs):
+        self._predictor.set_predict(**kwargs)
+
+
+class _ModelBasedConfig(_BaseModel):
+    def __init__(self, config=None, *args, **kwargs):
+        super().__init__()
+        self._config = config
+        self._model_name = config.Global.model
+
+    def _build_predictor(self):
+        predict_kwargs = deepcopy(self._config.Predict)
+
+        model_dir = predict_kwargs.pop("model_dir", None)
+        # if model_dir is None, using official
+        model = self._model_name if model_dir is None else model_dir
+
+        device = self._config.Global.get("device")
+        kernel_option = predict_kwargs.pop("kernel_option", {})
+        kernel_option.update({"device": device})
+
+        pp_option = PaddlePredictorOption(**kernel_option)
+        predictor = create_predictor(model, pp_option=pp_option)
+        assert "input" in predict_kwargs
+        return predict_kwargs, predictor
+
+    def check_dataset(self):
+        dataset_checker = build_dataset_checker(self._config)
+        return dataset_checker.check()
+
+    def train(self):
+        trainer = build_trainer(self._config)
+        trainer.train()
+
+    def evaluate(self):
+        evaluator = build_evaluater(self._config)
+        return evaluator.evaluate()
+
+    def export(self):
+        exportor = build_exportor(self._config)
+        return exportor.export()
+
+    def predict(self):
+        _predict_kwargs, _predictor = self._build_predictor()
+        yield from _predictor(**_predict_kwargs)

+ 0 - 2
paddlex/modules/__init__.py

@@ -20,8 +20,6 @@ from .base import (
     build_exportor,
 )
 
-from .predictor import build_predictor
-
 from .image_classification import (
     ClsDatasetChecker,
     ClsTrainer,

+ 0 - 44
paddlex/modules/predictor.py

@@ -1,44 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from copy import deepcopy
-from ..inference.models import create_model
-from ..inference.utils.pp_option import PaddlePredictorOption
-from ..utils.config import AttrDict
-
-
-class Predictor(object):
-    def __init__(self, config):
-        model_name = config.Global.model
-        self.predict_config = deepcopy(config.Predict)
-
-        model_dir = self.predict_config.pop("model_dir", None)
-        # if model_dir is None, using official
-        model = model_name if model_dir is None else model_dir
-        self.input_path = self.predict_config.pop("input_path")
-        self.pp_option = PaddlePredictorOption(
-            **self.predict_config.pop("kernel_option", {})
-        )
-        self.model = create_model(model)
-
-    def predict(self):
-        for res in self.model(
-            input=self.input_path, pp_option=self.pp_option, **self.predict_config
-        ):
-            res.print(json_format=False)
-
-
-def build_predictor(config: AttrDict):
-    """build predictor by config for dev"""
-    return Predictor(config)

+ 4 - 7
paddlex/utils/result_saver.py

@@ -28,13 +28,13 @@ def try_except_decorator(func):
         try:
             result = func(self, *args, **kwargs)
             if result:
-                save_result(True, self.mode, self.output, result_dict=result)
+                save_result(True, self._mode, self._output, result_dict=result)
         except Exception as e:
             exc_type, exc_value, exc_tb = sys.exc_info()
             save_result(
                 False,
-                self.mode,
-                self.output,
+                self._mode,
+                self._output,
                 err_type=str(exc_type),
                 err_msg=str(exc_value),
             )
@@ -46,10 +46,7 @@ def try_except_decorator(func):
 
 def save_result(run_pass, mode, output, result_dict=None, err_type=None, err_msg=None):
     """format, build and save result"""
-    json_data = {
-        # "model_name": self.args.model_name,
-        "done_flag": run_pass
-    }
+    json_data = {"done_flag": run_pass}
     if not run_pass:
         assert result_dict is None and err_type is not None and err_msg is not None
         json_data.update({"err_type": err_type, "err_msg": err_msg})