| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import abc
- import inspect
- import functools
- import contextlib
- import tempfile
- import hashlib
- import base64
- from datetime import datetime, timedelta
- from .config import Config
- from .register import (
- get_registered_model_info,
- build_runner_from_model_info,
- build_model_from_model_info,
- )
- from ...utils import flags
- from ...utils import logging
- from ...utils.errors import (
- UnsupportedAPIError,
- UnsupportedParamError,
- raise_unsupported_api_error,
- )
- from ...utils.device import parse_device
- from ...utils.misc import CachedProperty as cached_property
- from ...utils.cache import get_cache_dir
- __all__ = ["PaddleModel", "BaseModel"]
- def _create_model(model_name=None, config=None):
- """_create_model"""
- if model_name is None and config is None:
- raise ValueError("At least one of `model_name` and `config` must be not None.")
- elif model_name is not None and config is not None:
- if model_name != config.model_name:
- raise ValueError(
- "If both `model_name` and `config` are not None, `model_name` should be the same as \
- `config.model_name`."
- )
- elif model_name is None and config is not None:
- model_name = config.model_name
- try:
- model_info = get_registered_model_info(model_name)
- except KeyError as e:
- raise UnsupportedParamError(
- f"{repr(model_name)} is not a registered model name."
- ) from e
- return build_model_from_model_info(model_info=model_info, config=config)
- PaddleModel = _create_model
- class BaseModel(metaclass=abc.ABCMeta):
- """
- Abstract base class of Model.
- Model defines how Config and Runner interact with each other. In addition,
- Model provides users with multiple APIs to perform model training,
- prediction, etc.
- """
- _API_FULL_LIST = ("train", "evaluate", "predict", "export", "infer", "compression")
- _API_SUPPORTED_OPTS_KEY_PATTERN = "supported_{api_name}_opts"
- def __init__(self, model_name, config=None):
- """
- Initialize the instance.
- Args:
- model_name (str): A registered model name.
- config (base.config.BaseConfig|None): Config. Default: None.
- """
- super().__init__()
- self.name = model_name
- self.model_info = get_registered_model_info(model_name)
- # NOTE: We build runner instance here by extracting runner info from model info
- # so that we don't have to overwrite the `__init__` method of each child class.
- self.runner = build_runner_from_model_info(self.model_info)
- if config is None:
- logging.warning(
- "We strongly discourage leaving `config` unset or setting it to None. "
- "Please note that when `config` is None, default settings will be used for every unspecified \
- configuration item, "
- "which may lead to unexpected result. Please make sure that this is what you intend to do."
- )
- config = Config(model_name)
- self.config = config
- self._patch_apis()
- @abc.abstractmethod
- def train(
- self,
- batch_size=None,
- learning_rate=None,
- epochs_iters=None,
- ips=None,
- device="gpu",
- resume_path=None,
- dy2st=False,
- amp="OFF",
- num_workers=None,
- use_vdl=True,
- save_dir=None,
- **kwargs,
- ):
- """
- Train a model.
- Args:
- batch_size (int|None): Number of samples in each mini-batch. If
- multiple devices are used, this is the batch size on each device.
- If None, use a default setting. Default: None.
- learning_rate (float|None): Learning rate of model training. If
- None, use a default setting. Default: None.
- epochs_iters (int|None): Total epochs or iterations of model
- training. If None, use a default setting. Default: None.
- ips (str|None): If not None, enable multi-machine training mode.
- `ips` specifies Paddle cluster node ips, e.g.,
- '192.168.0.16,192.168.0.17'. Default: None.
- device (str): A string that describes the device(s) to use, e.g.,
- 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
- resume_path (str|None): If not None, resume training from the model
- snapshot corresponding to the weight file `resume_path`. If
- None, use a default setting. Default: None.
- dy2st (bool): Whether to enable dynamic-to-static training.
- Default: False.
- amp (str): Optimization level to use in AMP training. Choices are
- ['O1', 'O2', 'OFF']. Default: 'OFF'.
- num_workers (int|None): Number of subprocesses to use for data
- loading. If None, use a default setting. Default: None.
- use_vdl (bool): Whether to enable VisualDL during training.
- Default: True.
- save_dir (str|None): Directory to store model snapshots and logs. If
- None, use a default setting. Default: None.
- Returns:
- paddlex.repo_apis.base.utils.subprocess.CompletedProcess
- """
- raise NotImplementedError
- @abc.abstractmethod
- def evaluate(
- self,
- weight_path,
- batch_size=None,
- ips=None,
- device="gpu",
- amp="OFF",
- num_workers=None,
- **kwargs,
- ):
- """
- Evaluate a model.
- Args:
- weight_path (str): Path of the weights to initialize the model.
- batch_size (int|None): Number of samples in each mini-batch. If
- multiple devices are used, this is the batch size on each device.
- If None, use a default setting. Default: None.
- ips (str|None): If not None, enable multi-machine evaluation mode.
- `ips` specifies Paddle cluster node ips, e.g.,
- '192.168.0.16,192.168.0.17'. Default: None.
- device (str): A string that describes the device(s) to use, e.g.,
- 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
- amp (str): Optimization level to use in AMP training. Choices are
- ['O1', 'O2', 'OFF']. Default: 'OFF'.
- num_workers (int|None): Number of subprocesses to use for data
- loading. If None, use a default setting. Default: None.
- Returns:
- paddlex.repo_apis.base.utils.subprocess.CompletedProcess
- """
- raise NotImplementedError
- @abc.abstractmethod
- def predict(self, weight_path, input_path, device="gpu", save_dir=None, **kwargs):
- """
- Make prediction with a pre-trained model.
- Args:
- weight_path (str): Path of the weights to initialize the model.
- input_path (str): Path of the input file, e.g. an image.
- device (str): A string that describes the device to use, e.g.,
- 'cpu', 'gpu'. Default: 'gpu'.
- save_dir (str|None): Directory to store prediction results. If None,
- use a default setting. Default: None.
- Returns:
- paddlex.repo_apis.base.utils.subprocess.CompletedProcess
- """
- raise NotImplementedError
- @abc.abstractmethod
- def export(self, weight_path, save_dir, **kwargs):
- """
- Export a pre-trained model.
- Args:
- weight_path (str): Path of the weights to initialize the model.
- save_dir (str): Directory to store the exported model.
- Returns:
- paddlex.repo_apis.base.utils.subprocess.CompletedProcess
- """
- raise NotImplementedError
- @abc.abstractmethod
- def infer(self, model_dir, input_path, device="gpu", save_dir=None, **kwargs):
- """
- Make inference with an exported inference model.
- Args:
- model_dir (str): Path of the exported inference model.
- input_path (str): Path of the input file, e.g. an image.
- device (str): A string that describes the device(s) to use, e.g.,
- 'cpu', 'gpu'. Default: 'gpu'.
- save_dir (str|None): Directory to store inference results. If None,
- use a default setting. Default: None.
- Returns:
- paddlex.repo_apis.base.utils.subprocess.CompletedProcess
- """
- raise NotImplementedError
- @abc.abstractmethod
- def compression(
- self,
- weight_path,
- batch_size=None,
- learning_rate=None,
- epochs_iters=None,
- device="gpu",
- use_vdl=True,
- save_dir=None,
- **kwargs,
- ):
- """
- Perform quantization aware training (QAT) and export the quantized
- model.
- Args:
- weight_path (str): Path of the weights to initialize the model.
- batch_size (int|None): Number of samples in each mini-batch. If
- multiple devices are used, this is the batch size on each
- device. If None, use a default setting. Default: None.
- learning_rate (float|None): Learning rate of QAT. If None, use a
- default setting. Default: None.
- epochs_iters (int|None): Total epochs of iterations of model
- training. If None, use a default setting. Default: None.
- device (str): A string that describes the device(s) to use, e.g.,
- 'cpu', 'gpu'. Default: 'gpu'.
- use_vdl (bool): Whether to enable VisualDL during training.
- Default: True.
- save_dir (str|None): Directory to store the results. If None, use a
- default setting. Default: None.
- Returns:
- tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
- """
- raise NotImplementedError
- @contextlib.contextmanager
- def _create_new_config_file(self):
- cls = self.__class__
- model_name = self.model_info["model_name"]
- tag = "_".join([cls.__name__.lower(), model_name])
- yaml_file_name = tag + ".yml"
- if not flags.DEBUG:
- with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
- path = os.path.join(td, yaml_file_name)
- with open(path, "w", encoding="utf-8"):
- pass
- yield path
- else:
- path = os.path.join(get_cache_dir(), yaml_file_name)
- with open(path, "w", encoding="utf-8"):
- pass
- yield path
- @contextlib.contextmanager
- def _create_new_val_json_file(self):
- cls = self.__class__
- model_name = self.model_info["model_name"]
- tag = "_".join([cls.__name__.lower(), model_name])
- json_file_name = tag + "_test.json"
- if not flags.DEBUG:
- with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
- path = os.path.join(td, json_file_name)
- with open(path, "w", encoding="utf-8"):
- pass
- yield path
- else:
- path = os.path.join(get_cache_dir(), json_file_name)
- with open(path, "w", encoding="utf-8"):
- pass
- yield path
- @cached_property
- def supported_apis(self):
- """supported apis"""
- return self.model_info.get("supported_apis", None)
- @cached_property
- def supported_train_opts(self):
- """supported train opts"""
- return self.model_info.get(
- self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="train"), None
- )
- @cached_property
- def supported_evaluate_opts(self):
- """supported evaluate opts"""
- return self.model_info.get(
- self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="evaluate"), None
- )
- @cached_property
- def supported_predict_opts(self):
- """supported predcit opts"""
- return self.model_info.get(
- self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="predict"), None
- )
- @cached_property
- def supported_infer_opts(self):
- """supported infer opts"""
- return self.model_info.get(
- self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="infer"), None
- )
- @cached_property
- def supported_compression_opts(self):
- """supported copression opts"""
- return self.model_info.get(
- self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="compression"), None
- )
- @cached_property
- def supported_dataset_types(self):
- """supported dataset types"""
- return self.model_info.get("supported_dataset_types", None)
- @staticmethod
- def _assert_empty_kwargs(kwargs):
- if len(kwargs) > 0:
- # For compatibility
- logging.warning(f"Unconsumed keyword arguments detected: {kwargs}.")
- # raise RuntimeError(
- # f"Unconsumed keyword arguments detected: {kwargs}.")
- def _patch_apis(self):
- def _make_unavailable(bnd_method):
- @functools.wraps(bnd_method)
- def _unavailable_api(*args, **kwargs):
- model_name = self.name
- api_name = bnd_method.__name__
- raise UnsupportedAPIError(
- f"{model_name} does not support `{api_name}`."
- )
- return _unavailable_api
- def _add_prechecks(bnd_method):
- @functools.wraps(bnd_method)
- def _api_with_prechecks(*args, **kwargs):
- sig = inspect.Signature.from_callable(bnd_method)
- bnd_args = sig.bind(*args, **kwargs)
- args_dict = bnd_args.arguments
- # Merge default values
- for p in sig.parameters.values():
- if p.name not in args_dict and p.default is not p.empty:
- args_dict[p.name] = p.default
- # Rely on nonlocal variable `checks`
- for check in checks:
- # We throw any unhandled exception
- check.check(args_dict)
- return bnd_method(*args, **kwargs)
- api_name = bnd_method.__name__
- checks = []
- # We hardcode the prechecks for each API here
- if api_name == "train":
- opts = self.supported_train_opts
- if opts is not None:
- if "device" in opts:
- checks.append(_CheckDevice(opts["device"], check_mc=True))
- if "dy2st" in opts:
- checks.append(_CheckDy2St(opts["dy2st"]))
- if "amp" in opts:
- checks.append(_CheckAMP(opts["amp"]))
- elif api_name == "evaluate":
- opts = self.supported_evaluate_opts
- if opts is not None:
- if "device" in opts:
- checks.append(_CheckDevice(opts["device"], check_mc=True))
- if "amp" in opts:
- checks.append(_CheckAMP(opts["amp"]))
- elif api_name == "predict":
- opts = self.supported_predict_opts
- if opts is not None:
- if "device" in opts:
- checks.append(_CheckDevice(opts["device"], check_mc=False))
- elif api_name == "infer":
- opts = self.supported_infer_opts
- if opts is not None:
- if "device" in opts:
- checks.append(_CheckDevice(opts["device"], check_mc=False))
- elif api_name == "compression":
- opts = self.supported_compression_opts
- if opts is not None:
- if "device" in opts:
- checks.append(_CheckDevice(opts["device"], check_mc=True))
- else:
- return bnd_method
- return _api_with_prechecks
- supported_apis = self.supported_apis
- if supported_apis is not None:
- avail_api_set = set(self.supported_apis)
- else:
- avail_api_set = set(self._API_FULL_LIST)
- for api_name in self._API_FULL_LIST:
- api = getattr(self, api_name)
- if api_name not in avail_api_set:
- # We decorate real API implementation with `_make_unavailable`
- # so that an error is always raised when the API is called.
- decorated_api = _make_unavailable(api)
- # Monkey-patch
- setattr(self, api_name, decorated_api)
- else:
- if flags.CHECK_OPTS:
- # We decorate real API implementation with `_add_prechecks`
- # to perform sanity and validity checks before invoking the
- # internal API.
- decorated_api = _add_prechecks(api)
- setattr(self, api_name, decorated_api)
- class _CheckFailed(Exception):
- """_CheckFailed"""
- # Allow `_CheckFailed` class to be recognized using `hasattr(exc, 'check_failed_error')`
- check_failed_error = True
- def __init__(self, arg_name, arg_val, legal_vals):
- self.arg_name = arg_name
- self.arg_val = arg_val
- self.legal_vals = legal_vals
- def __str__(self):
- return f"`{self.arg_name}` is expected to be one of or conforms to {self.legal_vals}, but got {self.arg_val}"
- class _APICallArgsChecker(object):
- """_APICallArgsChecker"""
- def __init__(self, legal_vals):
- super().__init__()
- self.legal_vals = legal_vals
- def check(self, args):
- """check"""
- raise NotImplementedError
- class _CheckDevice(_APICallArgsChecker):
- """_CheckDevice"""
- def __init__(self, legal_vals, check_mc=False):
- super().__init__(legal_vals)
- self.check_mc = check_mc
- def check(self, args):
- """check"""
- assert "device" in args
- device = args["device"]
- if device is not None:
- device_type, dev_ids = parse_device(device)
- if not self.check_mc:
- if device_type not in self.legal_vals:
- raise _CheckFailed("device", device, self.legal_vals)
- else:
- # Currently we only check multi-device settings for GPUs
- if device_type != "gpu":
- if device_type not in self.legal_vals:
- raise _CheckFailed("device", device, self.legal_vals)
- else:
- n1c1_desc = f"{device_type}_n1c1"
- n1cx_desc = f"{device_type}_n1cx"
- nxcx_desc = f"{device_type}_nxcx"
- if len(dev_ids) <= 1:
- if (
- n1c1_desc not in self.legal_vals
- and n1cx_desc not in self.legal_vals
- and nxcx_desc not in self.legal_vals
- ):
- raise _CheckFailed("device", device, self.legal_vals)
- else:
- assert "ips" in args
- if args["ips"] is not None:
- # Multi-machine
- if nxcx_desc not in self.legal_vals:
- raise _CheckFailed("device", device, self.legal_vals)
- else:
- # Single-machine multi-device
- if (
- n1cx_desc not in self.legal_vals
- and nxcx_desc not in self.legal_vals
- ):
- raise _CheckFailed("device", device, self.legal_vals)
- else:
- # When `device` is None, we assume that a default device that the
- # current model supports will be used, so we simply do nothing.
- pass
- class _CheckDy2St(_APICallArgsChecker):
- """_CheckDy2St"""
- def check(self, args):
- """check"""
- assert "dy2st" in args
- dy2st = args["dy2st"]
- if isinstance(self.legal_vals, list):
- assert len(self.legal_vals) == 1
- support_dy2st = bool(self.legal_vals[0])
- else:
- support_dy2st = bool(self.legal_vals)
- if dy2st is not None:
- if dy2st and not support_dy2st:
- raise _CheckFailed("dy2st", dy2st, [support_dy2st])
- else:
- pass
- class _CheckAMP(_APICallArgsChecker):
- """_CheckAMP"""
- def check(self, args):
- """check"""
- assert "amp" in args
- amp = args["amp"]
- if amp is not None:
- if amp != "OFF" and amp not in self.legal_vals:
- raise _CheckFailed("amp", amp, self.legal_vals)
- else:
- pass
|