Browse Source

integrate device settings functions

gaotingquan 1 năm trước cách đây
mục cha
commit
da8f8cce1d

+ 0 - 35
paddlex/inference/utils/device.py

@@ -1,35 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-def constr_device(device_type, device_ids):
-    device_ids = ",".join(map(str, device_ids))
-    return f"{device_type}:{device_ids}"
-
-
-def parse_device(device):
-    parts = device.split(":")
-    if len(parts) > 2:
-        raise ValueError(f"Invalid device: {device}")
-    if len(parts) == 1:
-        device_type, device_ids = parts[0], None
-    else:
-        device_type, device_ids = parts
-        device_ids = device_ids.split(",")
-        for device_id in device_ids:
-            if not device_id.isdigit():
-                raise ValueError(f"Invalid device ID: {device_id}")
-        device_ids = list(map(int, device_ids))
-    device_type = device_type.lower()
-    return device_type, device_ids

+ 3 - 2
paddlex/inference/utils/pp_option.py

@@ -13,8 +13,8 @@
 # limitations under the License.
 
 from ...utils.func_register import FuncRegister
+from ...utils.device import parse_device, set_env_for_device, get_default_device
 from ...utils import logging
-from .device import parse_device
 from .new_ir_blacklist import NEWIR_BLOCKLIST
 
 
@@ -55,7 +55,7 @@ class PaddlePredictorOption(object):
         """get default config"""
         return {
             "run_mode": "paddle",
-            "device": "gpu",
+            "device": get_default_device(),
             "device_id": 0,
             "min_subgraph_size": 3,
             "shape_info_filename": None,
@@ -90,6 +90,7 @@ class PaddlePredictorOption(object):
             )
         device_id = device_ids[0] if device_ids is not None else 0
         self._cfg["device_id"] = device_id
+        set_env_for_device(device)
         if device_type not in ("cpu"):
             if device_ids is None or len(device_ids) > 1:
                 logging.warning(f"The device ID has been set to {device_id}.")

+ 1 - 5
paddlex/modules/base/build_model.py

@@ -15,10 +15,9 @@
 
 import os
 from ...repo_apis.base import Config, PaddleModel
-from ...utils.device import get_device
 
 
-def build_model(model_name: str, device: str = None, config_path: str = None) -> tuple:
+def build_model(model_name: str, config_path: str = None) -> tuple:
     """build Config and PaddleModel
 
     Args:
@@ -31,8 +30,5 @@ def build_model(model_name: str, device: str = None, config_path: str = None) ->
         tuple(Config, PaddleModel): the Config and PaddleModel
     """
     config = Config(model_name, config_path)
-
-    if device:
-        config.update_device(get_device(device))
     model = PaddleModel(config=config)
     return config, model

+ 4 - 4
paddlex/modules/base/evaluator.py

@@ -17,7 +17,7 @@ from pathlib import Path
 from abc import ABC, abstractmethod
 
 from .build_model import build_model
-from ...utils.device import get_device
+from ...utils.device import update_device_num
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 from ...utils.logging import *
@@ -138,9 +138,9 @@ evaling!"
         Returns:
             str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
         """
-        return get_device(
-            self.global_config.device, using_device_number=using_device_number
-        )
+        if using_device_number:
+            return update_device_num(self.global_config.device, using_device_number)
+        return self.global_config.device
 
     @abstractmethod
     def update_config(self):

+ 4 - 4
paddlex/modules/base/exportor.py

@@ -17,7 +17,7 @@ from pathlib import Path
 from abc import ABC, abstractmethod
 
 from .build_model import build_model
-from ...utils.device import get_device
+from ...utils.device import update_device_num
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 from ...utils.logging import *
@@ -103,9 +103,9 @@ exporting!"
         Returns:
             str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
         """
-        # return get_device(
-        #     self.global_config.device, using_device_number=using_device_number)
-        return get_device("cpu")
+        if using_device_number:
+            return update_device_num(self.global_config.device, using_device_number)
+        return self.global_config.device
 
     def update_config(self):
         """update export config"""

+ 7 - 7
paddlex/modules/base/trainer/trainer.py → paddlex/modules/base/trainer.py

@@ -15,10 +15,10 @@
 import os
 from abc import ABC, abstractmethod
 from pathlib import Path
-from ..build_model import build_model
-from ....utils.device import get_device
-from ....utils.misc import AutoRegisterABCMetaClass
-from ....utils.config import AttrDict
+from .build_model import build_model
+from ...utils.device import update_device_num
+from ...utils.misc import AutoRegisterABCMetaClass
+from ...utils.config import AttrDict
 
 
 def build_trainer(config: AttrDict) -> "BaseTrainer":
