Selaa lähdekoodia

[Refactor] New abstraction for predictors (#1990)

* Refactor base predictor

* Fix bug

* Add config_path attr

* Change create_predictor signature

* Fix class
Lin Manhui 1 vuosi sitten
vanhempi
commit
20eff3ae8a

+ 0 - 1
.check_license.py

@@ -28,7 +28,6 @@ LICENSE_TEXT = """# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 """
 
 

+ 10 - 20
paddlex/inference/components/paddle_predictor/option.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from ...utils.device import parse_device
 from ....utils.func_register import FuncRegister
 from ....utils import logging
 
@@ -81,29 +82,18 @@ class PaddlePredictorOption(object):
         self._cfg["batch_size"] = batch_size
 
     @register("device")
-    def set_device(self, device_setting: str):
+    def set_device(self, device: str):
         """set device"""
-        self._cfg["device"], self._cfg["device_id"] = self.parse_device_setting(
-            device_setting
-        )
-
-    @classmethod
-    def parse_device_setting(cls, device_setting):
-        if len(device_setting.split(":")) == 1:
-            device = device_setting.split(":")[0]
-            device_id = 0
-        else:
-            assert len(device_setting.split(":")) == 2
-            device = device_setting.split(":")[0]
-            device_id = device_setting.split(":")[1].split(",")[0]
-            logging.warning(f"The device id has been set to {device_id}.")
-
-        if device.lower() not in cls.SUPPORT_DEVICE:
-            support_run_mode_str = ", ".join(cls.SUPPORT_DEVICE)
+        device_type, device_ids = parse_device(device)
+        self._cfg["device"] = device_type
+        if device_type not in self.SUPPORT_DEVICE:
+            support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
             raise ValueError(
-                f"`device` must be {support_run_mode_str}, but received {repr(device)}."
+                f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
             )
-        return device.lower(), int(device_id)
+        device_id = device_ids[0] if device_ids is not None else 0
+        self._cfg["device_id"] = device_id
+        logging.warning(f"The device ID has been set to {device_id}.")
 
     @register("min_subgraph_size")
     def set_min_subgraph_size(self, min_subgraph_size: int):

+ 3 - 6
paddlex/inference/predictors/__init__.py

@@ -15,7 +15,7 @@
 
 from pathlib import Path
 
-from .base import BasePredictor
+from .base import BasePredictor, BasicPredictor
 from .image_classification import ClasPredictor
 from .text_detection import TextDetPredictor
 from .text_recognition import TextRecPredictor
@@ -23,17 +23,14 @@ from .table_recognition import TablePredictor
 from .official_models import official_models
 
 
-def create_predictor(
-    model: str, device: str = None, pp_option=None, *args, **kwargs
-) -> BasePredictor:
+def create_predictor(model: str, device: str = None, *args, **kwargs) -> BasePredictor:
     model_dir = check_model(model)
     config = BasePredictor.load_config(model_dir)
     model_name = config["Global"]["model_name"]
-    return BasePredictor.get(model_name)(
+    return BasicPredictor.get(model_name)(
         model_dir=model_dir,
         config=config,
         device=device,
-        pp_option=pp_option,
         *args,
         **kwargs,
     )

+ 45 - 17
paddlex/inference/predictors/base.py

@@ -17,16 +17,25 @@ import codecs
 from pathlib import Path
 from abc import abstractmethod
 
+import GPUtil
+
 from ...utils.subclass_register import AutoRegisterABCMetaClass
+from ..utils.device import constr_device
 from ...utils import logging
 from ..components.base import BaseComponent, ComponentsEngine
 from ..components.paddle_predictor.option import PaddlePredictorOption
 from ..utils.process_hook import generatorable_method
 
 
-class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
-    __is_base = True
+def _get_default_device():
+    avail_gpus = GPUtil.getAvailable()
+    if not avail_gpus:
+        return "cpu"
+    else:
+        return constr_device("gpu", [avail_gpus[0]])
+
 
+class BasePredictor(BaseComponent):
     INPUT_KEYS = "x"
     DEAULT_INPUTS = {"x": "x"}
     OUTPUT_KEYS = "result"
@@ -36,33 +45,55 @@ class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
 
     MODEL_FILE_PREFIX = "inference"
 
-    def __init__(self, model_dir, config=None, device=None, pp_option=None, **kwargs):
+    def __init__(self, model_dir, config=None, device=None, **kwargs):
         super().__init__()
         self.model_dir = Path(model_dir)
         self.config = config if config else self.load_config(self.model_dir)
+        self.device = device if device else _get_default_device()
         self.kwargs = self._check_args(kwargs)
+        # alias predict() to the __call__()
+        self.predict = self.__call__
 
-        self.pp_option = PaddlePredictorOption() if pp_option is None else pp_option
-        if device is not None:
-            self.pp_option.set_device(device)
+    @property
+    def config_path(self):
+        return self.get_config_path(self.model_dir)
 
-        self.components = self._build_components()
-        self.engine = ComponentsEngine(self.components)
+    @property
+    def model_name(self) -> str:
+        return self.config["Global"]["model_name"]
 
-        # alias predict() to the __call__()
-        self.predict = self.__call__
+    @abstractmethod
+    def apply(self, x):
+        raise NotImplementedError
 
-        logging.debug(
-            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}\nEnv: {self.pp_option}"
-        )
+    @classmethod
+    def get_config_path(cls, model_dir):
+        return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
 
     @classmethod
     def load_config(cls, model_dir):
-        config_path = model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
+        config_path = cls.get_config_path(model_dir)
         with codecs.open(config_path, "r", "utf-8") as file:
             dic = yaml.load(file, Loader=yaml.FullLoader)
         return dic
 
+    def _check_args(self, kwargs):
+        return kwargs
+
+
+class BasicPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
+    __is_base = True
+
+    def __init__(self, model_dir, config=None, device=None, pp_option=None, **kwargs):
+        super().__init__(model_dir=model_dir, config=config, device=device, **kwargs)
+        self.pp_option = PaddlePredictorOption() if pp_option is None else pp_option
+        self.pp_option.set_device(self.device)
+        self.components = self._build_components()
+        self.engine = ComponentsEngine(self.components)
+        logging.debug(
+            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}\nEnv: {self.pp_option}"
+        )
+
     def apply(self, x):
         """predict"""
         yield from self._generate_res(self.engine(x))
