model.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import abc
  16. import inspect
  17. import functools
  18. import contextlib
  19. import tempfile
  20. import hashlib
  21. import base64
  22. from datetime import datetime, timedelta
  23. from .config import Config
  24. from .register import (get_registered_model_info, build_runner_from_model_info,
  25. build_model_from_model_info)
  26. from ...utils import flags
  27. from ...utils import logging
  28. from ...utils.errors import UnsupportedAPIError, UnsupportedParamError, raise_unsupported_api_error
  29. from ...utils.misc import CachedProperty as cached_property
  30. from ...utils.cache import get_cache_dir
  31. __all__ = ['PaddleModel', 'BaseModel']
  32. def _create_model(model_name=None, config=None):
  33. """ _create_model """
  34. if model_name is None and config is None:
  35. raise ValueError(
  36. "At least one of `model_name` and `config` must be not None.")
  37. elif model_name is not None and config is not None:
  38. if model_name != config.model_name:
  39. raise ValueError(
  40. "If both `model_name` and `config` are not None, `model_name` should be the same as \
  41. `config.model_name`.")
  42. elif model_name is None and config is not None:
  43. model_name = config.model_name
  44. try:
  45. model_info = get_registered_model_info(model_name)
  46. except KeyError as e:
  47. raise UnsupportedParamError(
  48. f"{repr(model_name)} is not a registered model name.") from e
  49. return build_model_from_model_info(model_info=model_info, config=config)
  50. PaddleModel = _create_model
  51. class BaseModel(metaclass=abc.ABCMeta):
  52. """
  53. Abstract base class of Model.
  54. Model defines how Config and Runner interact with each other. In addition,
  55. Model provides users with multiple APIs to perform model training,
  56. prediction, etc.
  57. """
  58. _API_FULL_LIST = ('train', 'evaluate', 'predict', 'export', 'infer',
  59. 'compression')
  60. _API_SUPPORTED_OPTS_KEY_PATTERN = 'supported_{api_name}_opts'
  61. def __init__(self, model_name, config=None):
  62. """
  63. Initialize the instance.
  64. Args:
  65. model_name (str): A registered model name.
  66. config (base.config.BaseConfig|None): Config. Default: None.
  67. """
  68. super().__init__()
  69. self.name = model_name
  70. self.model_info = get_registered_model_info(model_name)
  71. # NOTE: We build runner instance here by extracting runner info from model info
  72. # so that we don't have to overwrite the `__init__` method of each child class.
  73. self.runner = build_runner_from_model_info(self.model_info)
  74. if config is None:
  75. logging.warning(
  76. "We strongly discourage leaving `config` unset or setting it to None. "
  77. "Please note that when `config` is None, default settings will be used for every unspecified \
  78. configuration item, "
  79. "which may lead to unexpected result. Please make sure that this is what you intend to do."
  80. )
  81. config = Config(model_name)
  82. self.config = config
  83. self._patch_apis()
  84. @abc.abstractmethod
  85. def train(self,
  86. batch_size=None,
  87. learning_rate=None,
  88. epochs_iters=None,
  89. ips=None,
  90. device='gpu',
  91. resume_path=None,
  92. dy2st=False,
  93. amp='OFF',
  94. num_workers=None,
  95. use_vdl=True,
  96. save_dir=None,
  97. **kwargs):
  98. """
  99. Train a model.
  100. Args:
  101. batch_size (int|None): Number of samples in each mini-batch. If
  102. multiple devices are used, this is the batch size on each device.
  103. If None, use a default setting. Default: None.
  104. learning_rate (float|None): Learning rate of model training. If
  105. None, use a default setting. Default: None.
  106. epochs_iters (int|None): Total epochs or iterations of model
  107. training. If None, use a default setting. Default: None.
  108. ips (str|None): If not None, enable multi-machine training mode.
  109. `ips` specifies Paddle cluster node ips, e.g.,
  110. '192.168.0.16,192.168.0.17'. Default: None.
  111. device (str): A string that describes the device(s) to use, e.g.,
  112. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  113. resume_path (str|None): If not None, resume training from the model
  114. snapshot corresponding to the weight file `resume_path`. If
  115. None, use a default setting. Default: None.
  116. dy2st (bool): Whether to enable dynamic-to-static training.
  117. Default: False.
  118. amp (str): Optimization level to use in AMP training. Choices are
  119. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  120. num_workers (int|None): Number of subprocesses to use for data
  121. loading. If None, use a default setting. Default: None.
  122. use_vdl (bool): Whether to enable VisualDL during training.
  123. Default: True.
  124. save_dir (str|None): Directory to store model snapshots and logs. If
  125. None, use a default setting. Default: None.
  126. Returns:
  127. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  128. """
  129. raise NotImplementedError
  130. @abc.abstractmethod
  131. def evaluate(self,
  132. weight_path,
  133. batch_size=None,
  134. ips=None,
  135. device='gpu',
  136. amp='OFF',
  137. num_workers=None,
  138. **kwargs):
  139. """
  140. Evaluate a model.
  141. Args:
  142. weight_path (str): Path of the weights to initialize the model.
  143. batch_size (int|None): Number of samples in each mini-batch. If
  144. multiple devices are used, this is the batch size on each device.
  145. If None, use a default setting. Default: None.
  146. ips (str|None): If not None, enable multi-machine evaluation mode.
  147. `ips` specifies Paddle cluster node ips, e.g.,
  148. '192.168.0.16,192.168.0.17'. Default: None.
  149. device (str): A string that describes the device(s) to use, e.g.,
  150. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  151. amp (str): Optimization level to use in AMP training. Choices are
  152. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  153. num_workers (int|None): Number of subprocesses to use for data
  154. loading. If None, use a default setting. Default: None.
  155. Returns:
  156. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  157. """
  158. raise NotImplementedError
  159. @abc.abstractmethod
  160. def predict(self,
  161. weight_path,
  162. input_path,
  163. device='gpu',
  164. save_dir=None,
  165. **kwargs):
  166. """
  167. Make prediction with a pre-trained model.
  168. Args:
  169. weight_path (str): Path of the weights to initialize the model.
  170. input_path (str): Path of the input file, e.g. an image.
  171. device (str): A string that describes the device to use, e.g.,
  172. 'cpu', 'gpu'. Default: 'gpu'.
  173. save_dir (str|None): Directory to store prediction results. If None,
  174. use a default setting. Default: None.
  175. Returns:
  176. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  177. """
  178. raise NotImplementedError
  179. @abc.abstractmethod
  180. def export(self, weight_path, save_dir, **kwargs):
  181. """
  182. Export a pre-trained model.
  183. Args:
  184. weight_path (str): Path of the weights to initialize the model.
  185. save_dir (str): Directory to store the exported model.
  186. Returns:
  187. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  188. """
  189. raise NotImplementedError
  190. @abc.abstractmethod
  191. def infer(self,
  192. model_dir,
  193. input_path,
  194. device='gpu',
  195. save_dir=None,
  196. **kwargs):
  197. """
  198. Make inference with an exported inference model.
  199. Args:
  200. model_dir (str): Path of the exported inference model.
  201. input_path (str): Path of the input file, e.g. an image.
  202. device (str): A string that describes the device(s) to use, e.g.,
  203. 'cpu', 'gpu'. Default: 'gpu'.
  204. save_dir (str|None): Directory to store inference results. If None,
  205. use a default setting. Default: None.
  206. Returns:
  207. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  208. """
  209. raise NotImplementedError
  210. @abc.abstractmethod
  211. def compression(self,
  212. weight_path,
  213. batch_size=None,
  214. learning_rate=None,
  215. epochs_iters=None,
  216. device='gpu',
  217. use_vdl=True,
  218. save_dir=None,
  219. **kwargs):
  220. """
  221. Perform quantization aware training (QAT) and export the quantized
  222. model.
  223. Args:
  224. weight_path (str): Path of the weights to initialize the model.
  225. batch_size (int|None): Number of samples in each mini-batch. If
  226. multiple devices are used, this is the batch size on each
  227. device. If None, use a default setting. Default: None.
  228. learning_rate (float|None): Learning rate of QAT. If None, use a
  229. default setting. Default: None.
  230. epochs_iters (int|None): Total epochs of iterations of model
  231. training. If None, use a default setting. Default: None.
  232. device (str): A string that describes the device(s) to use, e.g.,
  233. 'cpu', 'gpu'. Default: 'gpu'.
  234. use_vdl (bool): Whether to enable VisualDL during training.
  235. Default: True.
  236. save_dir (str|None): Directory to store the results. If None, use a
  237. default setting. Default: None.
  238. Returns:
  239. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  240. """
  241. raise NotImplementedError
  242. @contextlib.contextmanager
  243. def _create_new_config_file(self):
  244. cls = self.__class__
  245. model_name = self.model_info['model_name']
  246. tag = '_'.join([cls.__name__.lower(), model_name])
  247. yaml_file_name = tag + '.yml'
  248. if not flags.DEBUG:
  249. with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
  250. path = os.path.join(td, yaml_file_name)
  251. with open(path, 'w', encoding='utf-8'):
  252. pass
  253. yield path
  254. else:
  255. path = os.path.join(get_cache_dir(), yaml_file_name)
  256. with open(path, 'w', encoding='utf-8'):
  257. pass
  258. yield path
  259. @cached_property
  260. def supported_apis(self):
  261. """ supported apis """
  262. return self.model_info.get('supported_apis', None)
  263. @cached_property
  264. def supported_train_opts(self):
  265. """ supported train opts """
  266. return self.model_info.get(
  267. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='train'), None)
  268. @cached_property
  269. def supported_evaluate_opts(self):
  270. """ supported evaluate opts """
  271. return self.model_info.get(
  272. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='evaluate'),
  273. None)
  274. @cached_property
  275. def supported_predict_opts(self):
  276. """ supported predcit opts """
  277. return self.model_info.get(
  278. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='predict'),
  279. None)
  280. @cached_property
  281. def supported_infer_opts(self):
  282. """ supported infer opts """
  283. return self.model_info.get(
  284. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='infer'), None)
  285. @cached_property
  286. def supported_compression_opts(self):
  287. """ supported copression opts """
  288. return self.model_info.get(
  289. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name='compression'),
  290. None)
  291. @cached_property
  292. def supported_dataset_types(self):
  293. """ supported dataset types """
  294. return self.model_info.get('supported_dataset_types', None)
  295. @staticmethod
  296. def _assert_empty_kwargs(kwargs):
  297. if len(kwargs) > 0:
  298. # For compatibility
  299. logging.warning(f"Unconsumed keyword arguments detected: {kwargs}.")
  300. # raise RuntimeError(
  301. # f"Unconsumed keyword arguments detected: {kwargs}.")
  302. def _patch_apis(self):
  303. def _make_unavailable(bnd_method):
  304. @functools.wraps(bnd_method)
  305. def _unavailable_api(*args, **kwargs):
  306. model_name = self.name
  307. api_name = bnd_method.__name__
  308. raise UnsupportedAPIError(
  309. f"{model_name} does not support `{api_name}`.")
  310. return _unavailable_api
  311. def _add_prechecks(bnd_method):
  312. @functools.wraps(bnd_method)
  313. def _api_with_prechecks(*args, **kwargs):
  314. sig = inspect.Signature.from_callable(bnd_method)
  315. bnd_args = sig.bind(*args, **kwargs)
  316. args_dict = bnd_args.arguments
  317. # Merge default values
  318. for p in sig.parameters.values():
  319. if p.name not in args_dict and p.default is not p.empty:
  320. args_dict[p.name] = p.default
  321. # Rely on nonlocal variable `checks`
  322. for check in checks:
  323. # We throw any unhandled exception
  324. check.check(args_dict)
  325. return bnd_method(*args, **kwargs)
  326. api_name = bnd_method.__name__
  327. checks = []
  328. # We hardcode the prechecks for each API here
  329. if api_name == 'train':
  330. opts = self.supported_train_opts
  331. if opts is not None:
  332. if 'device' in opts:
  333. checks.append(
  334. _CheckDevice(
  335. opts['device'],
  336. self.runner.parse_device,
  337. check_mc=True))
  338. if 'dy2st' in opts:
  339. checks.append(_CheckDy2St(opts['dy2st']))
  340. if 'amp' in opts:
  341. checks.append(_CheckAMP(opts['amp']))
  342. elif api_name == 'evaluate':
  343. opts = self.supported_evaluate_opts
  344. if opts is not None:
  345. if 'device' in opts:
  346. checks.append(
  347. _CheckDevice(
  348. opts['device'],
  349. self.runner.parse_device,
  350. check_mc=True))
  351. if 'amp' in opts:
  352. checks.append(_CheckAMP(opts['amp']))
  353. elif api_name == 'predict':
  354. opts = self.supported_predict_opts
  355. if opts is not None:
  356. if 'device' in opts:
  357. checks.append(
  358. _CheckDevice(
  359. opts['device'],
  360. self.runner.parse_device,
  361. check_mc=False))
  362. elif api_name == 'infer':
  363. opts = self.supported_infer_opts
  364. if opts is not None:
  365. if 'device' in opts:
  366. checks.append(
  367. _CheckDevice(
  368. opts['device'],
  369. self.runner.parse_device,
  370. check_mc=False))
  371. elif api_name == 'compression':
  372. opts = self.supported_compression_opts
  373. if opts is not None:
  374. if 'device' in opts:
  375. checks.append(
  376. _CheckDevice(
  377. opts['device'],
  378. self.runner.parse_device,
  379. check_mc=True))
  380. else:
  381. return bnd_method
  382. return _api_with_prechecks
  383. supported_apis = self.supported_apis
  384. if supported_apis is not None:
  385. avail_api_set = set(self.supported_apis)
  386. else:
  387. avail_api_set = set(self._API_FULL_LIST)
  388. for api_name in self._API_FULL_LIST:
  389. api = getattr(self, api_name)
  390. if api_name not in avail_api_set:
  391. # We decorate real API implementation with `_make_unavailable`
  392. # so that an error is always raised when the API is called.
  393. decorated_api = _make_unavailable(api)
  394. # Monkey-patch
  395. setattr(self, api_name, decorated_api)
  396. else:
  397. if flags.CHECK_OPTS:
  398. # We decorate real API implementation with `_add_prechecks`
  399. # to perform sanity and validity checks before invoking the
  400. # internal API.
  401. decorated_api = _add_prechecks(api)
  402. setattr(self, api_name, decorated_api)
  403. class _CheckFailed(Exception):
  404. """ _CheckFailed """
  405. # Allow `_CheckFailed` class to be recognized using `hasattr(exc, 'check_failed_error')`
  406. check_failed_error = True
  407. def __init__(self, arg_name, arg_val, legal_vals):
  408. self.arg_name = arg_name
  409. self.arg_val = arg_val
  410. self.legal_vals = legal_vals
  411. def __str__(self):
  412. return f"`{self.arg_name}` is expected to be one of or conforms to {self.legal_vals}, but got {self.arg_val}"
  413. class _APICallArgsChecker(object):
  414. """ _APICallArgsChecker """
  415. def __init__(self, legal_vals):
  416. super().__init__()
  417. self.legal_vals = legal_vals
  418. def check(self, args):
  419. """ check """
  420. raise NotImplementedError
  421. class _CheckDevice(_APICallArgsChecker):
  422. """ _CheckDevice """
  423. def __init__(self, legal_vals, parse_device, check_mc=False):
  424. super().__init__(legal_vals)
  425. self.parse_device = parse_device
  426. self.check_mc = check_mc
  427. def check(self, args):
  428. """ check """
  429. assert 'device' in args
  430. device = args['device']
  431. if device is not None:
  432. device_type, dev_ids = self.parse_device(device)
  433. if not self.check_mc:
  434. if device_type not in self.legal_vals:
  435. raise _CheckFailed('device', device, self.legal_vals)
  436. else:
  437. # Currently we only check multi-device settings for GPUs
  438. if device_type != 'gpu':
  439. if device_type not in self.legal_vals:
  440. raise _CheckFailed('device', device, self.legal_vals)
  441. else:
  442. n1c1_desc = f'{device_type}_n1c1'
  443. n1cx_desc = f'{device_type}_n1cx'
  444. nxcx_desc = f'{device_type}_nxcx'
  445. if len(dev_ids) <= 1:
  446. if (n1c1_desc not in self.legal_vals and
  447. n1cx_desc not in self.legal_vals and
  448. nxcx_desc not in self.legal_vals):
  449. raise _CheckFailed('device', device,
  450. self.legal_vals)
  451. else:
  452. assert 'ips' in args
  453. if args['ips'] is not None:
  454. # Multi-machine
  455. if nxcx_desc not in self.legal_vals:
  456. raise _CheckFailed('device', device,
  457. self.legal_vals)
  458. else:
  459. # Single-machine multi-device
  460. if (n1cx_desc not in self.legal_vals and
  461. nxcx_desc not in self.legal_vals):
  462. raise _CheckFailed('device', device,
  463. self.legal_vals)
  464. else:
  465. # When `device` is None, we assume that a default device that the
  466. # current model supports will be used, so we simply do nothing.
  467. pass
  468. class _CheckDy2St(_APICallArgsChecker):
  469. """ _CheckDy2St """
  470. def check(self, args):
  471. """ check """
  472. assert 'dy2st' in args
  473. dy2st = args['dy2st']
  474. if isinstance(self.legal_vals, list):
  475. assert len(self.legal_vals) == 1
  476. support_dy2st = bool(self.legal_vals[0])
  477. else:
  478. support_dy2st = bool(self.legal_vals)
  479. if dy2st is not None:
  480. if dy2st and not support_dy2st:
  481. raise _CheckFailed('dy2st', dy2st, [support_dy2st])
  482. else:
  483. pass
  484. class _CheckAMP(_APICallArgsChecker):
  485. """ _CheckAMP """
  486. def check(self, args):
  487. """ check """
  488. assert 'amp' in args
  489. amp = args['amp']
  490. if amp is not None:
  491. if amp != 'OFF' and amp not in self.legal_vals:
  492. raise _CheckFailed('amp', amp, self.legal_vals)
  493. else:
  494. pass