model.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import contextlib
  16. import functools
  17. import inspect
  18. import os
  19. import tempfile
  20. from ...utils import flags, logging
  21. from ...utils.cache import get_cache_dir
  22. from ...utils.device import parse_device
  23. from ...utils.errors import UnsupportedAPIError, UnsupportedParamError
  24. from ...utils.misc import CachedProperty as cached_property
  25. from .config import Config
  26. from .register import (
  27. build_model_from_model_info,
  28. build_runner_from_model_info,
  29. get_registered_model_info,
  30. )
  31. __all__ = ["PaddleModel", "BaseModel"]
  32. def _create_model(model_name=None, config=None):
  33. """_create_model"""
  34. if model_name is None and config is None:
  35. raise ValueError("At least one of `model_name` and `config` must be not None.")
  36. elif model_name is not None and config is not None:
  37. if model_name != config.model_name:
  38. raise ValueError(
  39. "If both `model_name` and `config` are not None, `model_name` should be the same as \
  40. `config.model_name`."
  41. )
  42. elif model_name is None and config is not None:
  43. model_name = config.model_name
  44. try:
  45. model_info = get_registered_model_info(model_name)
  46. except KeyError as e:
  47. raise UnsupportedParamError(
  48. f"{repr(model_name)} is not a registered model name."
  49. ) from e
  50. return build_model_from_model_info(model_info=model_info, config=config)
  51. PaddleModel = _create_model
  52. class BaseModel(metaclass=abc.ABCMeta):
  53. """
  54. Abstract base class of Model.
  55. Model defines how Config and Runner interact with each other. In addition,
  56. Model provides users with multiple APIs to perform model training,
  57. prediction, etc.
  58. """
  59. _API_FULL_LIST = ("train", "evaluate", "predict", "export", "infer", "compression")
  60. _API_SUPPORTED_OPTS_KEY_PATTERN = "supported_{api_name}_opts"
  61. def __init__(self, model_name, config=None):
  62. """
  63. Initialize the instance.
  64. Args:
  65. model_name (str): A registered model name.
  66. config (base.config.BaseConfig|None): Config. Default: None.
  67. """
  68. super().__init__()
  69. self.name = model_name
  70. self.model_info = get_registered_model_info(model_name)
  71. # NOTE: We build runner instance here by extracting runner info from model info
  72. # so that we don't have to overwrite the `__init__` method of each child class.
  73. self.runner = build_runner_from_model_info(self.model_info)
  74. if config is None:
  75. logging.warning(
  76. "We strongly discourage leaving `config` unset or setting it to None. "
  77. "Please note that when `config` is None, default settings will be used for every unspecified \
  78. configuration item, "
  79. "which may lead to unexpected result. Please make sure that this is what you intend to do."
  80. )
  81. config = Config(model_name)
  82. self.config = config
  83. self._patch_apis()
  84. @abc.abstractmethod
  85. def train(
  86. self,
  87. batch_size=None,
  88. learning_rate=None,
  89. epochs_iters=None,
  90. ips=None,
  91. device="gpu",
  92. resume_path=None,
  93. dy2st=False,
  94. amp="OFF",
  95. num_workers=None,
  96. use_vdl=True,
  97. save_dir=None,
  98. **kwargs,
  99. ):
  100. """
  101. Train a model.
  102. Args:
  103. batch_size (int|None): Number of samples in each mini-batch. If
  104. multiple devices are used, this is the batch size on each device.
  105. If None, use a default setting. Default: None.
  106. learning_rate (float|None): Learning rate of model training. If
  107. None, use a default setting. Default: None.
  108. epochs_iters (int|None): Total epochs or iterations of model
  109. training. If None, use a default setting. Default: None.
  110. ips (str|None): If not None, enable multi-machine training mode.
  111. `ips` specifies Paddle cluster node ips, e.g.,
  112. '192.168.0.16,192.168.0.17'. Default: None.
  113. device (str): A string that describes the device(s) to use, e.g.,
  114. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  115. resume_path (str|None): If not None, resume training from the model
  116. snapshot corresponding to the weight file `resume_path`. If
  117. None, use a default setting. Default: None.
  118. dy2st (bool): Whether to enable dynamic-to-static training.
  119. Default: False.
  120. amp (str): Optimization level to use in AMP training. Choices are
  121. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  122. num_workers (int|None): Number of subprocesses to use for data
  123. loading. If None, use a default setting. Default: None.
  124. use_vdl (bool): Whether to enable VisualDL during training.
  125. Default: True.
  126. save_dir (str|None): Directory to store model snapshots and logs. If
  127. None, use a default setting. Default: None.
  128. Returns:
  129. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  130. """
  131. raise NotImplementedError
  132. @abc.abstractmethod
  133. def evaluate(
  134. self,
  135. weight_path,
  136. batch_size=None,
  137. ips=None,
  138. device="gpu",
  139. amp="OFF",
  140. num_workers=None,
  141. **kwargs,
  142. ):
  143. """
  144. Evaluate a model.
  145. Args:
  146. weight_path (str): Path of the weights to initialize the model.
  147. batch_size (int|None): Number of samples in each mini-batch. If
  148. multiple devices are used, this is the batch size on each device.
  149. If None, use a default setting. Default: None.
  150. ips (str|None): If not None, enable multi-machine evaluation mode.
  151. `ips` specifies Paddle cluster node ips, e.g.,
  152. '192.168.0.16,192.168.0.17'. Default: None.
  153. device (str): A string that describes the device(s) to use, e.g.,
  154. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  155. amp (str): Optimization level to use in AMP training. Choices are
  156. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  157. num_workers (int|None): Number of subprocesses to use for data
  158. loading. If None, use a default setting. Default: None.
  159. Returns:
  160. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  161. """
  162. raise NotImplementedError
  163. @abc.abstractmethod
  164. def predict(self, weight_path, input_path, device="gpu", save_dir=None, **kwargs):
  165. """
  166. Make prediction with a pre-trained model.
  167. Args:
  168. weight_path (str): Path of the weights to initialize the model.
  169. input_path (str): Path of the input file, e.g. an image.
  170. device (str): A string that describes the device to use, e.g.,
  171. 'cpu', 'gpu'. Default: 'gpu'.
  172. save_dir (str|None): Directory to store prediction results. If None,
  173. use a default setting. Default: None.
  174. Returns:
  175. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  176. """
  177. raise NotImplementedError
  178. @abc.abstractmethod
  179. def export(self, weight_path, save_dir, **kwargs):
  180. """
  181. Export a pre-trained model.
  182. Args:
  183. weight_path (str): Path of the weights to initialize the model.
  184. save_dir (str): Directory to store the exported model.
  185. Returns:
  186. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  187. """
  188. raise NotImplementedError
  189. @abc.abstractmethod
  190. def infer(self, model_dir, input_path, device="gpu", save_dir=None, **kwargs):
  191. """
  192. Make inference with an exported inference model.
  193. Args:
  194. model_dir (str): Path of the exported inference model.
  195. input_path (str): Path of the input file, e.g. an image.
  196. device (str): A string that describes the device(s) to use, e.g.,
  197. 'cpu', 'gpu'. Default: 'gpu'.
  198. save_dir (str|None): Directory to store inference results. If None,
  199. use a default setting. Default: None.
  200. Returns:
  201. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  202. """
  203. raise NotImplementedError
  204. @abc.abstractmethod
  205. def compression(
  206. self,
  207. weight_path,
  208. batch_size=None,
  209. learning_rate=None,
  210. epochs_iters=None,
  211. device="gpu",
  212. use_vdl=True,
  213. save_dir=None,
  214. **kwargs,
  215. ):
  216. """
  217. Perform quantization aware training (QAT) and export the quantized
  218. model.
  219. Args:
  220. weight_path (str): Path of the weights to initialize the model.
  221. batch_size (int|None): Number of samples in each mini-batch. If
  222. multiple devices are used, this is the batch size on each
  223. device. If None, use a default setting. Default: None.
  224. learning_rate (float|None): Learning rate of QAT. If None, use a
  225. default setting. Default: None.
  226. epochs_iters (int|None): Total epochs of iterations of model
  227. training. If None, use a default setting. Default: None.
  228. device (str): A string that describes the device(s) to use, e.g.,
  229. 'cpu', 'gpu'. Default: 'gpu'.
  230. use_vdl (bool): Whether to enable VisualDL during training.
  231. Default: True.
  232. save_dir (str|None): Directory to store the results. If None, use a
  233. default setting. Default: None.
  234. Returns:
  235. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  236. """
  237. raise NotImplementedError
  238. @contextlib.contextmanager
  239. def _create_new_config_file(self):
  240. cls = self.__class__
  241. model_name = self.model_info["model_name"]
  242. tag = "_".join([cls.__name__.lower(), model_name])
  243. yaml_file_name = tag + ".yml"
  244. if not flags.DEBUG:
  245. with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
  246. path = os.path.join(td, yaml_file_name)
  247. with open(path, "w", encoding="utf-8"):
  248. pass
  249. yield path
  250. else:
  251. path = os.path.join(get_cache_dir(), yaml_file_name)
  252. with open(path, "w", encoding="utf-8"):
  253. pass
  254. yield path
  255. @contextlib.contextmanager
  256. def _create_new_val_json_file(self):
  257. cls = self.__class__
  258. model_name = self.model_info["model_name"]
  259. tag = "_".join([cls.__name__.lower(), model_name])
  260. json_file_name = tag + "_test.json"
  261. if not flags.DEBUG:
  262. with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
  263. path = os.path.join(td, json_file_name)
  264. with open(path, "w", encoding="utf-8"):
  265. pass
  266. yield path
  267. else:
  268. path = os.path.join(get_cache_dir(), json_file_name)
  269. with open(path, "w", encoding="utf-8"):
  270. pass
  271. yield path
  272. @cached_property
  273. def supported_apis(self):
  274. """supported apis"""
  275. return self.model_info.get("supported_apis", None)
  276. @cached_property
  277. def supported_train_opts(self):
  278. """supported train opts"""
  279. return self.model_info.get(
  280. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="train"), None
  281. )
  282. @cached_property
  283. def supported_evaluate_opts(self):
  284. """supported evaluate opts"""
  285. return self.model_info.get(
  286. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="evaluate"), None
  287. )
  288. @cached_property
  289. def supported_predict_opts(self):
  290. """supported predcit opts"""
  291. return self.model_info.get(
  292. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="predict"), None
  293. )
  294. @cached_property
  295. def supported_infer_opts(self):
  296. """supported infer opts"""
  297. return self.model_info.get(
  298. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="infer"), None
  299. )
  300. @cached_property
  301. def supported_compression_opts(self):
  302. """supported copression opts"""
  303. return self.model_info.get(
  304. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="compression"), None
  305. )
  306. @cached_property
  307. def supported_dataset_types(self):
  308. """supported dataset types"""
  309. return self.model_info.get("supported_dataset_types", None)
  310. @staticmethod
  311. def _assert_empty_kwargs(kwargs):
  312. if len(kwargs) > 0:
  313. # For compatibility
  314. logging.warning(f"Unconsumed keyword arguments detected: {kwargs}.")
  315. # raise RuntimeError(
  316. # f"Unconsumed keyword arguments detected: {kwargs}.")
  317. def _patch_apis(self):
  318. def _make_unavailable(bnd_method):
  319. @functools.wraps(bnd_method)
  320. def _unavailable_api(*args, **kwargs):
  321. model_name = self.name
  322. api_name = bnd_method.__name__
  323. raise UnsupportedAPIError(
  324. f"{model_name} does not support `{api_name}`."
  325. )
  326. return _unavailable_api
  327. def _add_prechecks(bnd_method):
  328. @functools.wraps(bnd_method)
  329. def _api_with_prechecks(*args, **kwargs):
  330. sig = inspect.Signature.from_callable(bnd_method)
  331. bnd_args = sig.bind(*args, **kwargs)
  332. args_dict = bnd_args.arguments
  333. # Merge default values
  334. for p in sig.parameters.values():
  335. if p.name not in args_dict and p.default is not p.empty:
  336. args_dict[p.name] = p.default
  337. # Rely on nonlocal variable `checks`
  338. for check in checks:
  339. # We throw any unhandled exception
  340. check.check(args_dict)
  341. return bnd_method(*args, **kwargs)
  342. api_name = bnd_method.__name__
  343. checks = []
  344. # We hardcode the prechecks for each API here
  345. if api_name == "train":
  346. opts = self.supported_train_opts
  347. if opts is not None:
  348. if "device" in opts:
  349. checks.append(_CheckDevice(opts["device"], check_mc=True))
  350. if "dy2st" in opts:
  351. checks.append(_CheckDy2St(opts["dy2st"]))
  352. if "amp" in opts:
  353. checks.append(_CheckAMP(opts["amp"]))
  354. elif api_name == "evaluate":
  355. opts = self.supported_evaluate_opts
  356. if opts is not None:
  357. if "device" in opts:
  358. checks.append(_CheckDevice(opts["device"], check_mc=True))
  359. if "amp" in opts:
  360. checks.append(_CheckAMP(opts["amp"]))
  361. elif api_name == "predict":
  362. opts = self.supported_predict_opts
  363. if opts is not None:
  364. if "device" in opts:
  365. checks.append(_CheckDevice(opts["device"], check_mc=False))
  366. elif api_name == "infer":
  367. opts = self.supported_infer_opts
  368. if opts is not None:
  369. if "device" in opts:
  370. checks.append(_CheckDevice(opts["device"], check_mc=False))
  371. elif api_name == "compression":
  372. opts = self.supported_compression_opts
  373. if opts is not None:
  374. if "device" in opts:
  375. checks.append(_CheckDevice(opts["device"], check_mc=True))
  376. else:
  377. return bnd_method
  378. return _api_with_prechecks
  379. supported_apis = self.supported_apis
  380. if supported_apis is not None:
  381. avail_api_set = set(self.supported_apis)
  382. else:
  383. avail_api_set = set(self._API_FULL_LIST)
  384. for api_name in self._API_FULL_LIST:
  385. api = getattr(self, api_name)
  386. if api_name not in avail_api_set:
  387. # We decorate real API implementation with `_make_unavailable`
  388. # so that an error is always raised when the API is called.
  389. decorated_api = _make_unavailable(api)
  390. # Monkey-patch
  391. setattr(self, api_name, decorated_api)
  392. else:
  393. if flags.CHECK_OPTS:
  394. # We decorate real API implementation with `_add_prechecks`
  395. # to perform sanity and validity checks before invoking the
  396. # internal API.
  397. decorated_api = _add_prechecks(api)
  398. setattr(self, api_name, decorated_api)
  399. class _CheckFailed(Exception):
  400. """_CheckFailed"""
  401. # Allow `_CheckFailed` class to be recognized using `hasattr(exc, 'check_failed_error')`
  402. check_failed_error = True
  403. def __init__(self, arg_name, arg_val, legal_vals):
  404. self.arg_name = arg_name
  405. self.arg_val = arg_val
  406. self.legal_vals = legal_vals
  407. def __str__(self):
  408. return f"`{self.arg_name}` is expected to be one of or conforms to {self.legal_vals}, but got {self.arg_val}"
  409. class _APICallArgsChecker(object):
  410. """_APICallArgsChecker"""
  411. def __init__(self, legal_vals):
  412. super().__init__()
  413. self.legal_vals = legal_vals
  414. def check(self, args):
  415. """check"""
  416. raise NotImplementedError
  417. class _CheckDevice(_APICallArgsChecker):
  418. """_CheckDevice"""
  419. def __init__(self, legal_vals, check_mc=False):
  420. super().__init__(legal_vals)
  421. self.check_mc = check_mc
  422. def check(self, args):
  423. """check"""
  424. assert "device" in args
  425. device = args["device"]
  426. if device is not None:
  427. device_type, dev_ids = parse_device(device)
  428. if not self.check_mc:
  429. if device_type not in self.legal_vals:
  430. raise _CheckFailed("device", device, self.legal_vals)
  431. else:
  432. # Currently we only check multi-device settings for GPUs
  433. if device_type != "gpu":
  434. if device_type not in self.legal_vals:
  435. raise _CheckFailed("device", device, self.legal_vals)
  436. else:
  437. n1c1_desc = f"{device_type}_n1c1"
  438. n1cx_desc = f"{device_type}_n1cx"
  439. nxcx_desc = f"{device_type}_nxcx"
  440. if len(dev_ids) <= 1:
  441. if (
  442. n1c1_desc not in self.legal_vals
  443. and n1cx_desc not in self.legal_vals
  444. and nxcx_desc not in self.legal_vals
  445. ):
  446. raise _CheckFailed("device", device, self.legal_vals)
  447. else:
  448. assert "ips" in args
  449. if args["ips"] is not None:
  450. # Multi-machine
  451. if nxcx_desc not in self.legal_vals:
  452. raise _CheckFailed("device", device, self.legal_vals)
  453. else:
  454. # Single-machine multi-device
  455. if (
  456. n1cx_desc not in self.legal_vals
  457. and nxcx_desc not in self.legal_vals
  458. ):
  459. raise _CheckFailed("device", device, self.legal_vals)
  460. else:
  461. # When `device` is None, we assume that a default device that the
  462. # current model supports will be used, so we simply do nothing.
  463. pass
  464. class _CheckDy2St(_APICallArgsChecker):
  465. """_CheckDy2St"""
  466. def check(self, args):
  467. """check"""
  468. assert "dy2st" in args
  469. dy2st = args["dy2st"]
  470. if isinstance(self.legal_vals, list):
  471. assert len(self.legal_vals) == 1
  472. support_dy2st = bool(self.legal_vals[0])
  473. else:
  474. support_dy2st = bool(self.legal_vals)
  475. if dy2st is not None:
  476. if dy2st and not support_dy2st:
  477. raise _CheckFailed("dy2st", dy2st, [support_dy2st])
  478. else:
  479. pass
  480. class _CheckAMP(_APICallArgsChecker):
  481. """_CheckAMP"""
  482. def check(self, args):
  483. """check"""
  484. assert "amp" in args
  485. amp = args["amp"]
  486. if amp is not None:
  487. if amp != "OFF" and amp not in self.legal_vals:
  488. raise _CheckFailed("amp", amp, self.legal_vals)
  489. else:
  490. pass