model.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import abc
  16. import inspect
  17. import functools
  18. import contextlib
  19. import tempfile
  20. import hashlib
  21. import base64
  22. from datetime import datetime, timedelta
  23. from .config import Config
  24. from .register import (
  25. get_registered_model_info,
  26. build_runner_from_model_info,
  27. build_model_from_model_info,
  28. )
  29. from ...utils import flags
  30. from ...utils import logging
  31. from ...utils.errors import (
  32. UnsupportedAPIError,
  33. UnsupportedParamError,
  34. raise_unsupported_api_error,
  35. )
  36. from ...utils.device import parse_device
  37. from ...utils.misc import CachedProperty as cached_property
  38. from ...utils.cache import get_cache_dir
  39. __all__ = ["PaddleModel", "BaseModel"]
  40. def _create_model(model_name=None, config=None):
  41. """_create_model"""
  42. if model_name is None and config is None:
  43. raise ValueError("At least one of `model_name` and `config` must be not None.")
  44. elif model_name is not None and config is not None:
  45. if model_name != config.model_name:
  46. raise ValueError(
  47. "If both `model_name` and `config` are not None, `model_name` should be the same as \
  48. `config.model_name`."
  49. )
  50. elif model_name is None and config is not None:
  51. model_name = config.model_name
  52. try:
  53. model_info = get_registered_model_info(model_name)
  54. except KeyError as e:
  55. raise UnsupportedParamError(
  56. f"{repr(model_name)} is not a registered model name."
  57. ) from e
  58. return build_model_from_model_info(model_info=model_info, config=config)
  59. PaddleModel = _create_model
  60. class BaseModel(metaclass=abc.ABCMeta):
  61. """
  62. Abstract base class of Model.
  63. Model defines how Config and Runner interact with each other. In addition,
  64. Model provides users with multiple APIs to perform model training,
  65. prediction, etc.
  66. """
  67. _API_FULL_LIST = ("train", "evaluate", "predict", "export", "infer", "compression")
  68. _API_SUPPORTED_OPTS_KEY_PATTERN = "supported_{api_name}_opts"
  69. def __init__(self, model_name, config=None):
  70. """
  71. Initialize the instance.
  72. Args:
  73. model_name (str): A registered model name.
  74. config (base.config.BaseConfig|None): Config. Default: None.
  75. """
  76. super().__init__()
  77. self.name = model_name
  78. self.model_info = get_registered_model_info(model_name)
  79. # NOTE: We build runner instance here by extracting runner info from model info
  80. # so that we don't have to overwrite the `__init__` method of each child class.
  81. self.runner = build_runner_from_model_info(self.model_info)
  82. if config is None:
  83. logging.warning(
  84. "We strongly discourage leaving `config` unset or setting it to None. "
  85. "Please note that when `config` is None, default settings will be used for every unspecified \
  86. configuration item, "
  87. "which may lead to unexpected result. Please make sure that this is what you intend to do."
  88. )
  89. config = Config(model_name)
  90. self.config = config
  91. self._patch_apis()
  92. @abc.abstractmethod
  93. def train(
  94. self,
  95. batch_size=None,
  96. learning_rate=None,
  97. epochs_iters=None,
  98. ips=None,
  99. device="gpu",
  100. resume_path=None,
  101. dy2st=False,
  102. amp="OFF",
  103. num_workers=None,
  104. use_vdl=True,
  105. save_dir=None,
  106. **kwargs,
  107. ):
  108. """
  109. Train a model.
  110. Args:
  111. batch_size (int|None): Number of samples in each mini-batch. If
  112. multiple devices are used, this is the batch size on each device.
  113. If None, use a default setting. Default: None.
  114. learning_rate (float|None): Learning rate of model training. If
  115. None, use a default setting. Default: None.
  116. epochs_iters (int|None): Total epochs or iterations of model
  117. training. If None, use a default setting. Default: None.
  118. ips (str|None): If not None, enable multi-machine training mode.
  119. `ips` specifies Paddle cluster node ips, e.g.,
  120. '192.168.0.16,192.168.0.17'. Default: None.
  121. device (str): A string that describes the device(s) to use, e.g.,
  122. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  123. resume_path (str|None): If not None, resume training from the model
  124. snapshot corresponding to the weight file `resume_path`. If
  125. None, use a default setting. Default: None.
  126. dy2st (bool): Whether to enable dynamic-to-static training.
  127. Default: False.
  128. amp (str): Optimization level to use in AMP training. Choices are
  129. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  130. num_workers (int|None): Number of subprocesses to use for data
  131. loading. If None, use a default setting. Default: None.
  132. use_vdl (bool): Whether to enable VisualDL during training.
  133. Default: True.
  134. save_dir (str|None): Directory to store model snapshots and logs. If
  135. None, use a default setting. Default: None.
  136. Returns:
  137. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  138. """
  139. raise NotImplementedError
  140. @abc.abstractmethod
  141. def evaluate(
  142. self,
  143. weight_path,
  144. batch_size=None,
  145. ips=None,
  146. device="gpu",
  147. amp="OFF",
  148. num_workers=None,
  149. **kwargs,
  150. ):
  151. """
  152. Evaluate a model.
  153. Args:
  154. weight_path (str): Path of the weights to initialize the model.
  155. batch_size (int|None): Number of samples in each mini-batch. If
  156. multiple devices are used, this is the batch size on each device.
  157. If None, use a default setting. Default: None.
  158. ips (str|None): If not None, enable multi-machine evaluation mode.
  159. `ips` specifies Paddle cluster node ips, e.g.,
  160. '192.168.0.16,192.168.0.17'. Default: None.
  161. device (str): A string that describes the device(s) to use, e.g.,
  162. 'cpu', 'gpu', 'gpu:1,2'. Default: 'gpu'.
  163. amp (str): Optimization level to use in AMP training. Choices are
  164. ['O1', 'O2', 'OFF']. Default: 'OFF'.
  165. num_workers (int|None): Number of subprocesses to use for data
  166. loading. If None, use a default setting. Default: None.
  167. Returns:
  168. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  169. """
  170. raise NotImplementedError
  171. @abc.abstractmethod
  172. def predict(self, weight_path, input_path, device="gpu", save_dir=None, **kwargs):
  173. """
  174. Make prediction with a pre-trained model.
  175. Args:
  176. weight_path (str): Path of the weights to initialize the model.
  177. input_path (str): Path of the input file, e.g. an image.
  178. device (str): A string that describes the device to use, e.g.,
  179. 'cpu', 'gpu'. Default: 'gpu'.
  180. save_dir (str|None): Directory to store prediction results. If None,
  181. use a default setting. Default: None.
  182. Returns:
  183. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  184. """
  185. raise NotImplementedError
  186. @abc.abstractmethod
  187. def export(self, weight_path, save_dir, **kwargs):
  188. """
  189. Export a pre-trained model.
  190. Args:
  191. weight_path (str): Path of the weights to initialize the model.
  192. save_dir (str): Directory to store the exported model.
  193. Returns:
  194. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  195. """
  196. raise NotImplementedError
  197. @abc.abstractmethod
  198. def infer(self, model_dir, input_path, device="gpu", save_dir=None, **kwargs):
  199. """
  200. Make inference with an exported inference model.
  201. Args:
  202. model_dir (str): Path of the exported inference model.
  203. input_path (str): Path of the input file, e.g. an image.
  204. device (str): A string that describes the device(s) to use, e.g.,
  205. 'cpu', 'gpu'. Default: 'gpu'.
  206. save_dir (str|None): Directory to store inference results. If None,
  207. use a default setting. Default: None.
  208. Returns:
  209. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  210. """
  211. raise NotImplementedError
  212. @abc.abstractmethod
  213. def compression(
  214. self,
  215. weight_path,
  216. batch_size=None,
  217. learning_rate=None,
  218. epochs_iters=None,
  219. device="gpu",
  220. use_vdl=True,
  221. save_dir=None,
  222. **kwargs,
  223. ):
  224. """
  225. Perform quantization aware training (QAT) and export the quantized
  226. model.
  227. Args:
  228. weight_path (str): Path of the weights to initialize the model.
  229. batch_size (int|None): Number of samples in each mini-batch. If
  230. multiple devices are used, this is the batch size on each
  231. device. If None, use a default setting. Default: None.
  232. learning_rate (float|None): Learning rate of QAT. If None, use a
  233. default setting. Default: None.
  234. epochs_iters (int|None): Total epochs of iterations of model
  235. training. If None, use a default setting. Default: None.
  236. device (str): A string that describes the device(s) to use, e.g.,
  237. 'cpu', 'gpu'. Default: 'gpu'.
  238. use_vdl (bool): Whether to enable VisualDL during training.
  239. Default: True.
  240. save_dir (str|None): Directory to store the results. If None, use a
  241. default setting. Default: None.
  242. Returns:
  243. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  244. """
  245. raise NotImplementedError
  246. @contextlib.contextmanager
  247. def _create_new_config_file(self):
  248. cls = self.__class__
  249. model_name = self.model_info["model_name"]
  250. tag = "_".join([cls.__name__.lower(), model_name])
  251. yaml_file_name = tag + ".yml"
  252. if not flags.DEBUG:
  253. with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
  254. path = os.path.join(td, yaml_file_name)
  255. with open(path, "w", encoding="utf-8"):
  256. pass
  257. yield path
  258. else:
  259. path = os.path.join(get_cache_dir(), yaml_file_name)
  260. with open(path, "w", encoding="utf-8"):
  261. pass
  262. yield path
  263. @contextlib.contextmanager
  264. def _create_new_val_json_file(self):
  265. cls = self.__class__
  266. model_name = self.model_info["model_name"]
  267. tag = "_".join([cls.__name__.lower(), model_name])
  268. json_file_name = tag + "_test.json"
  269. if not flags.DEBUG:
  270. with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
  271. path = os.path.join(td, json_file_name)
  272. with open(path, "w", encoding="utf-8"):
  273. pass
  274. yield path
  275. else:
  276. path = os.path.join(get_cache_dir(), json_file_name)
  277. with open(path, "w", encoding="utf-8"):
  278. pass
  279. yield path
  280. @cached_property
  281. def supported_apis(self):
  282. """supported apis"""
  283. return self.model_info.get("supported_apis", None)
  284. @cached_property
  285. def supported_train_opts(self):
  286. """supported train opts"""
  287. return self.model_info.get(
  288. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="train"), None
  289. )
  290. @cached_property
  291. def supported_evaluate_opts(self):
  292. """supported evaluate opts"""
  293. return self.model_info.get(
  294. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="evaluate"), None
  295. )
  296. @cached_property
  297. def supported_predict_opts(self):
  298. """supported predcit opts"""
  299. return self.model_info.get(
  300. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="predict"), None
  301. )
  302. @cached_property
  303. def supported_infer_opts(self):
  304. """supported infer opts"""
  305. return self.model_info.get(
  306. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="infer"), None
  307. )
  308. @cached_property
  309. def supported_compression_opts(self):
  310. """supported copression opts"""
  311. return self.model_info.get(
  312. self._API_SUPPORTED_OPTS_KEY_PATTERN.format(api_name="compression"), None
  313. )
  314. @cached_property
  315. def supported_dataset_types(self):
  316. """supported dataset types"""
  317. return self.model_info.get("supported_dataset_types", None)
  318. @staticmethod
  319. def _assert_empty_kwargs(kwargs):
  320. if len(kwargs) > 0:
  321. # For compatibility
  322. logging.warning(f"Unconsumed keyword arguments detected: {kwargs}.")
  323. # raise RuntimeError(
  324. # f"Unconsumed keyword arguments detected: {kwargs}.")
  325. def _patch_apis(self):
  326. def _make_unavailable(bnd_method):
  327. @functools.wraps(bnd_method)
  328. def _unavailable_api(*args, **kwargs):
  329. model_name = self.name
  330. api_name = bnd_method.__name__
  331. raise UnsupportedAPIError(
  332. f"{model_name} does not support `{api_name}`."
  333. )
  334. return _unavailable_api
  335. def _add_prechecks(bnd_method):
  336. @functools.wraps(bnd_method)
  337. def _api_with_prechecks(*args, **kwargs):
  338. sig = inspect.Signature.from_callable(bnd_method)
  339. bnd_args = sig.bind(*args, **kwargs)
  340. args_dict = bnd_args.arguments
  341. # Merge default values
  342. for p in sig.parameters.values():
  343. if p.name not in args_dict and p.default is not p.empty:
  344. args_dict[p.name] = p.default
  345. # Rely on nonlocal variable `checks`
  346. for check in checks:
  347. # We throw any unhandled exception
  348. check.check(args_dict)
  349. return bnd_method(*args, **kwargs)
  350. api_name = bnd_method.__name__
  351. checks = []
  352. # We hardcode the prechecks for each API here
  353. if api_name == "train":
  354. opts = self.supported_train_opts
  355. if opts is not None:
  356. if "device" in opts:
  357. checks.append(_CheckDevice(opts["device"], check_mc=True))
  358. if "dy2st" in opts:
  359. checks.append(_CheckDy2St(opts["dy2st"]))
  360. if "amp" in opts:
  361. checks.append(_CheckAMP(opts["amp"]))
  362. elif api_name == "evaluate":
  363. opts = self.supported_evaluate_opts
  364. if opts is not None:
  365. if "device" in opts:
  366. checks.append(_CheckDevice(opts["device"], check_mc=True))
  367. if "amp" in opts:
  368. checks.append(_CheckAMP(opts["amp"]))
  369. elif api_name == "predict":
  370. opts = self.supported_predict_opts
  371. if opts is not None:
  372. if "device" in opts:
  373. checks.append(_CheckDevice(opts["device"], check_mc=False))
  374. elif api_name == "infer":
  375. opts = self.supported_infer_opts
  376. if opts is not None:
  377. if "device" in opts:
  378. checks.append(_CheckDevice(opts["device"], check_mc=False))
  379. elif api_name == "compression":
  380. opts = self.supported_compression_opts
  381. if opts is not None:
  382. if "device" in opts:
  383. checks.append(_CheckDevice(opts["device"], check_mc=True))
  384. else:
  385. return bnd_method
  386. return _api_with_prechecks
  387. supported_apis = self.supported_apis
  388. if supported_apis is not None:
  389. avail_api_set = set(self.supported_apis)
  390. else:
  391. avail_api_set = set(self._API_FULL_LIST)
  392. for api_name in self._API_FULL_LIST:
  393. api = getattr(self, api_name)
  394. if api_name not in avail_api_set:
  395. # We decorate real API implementation with `_make_unavailable`
  396. # so that an error is always raised when the API is called.
  397. decorated_api = _make_unavailable(api)
  398. # Monkey-patch
  399. setattr(self, api_name, decorated_api)
  400. else:
  401. if flags.CHECK_OPTS:
  402. # We decorate real API implementation with `_add_prechecks`
  403. # to perform sanity and validity checks before invoking the
  404. # internal API.
  405. decorated_api = _add_prechecks(api)
  406. setattr(self, api_name, decorated_api)
  407. class _CheckFailed(Exception):
  408. """_CheckFailed"""
  409. # Allow `_CheckFailed` class to be recognized using `hasattr(exc, 'check_failed_error')`
  410. check_failed_error = True
  411. def __init__(self, arg_name, arg_val, legal_vals):
  412. self.arg_name = arg_name
  413. self.arg_val = arg_val
  414. self.legal_vals = legal_vals
  415. def __str__(self):
  416. return f"`{self.arg_name}` is expected to be one of or conforms to {self.legal_vals}, but got {self.arg_val}"
  417. class _APICallArgsChecker(object):
  418. """_APICallArgsChecker"""
  419. def __init__(self, legal_vals):
  420. super().__init__()
  421. self.legal_vals = legal_vals
  422. def check(self, args):
  423. """check"""
  424. raise NotImplementedError
  425. class _CheckDevice(_APICallArgsChecker):
  426. """_CheckDevice"""
  427. def __init__(self, legal_vals, check_mc=False):
  428. super().__init__(legal_vals)
  429. self.check_mc = check_mc
  430. def check(self, args):
  431. """check"""
  432. assert "device" in args
  433. device = args["device"]
  434. if device is not None:
  435. device_type, dev_ids = parse_device(device)
  436. if not self.check_mc:
  437. if device_type not in self.legal_vals:
  438. raise _CheckFailed("device", device, self.legal_vals)
  439. else:
  440. # Currently we only check multi-device settings for GPUs
  441. if device_type != "gpu":
  442. if device_type not in self.legal_vals:
  443. raise _CheckFailed("device", device, self.legal_vals)
  444. else:
  445. n1c1_desc = f"{device_type}_n1c1"
  446. n1cx_desc = f"{device_type}_n1cx"
  447. nxcx_desc = f"{device_type}_nxcx"
  448. if len(dev_ids) <= 1:
  449. if (
  450. n1c1_desc not in self.legal_vals
  451. and n1cx_desc not in self.legal_vals
  452. and nxcx_desc not in self.legal_vals
  453. ):
  454. raise _CheckFailed("device", device, self.legal_vals)
  455. else:
  456. assert "ips" in args
  457. if args["ips"] is not None:
  458. # Multi-machine
  459. if nxcx_desc not in self.legal_vals:
  460. raise _CheckFailed("device", device, self.legal_vals)
  461. else:
  462. # Single-machine multi-device
  463. if (
  464. n1cx_desc not in self.legal_vals
  465. and nxcx_desc not in self.legal_vals
  466. ):
  467. raise _CheckFailed("device", device, self.legal_vals)
  468. else:
  469. # When `device` is None, we assume that a default device that the
  470. # current model supports will be used, so we simply do nothing.
  471. pass
  472. class _CheckDy2St(_APICallArgsChecker):
  473. """_CheckDy2St"""
  474. def check(self, args):
  475. """check"""
  476. assert "dy2st" in args
  477. dy2st = args["dy2st"]
  478. if isinstance(self.legal_vals, list):
  479. assert len(self.legal_vals) == 1
  480. support_dy2st = bool(self.legal_vals[0])
  481. else:
  482. support_dy2st = bool(self.legal_vals)
  483. if dy2st is not None:
  484. if dy2st and not support_dy2st:
  485. raise _CheckFailed("dy2st", dy2st, [support_dy2st])
  486. else:
  487. pass
  488. class _CheckAMP(_APICallArgsChecker):
  489. """_CheckAMP"""
  490. def check(self, args):
  491. """check"""
  492. assert "amp" in args
  493. amp = args["amp"]
  494. if amp is not None:
  495. if amp != "OFF" and amp not in self.legal_vals:
  496. raise _CheckFailed("amp", amp, self.legal_vals)
  497. else:
  498. pass