model.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import abc
  16. import inspect
  17. import functools
  18. import contextlib
  19. import tempfile
  20. import hashlib
  21. import base64
  22. from datetime import datetime, timedelta
  23. from .config import Config
  24. from .register import (
  25. get_registered_model_info,
  26. build_runner_from_model_info,
  27. build_model_from_model_info,
  28. )
  29. from ...utils import flags
  30. from ...utils import logging
  31. from ...utils.errors import (
  32. UnsupportedAPIError,
  33. UnsupportedParamError,
  34. raise_unsupported_api_error,
  35. )
  36. from ...utils.misc import CachedProperty as cached_property
  37. from ...utils.cache import get_cache_dir
  38. __all__ = ["PaddleModel", "BaseModel"]
  39. def _create_model(model_name=None, config=None):
  40. """_create_model"""
  41. if model_name is None and config is None:
  42. raise ValueError("At least one of `model_name` and `config` must be not None.")
  43. elif model_name is not None and config is not None:
  44. if model_name != config.model_name:
  45. raise ValueError(
  46. "If both `model_name` and `config` are not None, `model_name` should be the same as \
  47. `config.model_name`."
  48. )
  49. elif model_name is None and config is not None:
  50. model_name = config.model_name
  51. try:
  52. model_info = get_registered_model_info(model_name)
  53. except KeyError as e:
  54. raise UnsupportedParamError(
  55. f"{repr(model_name)} is not a registered model name."
  56. ) from e
  57. return build_model_from_model_info(model_info=model_info, config=config)
  58. PaddleModel = _create_model
  59. class BaseModel(metaclass=abc.ABCMeta):
  60. """
  61. Abstract base class of Model.
  62. Model defines how Config and Runner interact with each other. In addition,
  63. Model provides users with multiple APIs to perform model training,
  64. prediction, etc.
  65. """
  66. _API_FULL_LIST = ("train", "evaluate", "predict", "export", "infer", "compression")
  67. _API_SUPPORTED_OPTS_KEY_PATTERN = "supported_{api_name}_opts"
  68. def __init__(self, model_name, config=None):
  69. """
  70. Initialize the instance.
  71. Args:
  72. model_name (str): A registered model name.
  73. config (base.config.BaseConfig|None): Config. Default: None.
  74. """
  75. super().__init__()
  76. self.name = model_name
  77. self.model_info = get_registered_model_info(model_name)
  78. # NOTE: We build runner instance here by extracting runner info from model info
  79. # so that we don't have to overwrite the `__init__` method of each child class.
  80. self.runner = build_runner_from_model_info(self.model_info)
  81. if config is None:
  82. logging.warning(
  83. "We strongly discourage leaving `config` unset or setting it to None. "
  84. "Please note that when `config` is None, default settings will be used for every unspecified \
  85. configuration item, "
  86. "which may lead to unexpected result. Please make sure that this is what you intend to do."
  87. )
  88. config = Config(model_name)
  89. self.config = config
  90. self._patch_apis()
  91. @abc.abstractmethod
  92. def train(
  93. self,
  94. batch_size=None,
  95. learning_rate=None,
  96. epochs_iters=None,
  97. ips=None,
  98. device="gpu",
  99. resume_path=None,
  100. dy2st=False,
  101. amp="OFF",
  102. num_workers=None,
  103. use_vdl=True,
  104. save_dir=None,
  105. **kwargs,
  106. ):
  107. """
  108. Train a model.
  109. Args:
  110. batch_size (int|None): Number of samples in each mini-batch. If
  111. multiple devices are used, this is the batch size on each device.
  112. If None, use a default setting. Default: None.
  113. learning_rate (float|None): Learning rate of model training. If
  114. None, use a default setting. Default: None.
  115. epochs_iters (int|None): Total epochs or iterations of model
  116. training. If None, use a default setting. Default: None.
  117. ips (str|None): If not None, enable multi-machine training mode.
  118. `ips` specifies Paddle cluster node ips, e.g.,
  119. '192.168.0.16,192.168.0.17'. Default: None.
  120. device (str): A string that describes the device(s) to use, e.g.,
  121. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  122. resume_path (str|None): If not None, resume training from the model
  123. snapshot corresponding to the weight file `resume_path`. If
  124. None, use a default setting. Default: None.
  125. dy2st (bool): Whether to enable dynamic-to-static training.
  126. Default: False.
  127. amp (str): Optimization level to use in AMP training. Choices are
  128. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  129. num_workers (int|None): Number of subprocesses to use for data
  130. loading. If None, use a default setting. Default: None.
  131. use_vdl (bool): Whether to enable VisualDL during training.
  132. Default: True.
  133. save_dir (str|None): Directory to store model snapshots and logs. If
  134. None, use a default setting. Default: None.
  135. Returns:
  136. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  137. """
  138. raise NotImplementedError
  139. @abc.abstractmethod
  140. def evaluate(
  141. self,
  142. weight_path,
  143. batch_size=None,
  144. ips=None,
  145. device="gpu",
  146. amp="OFF",
  147. num_workers=None,
  148. **kwargs,
  149. ):
  150. """
  151. Evaluate a model.
  152. Args:
  153. weight_path (str): Path of the weights to initialize the model.
  154. batch_size (int|None): Number of samples in each mini-batch. If
  155. multiple devices are used, this is the batch size on each device.
  156. If None, use a default setting. Default: None.
  157. ips (str|None): If not None, enable multi-machine evaluation mode.
  158. `ips` specifies Paddle cluster node ips, e.g.,
  159. '192.168.0.16,192.168.0.17'. Default: None.
  160. device (str): A string that describes the device(s) to use, e.g.,
  161. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  162. amp (str): Optimization level to use in AMP training. Choices are
  163. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  164. num_workers (int|None): Number of subprocesses to use for data
  165. loading. If None, use a default setting. Default: None.
  166. Returns:
  167. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  168. """
  169. raise NotImplementedError
  170. @abc.abstractmethod
  171. def predict(self, weight_path, input_path, device="gpu", save_dir=None, **kwargs):
  172. """
  173. Make prediction with a pre-trained model.
  174. Args:
  175. weight_path (str): Path of the weights to initialize the model.
  176. input_path (str): Path of the input file, e.g. an image.
  177. device (str): A string that describes the device to use, e.g.,
  178. 'cpu', 'gpu'. Default: 'gpu'.
  179. save_dir (str|None): Directory to store prediction results. If None,
  180. use a default setting. Default: None.
  181. Returns:
  182. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  183. """
  184. raise NotImplementedError
  185. @abc.abstractmethod
  186. def export(self, weight_path, save_dir, **kwargs):
  187. """
  188. Export a pre-trained model.
  189. Args:
  190. weight_path (str): Path of the weights to initialize the model.
  191. save_dir (str): Directory to store the exported model.
  192. Returns:
  193. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  194. """
  195. raise NotImplementedError
  196. @abc.abstractmethod
  197. def infer(self, model_dir, input_path, device="gpu", save_dir=None, **kwargs):
  198. """
  199. Make inference with an exported inference model.
  200. Args:
  201. model_dir (str): Path of the exported inference model.
  202. input_path (str): Path of the input file, e.g. an image.
  203. device (str): A string that describes the device(s) to use, e.g.,
  204. 'cpu', 'gpu'. Default: 'gpu'.
  205. save_dir (str|None): Directory to store inference results. If None,
  206. use a default setting. Default: None.
  207. Returns:
  208. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  209. """
  210. raise NotImplementedError
  211. @abc.abstractmethod
  212. def compression(
  213. self,
  214. weight_path,
  215. batch_size=None,
  216. learning_rate=None,
  217. epochs_iters=None,
  218. device="gpu",
  219. use_vdl=True,
  220. save_dir=None,
  221. **kwargs,
  222. ):
  223. """
  224. Perform quantization aware training (QAT) and export the quantized
  225. model.
  226. Args:
  227. weight_path (str): Path of the weights to initialize the model.
  228. batch_size (int|None): Number of samples in each mini-batch. If
  229. multiple devices are used, this is the batch size on each
  230. device. If None, use a default setting. Default: None.
  231. learning_rate (float|None): Learning rate of QAT. If None, use a
  232. default setting. Default: None.
  233. epochs_iters (int|None): Total epochs of iterations of model
  234. training. If None, use a default setting. Default: None.
  235. device (str): A string that describes the device(s) to use, e.g.,
  236. 'cpu', 'gpu'. Default: 'gpu'.
  237. use_vdl (bool): Whether to enable VisualDL during training.
  238. Default: True.
  239. save_dir (str|None): Directory to store the results. If None, use a
  240. default setting. Default: None.
  241. Returns:
  242. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  243. """
  244. raise NotImplementedError
  245. @contextlib.contextmanager
  246. def _create_new_config_file(self):
  247. cls = self.__class__
  248. model_name = self.model_info["model_name"]
  249. tag = "_".join([cls.__name__.lower(), model_name])
  250. yaml_file_name = tag + ".yml"
  251. if not flags.DEBUG:
  252. with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
  253. path = os.path.join(td, yaml_file_name)
  254. with open(path, "w", encoding="utf-8"):
  255. pass
  256. yield path
  257. else:
  258. path = os.path.join(get_cache_dir(), yaml_file_name)
  259. with open(path, "w", encoding="utf-8"):
  260. pass
  261. yield path
  262. @cached_property
  263. def supported_apis(self):
  264. """supported apis"""
  265. return self.model_info.get("supported_apis", None)
  266. @cached_property
  267. def supported_train_opts(self):
  268. """supported train opts"""
  269. return self.model_info.get(
  270. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="train"), None
  271. )
  272. @cached_property
  273. def supported_evaluate_opts(self):
  274. """supported evaluate opts"""
  275. return self.model_info.get(
  276. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="evaluate"), None
  277. )
  278. @cached_property
  279. def supported_predict_opts(self):
  280. """supported predcit opts"""
  281. return self.model_info.get(
  282. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="predict"), None
  283. )
  284. @cached_property
  285. def supported_infer_opts(self):
  286. """supported infer opts"""
  287. return self.model_info.get(
  288. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="infer"), None
  289. )
  290. @cached_property
  291. def supported_compression_opts(self):
  292. """supported copression opts"""
  293. return self.model_info.get(
  294. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="compression"), None
  295. )
  296. @cached_property
  297. def supported_dataset_types(self):
  298. """supported dataset types"""
  299. return self.model_info.get("supported_dataset_types", None)
  300. @staticmethod
  301. def _assert_empty_kwargs(kwargs):
  302. if len(kwargs) > 0:
  303. # For compatibility
  304. logging.warning(f"Unconsumed keyword arguments detected: {kwargs}.")
  305. # raise RuntimeError(
  306. # f"Unconsumed keyword arguments detected: {kwargs}.")
  307. def _patch_apis(self):
  308. def _make_unavailable(bnd_method):
  309. @functools.wraps(bnd_method)
  310. def _unavailable_api(*args, **kwargs):
  311. model_name = self.name
  312. api_name = bnd_method.__name__
  313. raise UnsupportedAPIError(
  314. f"{model_name} does not support `{api_name}`."
  315. )
  316. return _unavailable_api
  317. def _add_prechecks(bnd_method):
  318. @functools.wraps(bnd_method)
  319. def _api_with_prechecks(*args, **kwargs):
  320. sig = inspect.Signature.from_callable(bnd_method)
  321. bnd_args = sig.bind(*args, **kwargs)
  322. args_dict = bnd_args.arguments
  323. # Merge default values
  324. for p in sig.parameters.values():
  325. if p.name not in args_dict and p.default is not p.empty:
  326. args_dict[p.name] = p.default
  327. # Rely on nonlocal variable `checks`
  328. for check in checks:
  329. # We throw any unhandled exception
  330. check.check(args_dict)
  331. return bnd_method(*args, **kwargs)
  332. api_name = bnd_method.__name__
  333. checks = []
  334. # We hardcode the prechecks for each API here
  335. if api_name == "train":
  336. opts = self.supported_train_opts
  337. if opts is not None:
  338. if "device" in opts:
  339. checks.append(
  340. _CheckDevice(
  341. opts["device"], self.runner.parse_device, check_mc=True
  342. )
  343. )
  344. if "dy2st" in opts:
  345. checks.append(_CheckDy2St(opts["dy2st"]))
  346. if "amp" in opts:
  347. checks.append(_CheckAMP(opts["amp"]))
  348. elif api_name == "evaluate":
  349. opts = self.supported_evaluate_opts
  350. if opts is not None:
  351. if "device" in opts:
  352. checks.append(
  353. _CheckDevice(
  354. opts["device"], self.runner.parse_device, check_mc=True
  355. )
  356. )
  357. if "amp" in opts:
  358. checks.append(_CheckAMP(opts["amp"]))
  359. elif api_name == "predict":
  360. opts = self.supported_predict_opts
  361. if opts is not None:
  362. if "device" in opts:
  363. checks.append(
  364. _CheckDevice(
  365. opts["device"], self.runner.parse_device, check_mc=False
  366. )
  367. )
  368. elif api_name == "infer":
  369. opts = self.supported_infer_opts
  370. if opts is not None:
  371. if "device" in opts:
  372. checks.append(
  373. _CheckDevice(
  374. opts["device"], self.runner.parse_device, check_mc=False
  375. )
  376. )
  377. elif api_name == "compression":
  378. opts = self.supported_compression_opts
  379. if opts is not None:
  380. if "device" in opts:
  381. checks.append(
  382. _CheckDevice(
  383. opts["device"], self.runner.parse_device, check_mc=True
  384. )
  385. )
  386. else:
  387. return bnd_method
  388. return _api_with_prechecks
  389. supported_apis = self.supported_apis
  390. if supported_apis is not None:
  391. avail_api_set = set(self.supported_apis)
  392. else:
  393. avail_api_set = set(self._API_FULL_LIST)
  394. for api_name in self._API_FULL_LIST:
  395. api = getattr(self, api_name)
  396. if api_name not in avail_api_set:
  397. # We decorate real API implementation with `_make_unavailable`
  398. # so that an error is always raised when the API is called.
  399. decorated_api = _make_unavailable(api)
  400. # Monkey-patch
  401. setattr(self, api_name, decorated_api)
  402. else:
  403. if flags.CHECK_OPTS:
  404. # We decorate real API implementation with `_add_prechecks`
  405. # to perform sanity and validity checks before invoking the
  406. # internal API.
  407. decorated_api = _add_prechecks(api)
  408. setattr(self, api_name, decorated_api)
  409. class _CheckFailed(Exception):
  410. """_CheckFailed"""
  411. # Allow `_CheckFailed` class to be recognized using `hasattr(exc, 'check_failed_error')`
  412. check_failed_error = True
  413. def __init__(self, arg_name, arg_val, legal_vals):
  414. self.arg_name = arg_name
  415. self.arg_val = arg_val
  416. self.legal_vals = legal_vals
  417. def __str__(self):
  418. return f"`{self.arg_name}` is expected to be one of or conforms to {self.legal_vals}, but got {self.arg_val}"
  419. class _APICallArgsChecker(object):
  420. """_APICallArgsChecker"""
  421. def __init__(self, legal_vals):
  422. super().__init__()
  423. self.legal_vals = legal_vals
  424. def check(self, args):
  425. """check"""
  426. raise NotImplementedError
  427. class _CheckDevice(_APICallArgsChecker):
  428. """_CheckDevice"""
  429. def __init__(self, legal_vals, parse_device, check_mc=False):
  430. super().__init__(legal_vals)
  431. self.parse_device = parse_device
  432. self.check_mc = check_mc
  433. def check(self, args):
  434. """check"""
  435. assert "device" in args
  436. device = args["device"]
  437. if device is not None:
  438. device_type, dev_ids = self.parse_device(device)
  439. if not self.check_mc:
  440. if device_type not in self.legal_vals:
  441. raise _CheckFailed("device", device, self.legal_vals)
  442. else:
  443. # Currently we only check multi-device settings for GPUs
  444. if device_type != "gpu":
  445. if device_type not in self.legal_vals:
  446. raise _CheckFailed("device", device, self.legal_vals)
  447. else:
  448. n1c1_desc = f"{device_type}_n1c1"
  449. n1cx_desc = f"{device_type}_n1cx"
  450. nxcx_desc = f"{device_type}_nxcx"
  451. if len(dev_ids) <= 1:
  452. if (
  453. n1c1_desc not in self.legal_vals
  454. and n1cx_desc not in self.legal_vals
  455. and nxcx_desc not in self.legal_vals
  456. ):
  457. raise _CheckFailed("device", device, self.legal_vals)
  458. else:
  459. assert "ips" in args
  460. if args["ips"] is not None:
  461. # Multi-machine
  462. if nxcx_desc not in self.legal_vals:
  463. raise _CheckFailed("device", device, self.legal_vals)
  464. else:
  465. # Single-machine multi-device
  466. if (
  467. n1cx_desc not in self.legal_vals
  468. and nxcx_desc not in self.legal_vals
  469. ):
  470. raise _CheckFailed("device", device, self.legal_vals)
  471. else:
  472. # When `device` is None, we assume that a default device that the
  473. # current model supports will be used, so we simply do nothing.
  474. pass
  475. class _CheckDy2St(_APICallArgsChecker):
  476. """_CheckDy2St"""
  477. def check(self, args):
  478. """check"""
  479. assert "dy2st" in args
  480. dy2st = args["dy2st"]
  481. if isinstance(self.legal_vals, list):
  482. assert len(self.legal_vals) == 1
  483. support_dy2st = bool(self.legal_vals[0])
  484. else:
  485. support_dy2st = bool(self.legal_vals)
  486. if dy2st is not None:
  487. if dy2st and not support_dy2st:
  488. raise _CheckFailed("dy2st", dy2st, [support_dy2st])
  489. else:
  490. pass
  491. class _CheckAMP(_APICallArgsChecker):
  492. """_CheckAMP"""
  493. def check(self, args):
  494. """check"""
  495. assert "amp" in args
  496. amp = args["amp"]
  497. if amp is not None:
  498. if amp != "OFF" and amp not in self.legal_vals:
  499. raise _CheckFailed("amp", amp, self.legal_vals)
  500. else:
  501. pass