# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import abc import inspect import functools import contextlib import tempfile import hashlib import base64 from datetime import datetime, timedelta from .config import Config from .register import (get_registered_model_info, build_runner_from_model_info, build_model_from_model_info) from ...utils import flags from ...utils import logging from ...utils.errors import UnsupportedAPIError, UnsupportedParamError, raise_unsupported_api_error from ...utils.misc import CachedProperty as cached_property from ...utils.cache import get_cache_dir __all__ = ['PaddleModel', 'BaseModel'] def _create_model(model_name=None, config=None): """ _create_model """ if model_name is None and config is None: raise ValueError( "At least one of `model_name` and `config` must be not None.") elif model_name is not None and config is not None: if model_name != config.model_name: raise ValueError( "If both `model_name` and `config` are not None, `model_name` should be the same as \ `config.model_name`.") elif model_name is None and config is not None: model_name = config.model_name try: model_info = get_registered_model_info(model_name) except KeyError as e: raise UnsupportedParamError( f"{repr(model_name)} is not a registered model name.") from e return build_model_from_model_info(model_info=model_info, config=config) PaddleModel = _create_model class BaseModel(metaclass=abc.ABCMeta): """ Abstract base class of Model. Model defines how Config and Runner interact with each other. In addition, Model provides users with multiple APIs to perform model training, prediction, etc. """ _API_FULL_LIST = ('train', 'evaluate', 'predict', 'export', 'infer', 'compression') _API_SUPPORTED_OPTS_KEY_PATTERN = 'supported_{api_name}_opts' def __init__(self, model_name, config=None): """ Initialize the instance. Args: model_name (str): A registered model name. config (base.config.BaseConfig|None): Config. Default: None. """ super().__init__() self.name = model_name self.model_info = get_registered_model_info(model_name) # NOTE: We build runner instance here by extracting runner info from model info # so that we don't have to overwrite the `__init__` method of each child class. self.runner = build_runner_from_model_info(self.model_info) if config is None: logging.warning( "We strongly discourage leaving `config` unset or setting it to None. " "Please note that when `config` is None, default settings will be used for every unspecified \ configuration item, " "which may lead to unexpected result. Please make sure that this is what you intend to do." ) config = Config(model_name) self.config = config self._patch_apis() @abc.abstractmethod def train(self, batch_size=None, learning_rate=None, epochs_iters=None, ips=None, device='gpu', resume_path=None, dy2st=False, amp='OFF', num_workers=None, use_vdl=True, save_dir=None, **kwargs): """ Train a model. Args: batch_size (int|None): Number of samples in each mini-batch. If multiple devices are used, this is the batch size on each device. If None, use a default setting. Default: None. learning_rate (float|None): Learning rate of model training. If None, use a default setting. Default: None. epochs_iters (int|None): Total epochs or iterations of model training. If None, use a default setting. Default: None. ips (str|None): If not None, enable multi-machine training mode. `ips` specifies Paddle cluster node ips, e.g., '192.168.0.16,192.168.0.17'. Default: None. device (str): A string that describes the device(s) to use, e.g., 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'. resume_path (str|None): If not None, resume training from the model snapshot corresponding to the weight file `resume_path`. If None, use a default setting. Default: None. dy2st (bool): Whether to enable dynamic-to-static training. Default: False. amp (str): Optimization level to use in AMP training. Choices are ['O1', 'O2', 'OFF']. Default: 'OFF'. num_workers (int|None): Number of subprocesses to use for data loading. If None, use a default setting. Default: None. use_vdl (bool): Whether to enable VisualDL during training. Default: True. save_dir (str|None): Directory to store model snapshots and logs. If None, use a default setting. Default: None. Returns: paddlex.repo_apis.base.utils.subprocess.CompletedProcess """ raise NotImplementedError @abc.abstractmethod def evaluate(self, weight_path, batch_size=None, ips=None, device='gpu', amp='OFF', num_workers=None, **kwargs): """ Evaluate a model. Args: weight_path (str): Path of the weights to initialize the model. batch_size (int|None): Number of samples in each mini-batch. If multiple devices are used, this is the batch size on each device. If None, use a default setting. Default: None. ips (str|None): If not None, enable multi-machine evaluation mode. `ips` specifies Paddle cluster node ips, e.g., '192.168.0.16,192.168.0.17'. Default: None. device (str): A string that describes the device(s) to use, e.g., 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'. amp (str): Optimization level to use in AMP training. Choices are ['O1', 'O2', 'OFF']. Default: 'OFF'. num_workers (int|None): Number of subprocesses to use for data loading. If None, use a default setting. Default: None. Returns: paddlex.repo_apis.base.utils.subprocess.CompletedProcess """ raise NotImplementedError @abc.abstractmethod def predict(self, weight_path, input_path, device='gpu', save_dir=None, **kwargs): """ Make prediction with a pre-trained model. Args: weight_path (str): Path of the weights to initialize the model. input_path (str): Path of the input file, e.g. an image. device (str): A string that describes the device to use, e.g., 'cpu', 'gpu'. Default: 'gpu'. save_dir (str|None): Directory to store prediction results. If None, use a default setting. Default: None. Returns: paddlex.repo_apis.base.utils.subprocess.CompletedProcess """ raise NotImplementedError @abc.abstractmethod def export(self, weight_path, save_dir, **kwargs): """ Export a pre-trained model. Args: weight_path (str): Path of the weights to initialize the model. save_dir (str): Directory to store the exported model. Returns: paddlex.repo_apis.base.utils.subprocess.CompletedProcess """ raise NotImplementedError @abc.abstractmethod def infer(self, model_dir, input_path, device='gpu', save_dir=None, **kwargs): """ Make inference with an exported inference model. Args: model_dir (str): Path of the exported inference model. input_path (str): Path of the input file, e.g. an image. device (str): A string that describes the device(s) to use, e.g., 'cpu', 'gpu'. Default: 'gpu'. save_dir (str|None): Directory to store inference results. If None, use a default setting. Default: None. Returns: paddlex.repo_apis.base.utils.subprocess.CompletedProcess """ raise NotImplementedError @abc.abstractmethod def compression(self, weight_path, batch_size=None, learning_rate=None, epochs_iters=None, device='gpu', use_vdl=True, save_dir=None, **kwargs): """ Perform quantization aware training (QAT) and export the quantized model. Args: weight_path (str): Path of the weights to initialize the model. batch_size (int|None): Number of samples in each mini-batch. If multiple devices are used, this is the batch size on each device. If None, use a default setting. Default: None. learning_rate (float|None): Learning rate of QAT. If None, use a default setting. Default: None. epochs_iters (int|None): Total epochs of iterations of model training. If None, use a default setting. Default: None. device (str): A string that describes the device(s) to use, e.g., 'cpu', 'gpu'. Default: 'gpu'. use_vdl (bool): Whether to enable VisualDL during training. Default: True. save_dir (str|None): Directory to store the results. If None, use a default setting. Default: None. Returns: tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess] """ raise NotImplementedError @contextlib.contextmanager def _create_new_config_file(self): cls = self.__class__ model_name = self.model_info['model_name'] tag = '_'.join([cls.__name__.lower(), model_name]) yaml_file_name = tag + '.yml' if not flags.DEBUG: with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td: path = os.path.join(td, yaml_file_name) with open(path, 'w', encoding='utf-8'): pass yield path else: path = os.path.join(get_cache_dir(), yaml_file_name) with open(path, 'w', encoding='utf-8'): pass yield path @cached_property def supported_apis(self): """ supported apis """ return self.model_info.get('supported_apis', None) @cached_property def supported_train_opts(self): """ supported train opts """ return self.model_info.get( self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='train'), None) @cached_property def supported_evaluate_opts(self): """ supported evaluate opts """ return self.model_info.get( self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='evaluate'), None) @cached_property def supported_predict_opts(self): """ supported predcit opts """ return self.model_info.get( self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='predict'), None) @cached_property def supported_infer_opts(self): """ supported infer opts """ return self.model_info.get( self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='infer'), None) @cached_property def supported_compression_opts(self): """ supported copression opts """ return self.model_info.get( self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='compression'), None) @cached_property def supported_dataset_types(self): """ supported dataset types """ return self.model_info.get('supported_dataset_types', None) @staticmethod def _assert_empty_kwargs(kwargs): if len(kwargs) > 0: # For compatibility logging.warning(f"Unconsumed keyword arguments detected: {kwargs}.") # raise RuntimeError( # f"Unconsumed keyword arguments detected: {kwargs}.") def _patch_apis(self): def _make_unavailable(bnd_method): @functools.wraps(bnd_method) def _unavailable_api(*args, **kwargs): model_name = self.name api_name = bnd_method.__name__ raise UnsupportedAPIError( f"{model_name} does not support `{api_name}`.") return _unavailable_api def _add_prechecks(bnd_method): @functools.wraps(bnd_method) def _api_with_prechecks(*args, **kwargs): sig = inspect.Signature.from_callable(bnd_method) bnd_args = sig.bind(*args, **kwargs) args_dict = bnd_args.arguments # Merge default values for p in sig.parameters.values(): if p.name not in args_dict and p.default is not p.empty: args_dict[p.name] = p.default # Rely on nonlocal variable `checks` for check in checks: # We throw any unhandled exception check.check(args_dict) return bnd_method(*args, **kwargs) api_name = bnd_method.__name__ checks = [] # We hardcode the prechecks for each API here if api_name == 'train': opts = self.supported_train_opts if opts is not None: if 'device' in opts: checks.append( _CheckDevice( opts['device'], self.runner.parse_device, check_mc=True)) if 'dy2st' in opts: checks.append(_CheckDy2St(opts['dy2st'])) if 'amp' in opts: checks.append(_CheckAMP(opts['amp'])) elif api_name == 'evaluate': opts = self.supported_evaluate_opts if opts is not None: if 'device' in opts: checks.append( _CheckDevice( opts['device'], self.runner.parse_device, check_mc=True)) if 'amp' in opts: checks.append(_CheckAMP(opts['amp'])) elif api_name == 'predict': opts = self.supported_predict_opts if opts is not None: if 'device' in opts: checks.append( _CheckDevice( opts['device'], self.runner.parse_device, check_mc=False)) elif api_name == 'infer': opts = self.supported_infer_opts if opts is not None: if 'device' in opts: checks.append( _CheckDevice( opts['device'], self.runner.parse_device, check_mc=False)) elif api_name == 'compression': opts = self.supported_compression_opts if opts is not None: if 'device' in opts: checks.append( _CheckDevice( opts['device'], self.runner.parse_device, check_mc=True)) else: return bnd_method return _api_with_prechecks supported_apis = self.supported_apis if supported_apis is not None: avail_api_set = set(self.supported_apis) else: avail_api_set = set(self._API_FULL_LIST) for api_name in self._API_FULL_LIST: api = getattr(self, api_name) if api_name not in avail_api_set: # We decorate real API implementation with `_make_unavailable` # so that an error is always raised when the API is called. decorated_api = _make_unavailable(api) # Monkey-patch setattr(self, api_name, decorated_api) else: if flags.CHECK_OPTS: # We decorate real API implementation with `_add_prechecks` # to perform sanity and validity checks before invoking the # internal API. decorated_api = _add_prechecks(api) setattr(self, api_name, decorated_api) class _CheckFailed(Exception): """ _CheckFailed """ # Allow `_CheckFailed` class to be recognized using `hasattr(exc, 'check_failed_error')` check_failed_error = True def __init__(self, arg_name, arg_val, legal_vals): self.arg_name = arg_name self.arg_val = arg_val self.legal_vals = legal_vals def __str__(self): return f"`{self.arg_name}` is expected to be one of or conforms to {self.legal_vals}, but got {self.arg_val}" class _APICallArgsChecker(object): """ _APICallArgsChecker """ def __init__(self, legal_vals): super().__init__() self.legal_vals = legal_vals def check(self, args): """ check """ raise NotImplementedError class _CheckDevice(_APICallArgsChecker): """ _CheckDevice """ def __init__(self, legal_vals, parse_device, check_mc=False): super().__init__(legal_vals) self.parse_device = parse_device self.check_mc = check_mc def check(self, args): """ check """ assert 'device' in args device = args['device'] if device is not None: device_type, dev_ids = self.parse_device(device) if not self.check_mc: if device_type not in self.legal_vals: raise _CheckFailed('device', device, self.legal_vals) else: # Currently we only check multi-device settings for GPUs if device_type != 'gpu': if device_type not in self.legal_vals: raise _CheckFailed('device', device, self.legal_vals) else: n1c1_desc = f'{device_type}_n1c1' n1cx_desc = f'{device_type}_n1cx' nxcx_desc = f'{device_type}_nxcx' if len(dev_ids) <= 1: if (n1c1_desc not in self.legal_vals and n1cx_desc not in self.legal_vals and nxcx_desc not in self.legal_vals): raise _CheckFailed('device', device, self.legal_vals) else: assert 'ips' in args if args['ips'] is not None: # Multi-machine if nxcx_desc not in self.legal_vals: raise _CheckFailed('device', device, self.legal_vals) else: # Single-machine multi-device if (n1cx_desc not in self.legal_vals and nxcx_desc not in self.legal_vals): raise _CheckFailed('device', device, self.legal_vals) else: # When `device` is None, we assume that a default device that the # current model supports will be used, so we simply do nothing. pass class _CheckDy2St(_APICallArgsChecker): """ _CheckDy2St """ def check(self, args): """ check """ assert 'dy2st' in args dy2st = args['dy2st'] if isinstance(self.legal_vals, list): assert len(self.legal_vals) == 1 support_dy2st = bool(self.legal_vals[0]) else: support_dy2st = bool(self.legal_vals) if dy2st is not None: if dy2st and not support_dy2st: raise _CheckFailed('dy2st', dy2st, [support_dy2st]) else: pass class _CheckAMP(_APICallArgsChecker): """ _CheckAMP """ def check(self, args): """ check """ assert 'amp' in args amp = args['amp'] if amp is not None: if amp != 'OFF' and amp not in self.legal_vals: raise _CheckFailed('amp', amp, self.legal_vals) else: pass