@@ -88,9 +88,9 @@ training!"
         Returns:
             str: device setting, such as: `gpu:0,1`, `npu:0,1` `cpu`.
         """
-        return get_device(
-            self.global_config.device, using_device_number=using_device_number
-        )
+        if using_device_number:
+            return update_device_num(self.global_config.device, using_device_number)
+        return self.global_config.device
 
     @abstractmethod
     def update_config(self):

+ 0 - 16
paddlex/modules/base/trainer/__init__.py

@@ -1,16 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from .trainer import build_trainer, BaseTrainer

+ 0 - 430
paddlex/modules/base/trainer/train_deamon.py

@@ -1,430 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-import time
-import json
-import traceback
-import threading
-from abc import ABC, abstractmethod
-from pathlib import Path
-import lazy_paddle as paddle
-
-from ..build_model import build_model
-from ....utils.file_interface import write_json_file
-from ....utils import logging
-
-
-def try_except_decorator(func):
-    """try-except"""
-
-    def wrap(self, *args, **kwargs):
-        try:
-            func(self, *args, **kwargs)
-        except Exception as e:
-            exc_type, exc_value, exc_tb = sys.exc_info()
-            self.save_json()
-            traceback.print_exception(exc_type, exc_value, exc_tb)
-        finally:
-            self.processing = False
-
-    return wrap
-
-
-class BaseTrainDeamon(ABC):
-    """BaseTrainResultDemon"""
-
-    update_interval = 600
-    last_k = 5
-
-    def __init__(self, config):
-        """init"""
-        self.global_config = config.Global
-        self.disable_deamon = config.get("Benchmark", {}).get("disable_deamon", False)
-        self.init_pre_hook()
-        self.output = self.global_config.output
-        self.train_outputs = self.get_train_outputs()
-        self.save_paths = self.get_save_paths()
-        self.results = self.init_train_result()
-        self.save_json()
-        self.models = {}
-        self.init_post_hook()
-
-        self.config_recorder = {}
-        self.model_recorder = {}
-        self.processing = False
-        self.start()
-
-    def init_train_result(self):
-        """init train result structure"""
-        model_names = self.init_model_names()
-        configs = self.init_configs()
-        train_log = self.init_train_log()
-        vdl = self.init_vdl_log()
-
-        results = []
-        for i, model_name in enumerate(model_names):
-            results.append(
-                {
-                    "model_name": model_name,
-                    "done_flag": False,
-                    "config": configs[i],
-                    "label_dict": "",
-                    "train_log": train_log,
-                    "visualdl_log": vdl,
-                    "models": self.init_model_pkg(),
-                }
-            )
-        return results
-
-    def get_save_names(self):
-        """get names to save"""
-        return ["train_result.json"]
-
-    def get_train_outputs(self):
-        """get training outputs dir"""
-        return [Path(self.output)]
-
-    def init_model_names(self):
-        """get models name"""
-        return [self.global_config.model]
-
-    def get_save_paths(self):
-        """get the path to save train_result.json"""
-        return [Path(self.output, save_name) for save_name in self.get_save_names()]
-
-    def init_configs(self):
-        """get the init value of config field in result"""
-        return [""] * len(self.init_model_names())
-
-    def init_train_log(self):
-        """get train log"""
-        return ""
-
-    def init_vdl_log(self):
-        """get visualdl log"""
-        return ""
-
-    def init_model_pkg(self):
-        """get model package"""
-        init_content = self.init_model_content()
-        model_pkg = {}
-
-        for pkg in self.get_watched_model():
-            model_pkg[pkg] = init_content
-        return model_pkg
-
-    def normlize_path(self, dict_obj, relative_to):
-        """normlize path to string type path relative to the output"""
-        for key in dict_obj:
-            if isinstance(dict_obj[key], dict):
-                self.normlize_path(dict_obj[key], relative_to)
-            if isinstance(dict_obj[key], Path):
-                dict_obj[key] = (
-                    dict_obj[key]
-                    .resolve()
-                    .relative_to(relative_to.resolve())
-                    .as_posix()
-                )
-
-    def save_json(self):
-        """save result to json"""
-        for i, result in enumerate(self.results):
-            self.save_paths[i].parent.mkdir(parents=True, exist_ok=True)
-            self.normlize_path(result, relative_to=self.save_paths[i].parent)
-            write_json_file(result, self.save_paths[i], indent=2)
-
-    def start(self):
-        """start deamon thread"""
-        self.exit = False
-        self.thread = threading.Thread(target=self.run)
-        self.thread.daemon = True
-        if not self.disable_deamon:
-            self.thread.start()
-
-    def stop_hook(self):
-        """hook befor stop"""
-        for result in self.results:
-            result["done_flag"] = True
-            self.update()
-
-    def stop(self):
-        """stop self"""
-        self.exit = True
-        while True:
-            if not self.processing:
-                self.stop_hook()
-                break
-            time.sleep(60)
-
-    def run(self):
-        """main function"""
-        while not self.exit:
-            self.update()
-            if self.exit:
-                break
-            time.sleep(self.update_interval)
-
-    def update_train_log(self, train_output):
-        """update train log"""
-        train_log_path = train_output / "train.log"
-        if train_log_path.exists():
-            return train_log_path
-
-    def update_vdl_log(self, train_output):
-        """update visualdl log"""
-        vdl_path = list(train_output.glob("vdlrecords*log"))
-        if len(vdl_path) >= 1:
-            return vdl_path[0]
-
-    def update_label_dict(self, train_output):
-        """update label dict"""
-        dict_path = train_output.joinpath("label_dict.txt")
-        if not dict_path.exists():
-            return ""
-        return dict_path
-
-    @try_except_decorator
-    def update(self):
-        """update train result json"""
-        self.processing = True
-        for i in range(len(self.results)):
-            self.results[i] = self.update_result(self.results[i], self.train_outputs[i])
-        self.save_json()
-        self.processing = False
-
-    def get_model(self, model_name, config_path):
-        """initialize the model"""
-        if model_name not in self.models:
-            config, model = build_model(
-                model_name,
-                # using CPU to export model
-                device="cpu",
-                config_path=config_path,
-            )
-            self.models[model_name] = model
-        return self.models[model_name]
-
-    def get_watched_model(self):
-        """get the models needed to be watched"""
-        watched_models = [f"last_{i}" for i in range(1, self.last_k + 1)]
-        watched_models.append("best")
-        return watched_models
-
-    def init_model_content(self):
-        """get model content structure"""
-        return {
-            "score": "",
-            "pdparams": "",
-            "pdema": "",
-            "pdopt": "",
-            "pdstates": "",
-            "inference_config": "",
-            "pdmodel": "",
-            "pdiparams": "",
-            "pdiparams.info": "",
-        }
-
-    def update_result(self, result, train_output):
-        """update every result"""
-        train_output = Path(train_output).resolve()
-        config_path = train_output.joinpath("config.yaml").resolve()
-        if not config_path.exists():
-            return result
-
-        model_name = result["model_name"]
-        if (
-            model_name in self.config_recorder
-            and self.config_recorder[model_name] != config_path
-        ):
-            result["models"] = self.init_model_pkg()
-        result["config"] = config_path
-        self.config_recorder[model_name] = config_path
-
-        result["train_log"] = self.update_train_log(train_output)
-        result["visualdl_log"] = self.update_vdl_log(train_output)
-        result["label_dict"] = self.update_label_dict(train_output)
-
-        model = self.get_model(result["model_name"], config_path)
-
-        params_path_list = list(
-            train_output.glob(
-                ".".join(
-                    [self.get_ith_ckp_prefix("[0-9]*"), self.get_the_pdparams_suffix()]
-                )
-            )
-        )
-        epoch_ids = []
-        for params_path in params_path_list:
-            epoch_id = self.get_epoch_id_by_pdparams_prefix(params_path.stem)
-            epoch_ids.append(epoch_id)
-        epoch_ids.sort()
-        # TODO(gaotingquan): how to avoid that the latest ckp files is being saved
-        # epoch_ids = epoch_ids[:-1]
-        for i in range(1, self.last_k + 1):
-            if len(epoch_ids) < i:
-                break
-            self.update_models(
-                result,
-                model,
-                train_output,
-                f"last_{i}",
-                self.get_ith_ckp_prefix(epoch_ids[-i]),
-            )
-        self.update_models(
-            result, model, train_output, "best", self.get_best_ckp_prefix()
-        )
-        return result
-
-    def update_models(self, result, model, train_output, model_key, ckp_prefix):
-        """update info of the models to be saved"""
-        pdparams = train_output.joinpath(
-            ".".join([ckp_prefix, self.get_the_pdparams_suffix()])
-        )
-        if pdparams.exists():
-            recorder_key = f"{train_output.name}_{model_key}"
-            if (
-                model_key != "best"
-                and recorder_key in self.model_recorder
-                and self.model_recorder[recorder_key] == pdparams
-            ):
-                return
-
-            self.model_recorder[recorder_key] = pdparams
-
-            pdema = ""
-            pdema_suffix = self.get_the_pdema_suffix()
-            if pdema_suffix:
-                pdema = pdparams.parent.joinpath(".".join([ckp_prefix, pdema_suffix]))
-                if not pdema.exists():
-                    pdema = ""
-
-            pdopt = ""
-            pdopt_suffix = self.get_the_pdopt_suffix()
-            if pdopt_suffix:
-                pdopt = pdparams.parent.joinpath(".".join([ckp_prefix, pdopt_suffix]))
-                if not pdopt.exists():
-                    pdopt = ""
-
-            pdstates = ""
-            pdstates_suffix = self.get_the_pdstates_suffix()
-            if pdstates_suffix:
-                pdstates = pdparams.parent.joinpath(
-                    ".".join([ckp_prefix, pdstates_suffix])
-                )
-                if not pdstates.exists():
-                    pdstates = ""
-
-            score = self.get_score(Path(pdstates).resolve().as_posix())
-
-            result["models"][model_key] = {
-                "score": score,
-                "pdparams": pdparams,
-                "pdema": pdema,
-                "pdopt": pdopt,
-                "pdstates": pdstates,
-            }
-
-            self.update_inference_model(
-                model,
-                pdparams,
-                train_output.joinpath(f"{ckp_prefix}"),
-                result["models"][model_key],
-            )
-
-    def update_inference_model(
-        self, model, weight_path, export_save_dir, result_the_model
-    ):
-        """update inference model"""
-        export_save_dir.mkdir(parents=True, exist_ok=True)
-        export_result = model.export(
-            weight_path=str(weight_path), save_dir=export_save_dir
-        )
-
-        if export_result.returncode == 0:
-            inference_config = export_save_dir.joinpath("inference.yml")
-            if not inference_config.exists():
-                inference_config = ""
-            use_pir = (
-                hasattr(paddle.framework, "use_pir_api")
-                and paddle.framework.use_pir_api()
-            )
-            pdmodel = (
-                export_save_dir.joinpath("inference.json")
-                if use_pir
-                else export_save_dir.joinpath("inference.pdmodel")
-            )
-            pdiparams = export_save_dir.joinpath("inference.pdiparams")
-            pdiparams_info = (
-                "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
-            )
-        else:
-            inference_config = ""
-            pdmodel = ""
-            pdiparams = ""
-            pdiparams_info = ""
-
-        result_the_model["inference_config"] = inference_config
-        result_the_model["pdmodel"] = pdmodel
-        result_the_model["pdiparams"] = pdiparams
-        result_the_model["pdiparams.info"] = pdiparams_info
-
-    def init_pre_hook(self):
-        """hook func that would be called befor init"""
-        pass
-
-    def init_post_hook(self):
-        """hook func that would be called after init"""
-        pass
-
-    @abstractmethod
-    def get_the_pdparams_suffix(self):
-        """get the suffix of pdparams file"""
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_the_pdema_suffix(self):
-        """get the suffix of pdema file"""
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_the_pdopt_suffix(self):
-        """get the suffix of pdopt file"""
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_the_pdstates_suffix(self):
-        """get the suffix of pdstates file"""
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_ith_ckp_prefix(self, epoch_id):
-        """get the prefix of the epoch_id checkpoint file"""
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_best_ckp_prefix(self):
-        """get the prefix of the best checkpoint file"""
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_score(self, pdstates_path):
-        """get the score by pdstates file"""
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_epoch_id_by_pdparams_prefix(self, pdparams_prefix):
-        """get the epoch_id by pdparams file"""
-        raise NotImplementedError

+ 6 - 5
paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py

@@ -17,6 +17,7 @@ import os
 from ...base import BaseModel
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
 from ....utils.misc import abspath
 from .config import InstanceSegConfig
 
@@ -68,7 +69,7 @@ class InstanceSegModel(BaseModel):
         if epochs_iters is not None:
             config.update_epochs(epochs_iters)
             config.update_cossch_epoch(epochs_iters)
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         config.update_device(device_type)
         if resume_path is not None:
             assert resume_path.endswith(
@@ -170,7 +171,7 @@ class InstanceSegModel(BaseModel):
         config.update_weights(weight_path)
         if batch_size is not None:
             config.update_batch_size(batch_size, "eval")
-        device_type, device_ids = self.runner.parse_device(device)
+        device_type, device_ids = parse_device(device)
         if len(device_ids) > 1:
             raise ValueError(
                 f"multi-{device_type} evaluation is not supported. Please use a single {device_type}."
@@ -229,7 +230,7 @@ class InstanceSegModel(BaseModel):
             cli_args.append(CLIArgument("--rtn_im_file", kwargs["rtn_im_file"]))
         weight_path = abspath(weight_path)
         config.update_weights(weight_path)
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         config.update_device(device_type)
         if save_dir is not None:
             save_dir = abspath(save_dir)
@@ -315,7 +316,7 @@ class InstanceSegModel(BaseModel):
         cli_args.append(CLIArgument("--image_file", input_path))
         if save_dir is not None:
             cli_args.append(CLIArgument("--output_dir", save_dir))
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         cli_args.append(CLIArgument("--device", device_type))
 
         self._assert_empty_kwargs(kwargs)
@@ -367,7 +368,7 @@ class InstanceSegModel(BaseModel):
         if epochs_iters is not None:
             cps_config.update_epochs(epochs_iters)
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             config.update_device(device_type)
         if save_dir is not None:
             save_dir = abspath(config.get_train_save_dir())

+ 20 - 11
paddlex/repo_apis/PaddleDetection_api/object_det/model.py

@@ -18,6 +18,7 @@ import json
 from ...base import BaseModel
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
 from ....utils.misc import abspath
 from ....utils import logging
 
@@ -72,7 +73,7 @@ class DetModel(BaseModel):
         if epochs_iters is not None:
             config.update_epochs(epochs_iters)
             config.update_cossch_epoch(epochs_iters)
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         config.update_device(device_type)
         if resume_path is not None:
             assert resume_path.endswith(
@@ -174,7 +175,7 @@ class DetModel(BaseModel):
         config.update_weights(weight_path)
         if batch_size is not None:
             config.update_batch_size(batch_size, "eval")
-        device_type, device_ids = self.runner.parse_device(device)
+        device_type, device_ids = parse_device(device)
         if len(device_ids) > 1:
             raise ValueError(
                 f"multi-{device_type} evaluation is not supported. Please use a single {device_type}."
@@ -233,7 +234,7 @@ class DetModel(BaseModel):
             cli_args.append(CLIArgument("--rtn_im_file", kwargs["rtn_im_file"]))
         weight_path = abspath(weight_path)
         config.update_weights(weight_path)
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         config.update_device(device_type)
         if save_dir is not None:
             save_dir = abspath(save_dir)
@@ -285,19 +286,27 @@ class DetModel(BaseModel):
         config.update({"hpi_config_path": hpi_config_path})
 
         if self.name in official_categories.keys():
-            anno_val_file = abspath(os.path.join(config.TestDataset['dataset_dir'], config.TestDataset['anno_path']))
+            anno_val_file = abspath(
+                os.path.join(
+                    config.TestDataset["dataset_dir"], config.TestDataset["anno_path"]
+                )
+            )
             if anno_val_file == None or (not os.path.isfile(anno_val_file)):
                 categories = official_categories[self.name]
-                temp_anno = {'images': [], 'annotations': [], 'categories': categories}
+                temp_anno = {"images": [], "annotations": [], "categories": categories}
                 with self._create_new_val_json_file() as anno_file:
-                    json.dump(temp_anno, open(anno_file, 'w'))
-                    config.update({"TestDataset": {"dataset_dir": '', "anno_path": anno_file}})
-                    logging.warning(f"{self.name} does not have validate annotations, use {anno_file} default instead.")
+                    json.dump(temp_anno, open(anno_file, "w"))
+                    config.update(
+                        {"TestDataset": {"dataset_dir": "", "anno_path": anno_file}}
+                    )
+                    logging.warning(
+                        f"{self.name} does not have validate annotations, use {anno_file} default instead."
+                    )
                     self._assert_empty_kwargs(kwargs)
                     with self._create_new_config_file() as config_path:
                         config.dump(config_path)
                         return self.runner.export(config_path, cli_args, None)
-                
+
         self._assert_empty_kwargs(kwargs)
 
         with self._create_new_config_file() as config_path:
@@ -333,7 +342,7 @@ class DetModel(BaseModel):
         cli_args.append(CLIArgument("--image_file", input_path))
         if save_dir is not None:
             cli_args.append(CLIArgument("--output_dir", save_dir))
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         cli_args.append(CLIArgument("--device", device_type))
 
         self._assert_empty_kwargs(kwargs)
@@ -385,7 +394,7 @@ class DetModel(BaseModel):
         if epochs_iters is not None:
             cps_config.update_epochs(epochs_iters)
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             config.update_device(device_type)
         if save_dir is not None:
             save_dir = abspath(config.get_train_save_dir())

+ 2 - 1
paddlex/repo_apis/PaddleOCR_api/table_rec/model.py

@@ -17,6 +17,7 @@ import os
 from ....utils import logging
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
 from ....utils.misc import abspath
 from ..text_rec.model import TextRecModel
 
@@ -95,7 +96,7 @@ class TableRecModel(TextRecModel):
         input_path = abspath(input_path)
         cli_args.append(CLIArgument("--image_dir", input_path))
 
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         cli_args.append(CLIArgument("--use_gpu", str(device_type == "gpu")))
 
         if save_dir is not None:

+ 2 - 1
paddlex/repo_apis/PaddleOCR_api/text_det/model.py

@@ -16,6 +16,7 @@ import os
 
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
 from ....utils.misc import abspath
 from ..text_rec.model import TextRecModel
 
@@ -51,7 +52,7 @@ class TextDetModel(TextRecModel):
         input_path = abspath(input_path)
         cli_args.append(CLIArgument("--image_dir", input_path))
 
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         cli_args.append(CLIArgument("--use_gpu", str(device_type == "gpu")))
 
         if save_dir is not None:

+ 2 - 1
paddlex/repo_apis/PaddleOCR_api/text_rec/model.py

@@ -17,6 +17,7 @@ import os
 from ...base import BaseModel
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
 from ....utils.misc import abspath
 from ....utils import logging
 
@@ -301,7 +302,7 @@ class TextRecModel(BaseModel):
         input_path = abspath(input_path)
         cli_args.append(CLIArgument("--image_dir", input_path))
 
-        device_type, _ = self.runner.parse_device(device)
+        device_type, _ = parse_device(device)
         cli_args.append(CLIArgument("--use_gpu", str(device_type == "gpu")))
 
         if save_dir is not None:

+ 8 - 7
paddlex/repo_apis/PaddleSeg_api/seg/model.py

@@ -17,6 +17,7 @@ import os
 from ...base import BaseModel
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
 from ....utils.misc import abspath
 from ....utils.download import download
 from ....utils.cache import DEFAULT_CACHE_DIR
@@ -74,7 +75,7 @@ class SegModel(BaseModel):
         # No need to handle `ips`
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         # For compatibility
@@ -217,7 +218,7 @@ class SegModel(BaseModel):
         # No need to handle `ips`
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         if amp is not None:
@@ -264,7 +265,7 @@ class SegModel(BaseModel):
         cli_args.append(CLIArgument("--image_path", input_path))
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         if save_dir is not None:
@@ -289,7 +290,7 @@ class SegModel(BaseModel):
         cli_args.append(CLIArgument("--model_path", weight_path))
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         if save_dir is not None:
@@ -341,7 +342,7 @@ class SegModel(BaseModel):
             cli_args.append(CLIArgument("--input_shape", *input_shape))
 
         try:
-            output_op = config['output_op']
+            output_op = config["output_op"]
         except:
             output_op = kwargs.pop("output_op", None)
         if output_op is not None:
@@ -392,7 +393,7 @@ class SegModel(BaseModel):
         cli_args.append(CLIArgument("--image_path", input_path))
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         if save_dir is not None:
@@ -455,7 +456,7 @@ class SegModel(BaseModel):
             train_cli_args.append(CLIArgument("--iters", epochs_iters))
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             train_cli_args.append(CLIArgument("--device", device_type))
 
         if use_vdl:

+ 5 - 4
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -17,6 +17,7 @@ import os
 from ...base import BaseModel
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
 from ....utils.misc import abspath
 from ....utils.errors import raise_unsupported_api_error
 
@@ -82,7 +83,7 @@ class TSModel(BaseModel):
             raise ValueError(f"`use_vdl`={use_vdl} is not supported.")
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         if save_dir is not None:
@@ -152,7 +153,7 @@ class TSModel(BaseModel):
 
         # No need to handle `ips`
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         if amp is not None:
@@ -198,7 +199,7 @@ class TSModel(BaseModel):
         cli_args.append(CLIArgument("--csv_path", input_path))
 
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         if save_dir is not None:
@@ -230,7 +231,7 @@ class TSModel(BaseModel):
             save_dir = abspath(os.path.join("output", "inference"))
         cli_args.append(CLIArgument("--save_dir", save_dir))
         if device is not None:
-            device_type, _ = self.runner.parse_device(device)
+            device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
 
         self._assert_empty_kwargs(kwargs)

+ 8 - 28
paddlex/repo_apis/base/model.py

@@ -36,6 +36,7 @@ from ...utils.errors import (
     UnsupportedParamError,
     raise_unsupported_api_error,
 )
+from ...utils.device import parse_device
 from ...utils.misc import CachedProperty as cached_property
 from ...utils.cache import get_cache_dir
 
@@ -404,11 +405,7 @@ configuration item, "
                 opts = self.supported_train_opts
                 if opts is not None:
                     if "device" in opts:
-                        checks.append(
-                            _CheckDevice(
-                                opts["device"], self.runner.parse_device, check_mc=True
-                            )
-                        )
+                        checks.append(_CheckDevice(opts["device"], check_mc=True))
                     if "dy2st" in opts:
                         checks.append(_CheckDy2St(opts["dy2st"]))
                     if "amp" in opts:
@@ -417,40 +414,24 @@ configuration item, "
                 opts = self.supported_evaluate_opts
                 if opts is not None:
                     if "device" in opts:
-                        checks.append(
-                            _CheckDevice(
-                                opts["device"], self.runner.parse_device, check_mc=True
-                            )
-                        )
+                        checks.append(_CheckDevice(opts["device"], check_mc=True))
                     if "amp" in opts:
                         checks.append(_CheckAMP(opts["amp"]))
             elif api_name == "predict":
                 opts = self.supported_predict_opts
                 if opts is not None:
                     if "device" in opts:
-                        checks.append(
-                            _CheckDevice(
-                                opts["device"], self.runner.parse_device, check_mc=False
-                            )
-                        )
+                        checks.append(_CheckDevice(opts["device"], check_mc=False))
             elif api_name == "infer":
                 opts = self.supported_infer_opts
                 if opts is not None:
                     if "device" in opts:
-                        checks.append(
-                            _CheckDevice(
-                                opts["device"], self.runner.parse_device, check_mc=False
-                            )
-                        )
+                        checks.append(_CheckDevice(opts["device"], check_mc=False))
             elif api_name == "compression":
                 opts = self.supported_compression_opts
                 if opts is not None:
                     if "device" in opts:
-                        checks.append(
-                            _CheckDevice(
-                                opts["device"], self.runner.parse_device, check_mc=True
-                            )
-                        )
+                        checks.append(_CheckDevice(opts["device"], check_mc=True))
             else:
                 return bnd_method
 
@@ -508,9 +489,8 @@ class _APICallArgsChecker(object):
 class _CheckDevice(_APICallArgsChecker):
     """_CheckDevice"""
 
-    def __init__(self, legal_vals, parse_device, check_mc=False):
+    def __init__(self, legal_vals, check_mc=False):
         super().__init__(legal_vals)
-        self.parse_device = parse_device
         self.check_mc = check_mc
 
     def check(self, args):
@@ -518,7 +498,7 @@ class _CheckDevice(_APICallArgsChecker):
         assert "device" in args
         device = args["device"]
         if device is not None:
-            device_type, dev_ids = self.parse_device(device)
+            device_type, dev_ids = parse_device(device)
             if not self.check_mc:
                 if device_type not in self.legal_vals:
                     raise _CheckFailed("device", device, self.legal_vals)

+ 3 - 17
paddlex/repo_apis/base/runner.py

@@ -26,6 +26,7 @@ from .utils.subprocess import run_cmd as _run_cmd, CompletedProcess
 
 from ...utils import logging
 from ...utils.misc import abspath
+from ...utils.device import parse_device
 from ...utils.flags import DRY_RUN
 from ...utils.errors import raise_unsupported_api_error, CalledProcessError
 
@@ -182,12 +183,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
         args = [self.python]
         if device is None:
             return args, None
-        device, dev_ids = self.parse_device(device)
+        device, dev_ids = parse_device(device)
         if len(dev_ids) == 0:
             return args, None
         else:
             num_devices = len(dev_ids)
-            dev_ids = ",".join(dev_ids)
+            dev_ids = ",".join([str(n) for n in dev_ids])
         if num_devices > 1:
             args.extend(["-m", "paddle.distributed.launch"])
             args.extend(["--devices", dev_ids])
@@ -209,21 +210,6 @@ class BaseRunner(metaclass=abc.ABCMeta):
             return args, new_env
         return args, None
 
-    def parse_device(self, device):
-        """parse_device"""
-        # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
-        if ":" not in device:
-            device_type, dev_ids = device, []
-        else:
-            device_type, dev_ids = device.split(":")
-            dev_ids = dev_ids.split(",")
-        if device_type not in ("cpu", "gpu", "xpu", "npu", "mlu"):
-            raise ValueError("Unsupported device type.")
-        for dev_id in dev_ids:
-            if not dev_id.isdigit():
-                raise ValueError("Device ID must be an integer.")
-        return device_type, dev_ids
-
     def run_cmd(
         self,
         cmd,

+ 56 - 22
paddlex/utils/device.py

@@ -13,20 +13,67 @@
 # limitations under the License.
 
 import os
+import GPUtil
 import lazy_paddle as paddle
 from .errors import raise_unsupported_device_error
 
 SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
 
 
-def get_device(device_cfg, using_device_number=None):
-    """get running device setting"""
-    device = device_cfg.split(":")[0]
-    assert device.lower() in SUPPORTED_DEVICE_TYPE
-    if device.lower() in ["gpu", "xpu", "npu", "mlu"]:
-        if device.lower() == "gpu" and paddle.is_compiled_with_rocm():
+def _constr_device(device_type, device_ids):
+    if device_ids:
+        device_ids = ",".join(map(str, device_ids))
+        return f"{device_type}:{device_ids}"
+    else:
+        return f"{device_type}"
+
+
+def get_default_device():
+    avail_gpus = GPUtil.getAvailable()
+    if not avail_gpus:
+        return "cpu"
+    else:
+        return _constr_device("gpu", [avail_gpus[0]])
+
+
+def parse_device(device):
+    """parse_device"""
+    # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
+    parts = device.split(":")
+    if len(parts) > 2:
+        raise ValueError(f"Invalid device: {device}")
+    if len(parts) == 1:
+        device_type, device_ids = parts[0], None
+    else:
+        device_type, device_ids = parts
+        device_ids = device_ids.split(",")
+        for device_id in device_ids:
+            if not device_id.isdigit():
+                raise ValueError(
+                    f"Device ID must be an integer. Invalid device ID: {device_id}"
+                )
+        device_ids = list(map(int, device_ids))
+    device_type = device_type.lower()
+    # raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
+    assert device_type.lower() in SUPPORTED_DEVICE_TYPE
+    return device_type, device_ids
+
+
+def update_device_num(device, num):
+    device_type, device_ids = parse_device(device)
+    if device_ids:
+        assert len(device_ids) >= num
+        return _constr_device(device_type, device_ids[:num])
+    else:
+        return _constr_device(device_type, device_ids)
+
+
+def set_env_for_device(device):
+    device_type, device_ids = parse_device(device)
+    if device_type.lower() in ["gpu", "xpu", "npu", "mlu"]:
+        if device_type.lower() == "gpu" and paddle.is_compiled_with_rocm():
             os.environ["FLAGS_conv_workspace_size_limit"] = "2000"
-        if device.lower() == "npu":
+        if device_type.lower() == "npu":
             os.environ["FLAGS_npu_jit_compile"] = "0"
             os.environ["FLAGS_use_stride_kernel"] = "0"
             os.environ["FLAGS_allocator_strategy"] = "auto_growth"
@@ -35,22 +82,9 @@ def get_device(device_cfg, using_device_number=None):
             )
             os.environ["FLAGS_npu_scale_aclnn"] = "True"
             os.environ["FLAGS_npu_split_aclnn"] = "True"
-        if device.lower() == "xpu":
+        if device_type.lower() == "xpu":
             os.environ["BKCL_FORCE_SYNC"] = "1"
             os.environ["BKCL_TIMEOUT"] = "1800"
             os.environ["FLAGS_use_stride_kernel"] = "0"
-        if device.lower() == "mlu":
+        if device_type.lower() == "mlu":
             os.environ["FLAGS_use_stride_kernel"] = "0"
-
-        if len(device_cfg.split(":")) == 2:
-            device_ids = device_cfg.split(":")[1]
-        else:
-            device_ids = 0
-
-        if using_device_number:
-            device_ids = f"{device_ids[:using_device_number]}"
-        return "{}:{}".format(device.lower(), device_ids)
-    if device.lower() == "cpu":
-        return "cpu"
-    else:
-        raise_unsupported_device_error(device, SUPPORTED_DEVICE_TYPE)