@@ -71,9 +102,6 @@ class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
     def _generate_res(self, data):
         return self._pack_res(data)
 
-    def _check_args(self, kwargs):
-        return kwargs
-
     @abstractmethod
     def _build_components(self):
         raise NotImplementedError

+ 2 - 2
paddlex/inference/predictors/image_classification.py

@@ -19,10 +19,10 @@ from ...modules.image_classification.model_list import MODELS
 from ..components import *
 from ..results import TopkResult
 from ..utils.process_hook import batchable_method
-from .base import BasePredictor
+from .base import BasicPredictor
 
 
-class ClasPredictor(BasePredictor):
+class ClasPredictor(BasicPredictor):
 
     entities = MODELS
 

+ 2 - 2
paddlex/inference/predictors/text_detection.py

@@ -19,10 +19,10 @@ from ...modules.text_detection.model_list import MODELS
 from ..components import *
 from ..results import TextDetResult
 from ..utils.process_hook import batchable_method
-from .base import BasePredictor
+from .base import BasicPredictor
 
 
-class TextDetPredictor(BasePredictor):
+class TextDetPredictor(BasicPredictor):
 
     entities = MODELS
 

+ 2 - 2
paddlex/inference/predictors/text_recognition.py

@@ -19,10 +19,10 @@ from ...modules.text_recognition.model_list import MODELS
 from ..components import *
 from ..results import TextRecResult
 from ..utils.process_hook import batchable_method
-from .base import BasePredictor
+from .base import BasicPredictor
 
 
-class TextRecPredictor(BasePredictor):
+class TextRecPredictor(BasicPredictor):
 
     entities = MODELS
 

+ 13 - 0
paddlex/inference/utils/__init__.py

@@ -0,0 +1,13 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 35 - 0
paddlex/inference/utils/device.py

@@ -0,0 +1,35 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def constr_device(device_type, device_ids):
+    device_ids = ",".join(map(str, device_ids))
+    return f"{device_type}:{device_ids}"
+
+
+def parse_device(device):
+    parts = device.split(":")
+    if len(parts) > 2:
+        raise ValueError(f"Invalid device: {device}")
+    if len(parts) == 1:
+        device_type, device_ids = parts[0], None
+    else:
+        device_type, device_ids = parts
+        device_ids = device_ids.split(",")
+        for device_id in device_ids:
+            if not device_id.isdigit():
+                raise ValueError(f"Invalid device ID: {device_id}")
+        device_ids = list(map(int, device_ids))
+    device_type = device_type.lower()
+    return device_type, device_ids

+ 1 - 0
requirements.txt

@@ -16,3 +16,4 @@ pandas
 parsley
 requests
 tokenizers==0.19.1
+GPUtil>=1.4.0