config.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import yaml
  15. from typing import Union
  16. from paddleclas.ppcls.utils.config import get_config, override_config
  17. from ...base import BaseConfig
  18. from ....utils.misc import abspath
  19. class ClsConfig(BaseConfig):
  20. """Image Classification Task Config"""
  21. def update(self, list_like_obj: list):
  22. """update self
  23. Args:
  24. list_like_obj (list): list of pairs(key0.key1.idx.key2=value), such as:
  25. [
  26. 'topk=2',
  27. 'VALID.transforms.1.ResizeImage.resize_short=300'
  28. ]
  29. """
  30. dict_ = override_config(self.dict, list_like_obj)
  31. self.reset_from_dict(dict_)
  32. def load(self, config_file_path: str):
  33. """load config from yaml file
  34. Args:
  35. config_file_path (str): the path of yaml file.
  36. Raises:
  37. TypeError: the content of yaml file `config_file_path` error.
  38. """
  39. dict_ = yaml.load(open(config_file_path, 'rb'), Loader=yaml.Loader)
  40. if not isinstance(dict_, dict):
  41. raise TypeError
  42. self.reset_from_dict(dict_)
  43. def dump(self, config_file_path: str):
  44. """dump self to yaml file
  45. Args:
  46. config_file_path (str): the path to save self as yaml file.
  47. """
  48. with open(config_file_path, 'w', encoding='utf-8') as f:
  49. yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
  50. def update_dataset(
  51. self,
  52. dataset_path: str,
  53. dataset_type: str=None,
  54. *,
  55. train_list_path: str=None, ):
  56. """update dataset settings
  57. Args:
  58. dataset_path (str): the root path of dataset.
  59. dataset_type (str, optional): dataset type. Defaults to None.
  60. train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
  61. Raises:
  62. ValueError: the dataset_type error.
  63. """
  64. dataset_path = abspath(dataset_path)
  65. if dataset_type is None:
  66. dataset_type = 'ClsDataset'
  67. if train_list_path:
  68. train_list_path = f"{train_list_path}"
  69. else:
  70. train_list_path = f"{dataset_path}/train.txt"
  71. if dataset_type in ['ClsDataset']:
  72. ds_cfg = [
  73. f'DataLoader.Train.dataset.name={dataset_type}',
  74. f'DataLoader.Train.dataset.image_root={dataset_path}',
  75. f'DataLoader.Train.dataset.cls_label_path={train_list_path}',
  76. f'DataLoader.Eval.dataset.name={dataset_type}',
  77. f'DataLoader.Eval.dataset.image_root={dataset_path}',
  78. f'DataLoader.Eval.dataset.cls_label_path={dataset_path}/val.txt',
  79. f'Infer.PostProcess.class_id_map_file={dataset_path}/label.txt'
  80. ]
  81. else:
  82. raise ValueError(f"{repr(dataset_type)} is not supported.")
  83. self.update(ds_cfg)
  84. def update_batch_size(self, batch_size: int, mode: str='train'):
  85. """update batch size setting
  86. Args:
  87. batch_size (int): the batch size number to set.
  88. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  89. Defaults to 'train'.
  90. Raises:
  91. ValueError: `mode` error.
  92. """
  93. if mode == 'train':
  94. if self.DataLoader["Train"]["sampler"].get("batch_size", False):
  95. _cfg = [f'DataLoader.Train.sampler.batch_size={batch_size}']
  96. else:
  97. _cfg = [f'DataLoader.Train.sampler.first_bs={batch_size}']
  98. _cfg = [f'DataLoader.Train.dataset.name=MultiScaleDataset']
  99. elif mode == 'eval':
  100. _cfg = [f'DataLoader.Eval.sampler.batch_size={batch_size}']
  101. elif mode == 'test':
  102. _cfg = [f'DataLoader.Infer.batch_size={batch_size}']
  103. else:
  104. raise ValueError("The input `mode` should be train, eval or test.")
  105. self.update(_cfg)
  106. def update_learning_rate(self, learning_rate: float):
  107. """update learning rate
  108. Args:
  109. learning_rate (float): the learning rate value to set.
  110. """
  111. _cfg = [f'Optimizer.lr.learning_rate={learning_rate}']
  112. self.update(_cfg)
  113. def update_warmup_epochs(self, warmup_epochs: int):
  114. """update warmup epochs
  115. Args:
  116. warmup_epochs (int): the warmup epochs value to set.
  117. """
  118. _cfg = [f'Optimizer.lr.warmup_epoch={warmup_epochs}']
  119. self.update(_cfg)
  120. def update_pretrained_weights(self, pretrained_model: str):
  121. """update pretrained weight path
  122. Args:
  123. pretrained_model (str): the local path or url of pretrained weight file to set.
  124. """
  125. assert isinstance(
  126. pretrained_model, (str, None)
  127. ), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \
  128. indicating that no pretrained model to be used."
  129. if pretrained_model and not pretrained_model.startswith(
  130. ('http://', 'https://')):
  131. pretrained_model = abspath(
  132. pretrained_model.replace(".pdparams", ""))
  133. self.update([f'Global.pretrained_model={pretrained_model}'])
  134. def update_num_classes(self, num_classes: int):
  135. """update classes number
  136. Args:
  137. num_classes (int): the classes number value to set.
  138. """
  139. update_str_list = [f'Arch.class_num={num_classes}']
  140. if self._get_arch_name() == "DistillationModel":
  141. update_str_list.append(
  142. f"Arch.models.0.Teacher.class_num={num_classes}")
  143. update_str_list.append(
  144. f"Arch.models.1.Student.class_num={num_classes}")
  145. self.update(update_str_list)
  146. def _update_slim_config(self, slim_config_path: str):
  147. """update slim settings
  148. Args:
  149. slim_config_path (str): the path to slim config yaml file.
  150. """
  151. slim_config = yaml.load(
  152. open(slim_config_path, 'rb'), Loader=yaml.Loader)['Slim']
  153. self.update([f'Slim={slim_config}'])
  154. def _update_amp(self, amp: Union[None, str]):
  155. """update AMP settings
  156. Args:
  157. amp (None | str): the AMP settings.
  158. Raises:
  159. ValueError: AMP setting `amp` error, missing field `AMP`.
  160. """
  161. if amp is None or amp == 'OFF':
  162. if 'AMP' in self.dict:
  163. self._dict.pop('AMP')
  164. else:
  165. if 'AMP' not in self.dict:
  166. raise ValueError("Config must have AMP information.")
  167. _cfg = ['AMP.use_amp=True', f'AMP.level={amp}']
  168. self.update(_cfg)
  169. def update_num_workers(self, num_workers: int):
  170. """update workers number of train and eval dataloader
  171. Args:
  172. num_workers (int): the value of train and eval dataloader workers number to set.
  173. """
  174. _cfg = [
  175. f'DataLoader.Train.loader.num_workers={num_workers}',
  176. f'DataLoader.Eval.loader.num_workers={num_workers}',
  177. ]
  178. self.update(_cfg)
  179. def enable_shared_memory(self):
  180. """enable shared memory setting of train and eval dataloader
  181. """
  182. _cfg = [
  183. f'DataLoader.Train.loader.use_shared_memory=True',
  184. f'DataLoader.Eval.loader.use_shared_memory=True',
  185. ]
  186. self.update(_cfg)
  187. def disable_shared_memory(self):
  188. """disable shared memory setting of train and eval dataloader
  189. """
  190. _cfg = [
  191. f'DataLoader.Train.loader.use_shared_memory=False',
  192. f'DataLoader.Eval.loader.use_shared_memory=False',
  193. ]
  194. self.update(_cfg)
  195. def update_device(self, device: str):
  196. """update device setting
  197. Args:
  198. device (str): the running device to set
  199. """
  200. device = device.split(':')[0]
  201. _cfg = [f'Global.device={device}']
  202. self.update(_cfg)
  203. def update_label_dict_path(self, dict_path: str):
  204. """update label dict file path
  205. Args:
  206. dict_path (str): the path of label dict file to set
  207. """
  208. _cfg = [f'PostProcess.Topk.class_id_map_file={abspath(dict_path)}', ]
  209. self.update(_cfg)
  210. def _update_to_static(self, dy2st: bool):
  211. """update config to set dynamic to static mode
  212. Args:
  213. dy2st (bool): whether or not to use the dynamic to static mode.
  214. """
  215. self.update([f'Global.to_static={dy2st}'])
  216. def _update_use_vdl(self, use_vdl: bool):
  217. """update config to set VisualDL
  218. Args:
  219. use_vdl (bool): whether or not to use VisualDL.
  220. """
  221. self.update([f'Global.use_visualdl={use_vdl}'])
  222. def _update_epochs(self, epochs: int):
  223. """update epochs setting
  224. Args:
  225. epochs (int): the epochs number value to set
  226. """
  227. self.update([f'Global.epochs={epochs}'])
  228. def _update_checkpoints(self, resume_path: Union[None, str]):
  229. """update checkpoint setting
  230. Args:
  231. resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
  232. train from checkpoint file that path is `.pdparams` file.
  233. """
  234. if resume_path is not None:
  235. resume_path = resume_path.replace(".pdparams", "")
  236. self.update([f'Global.checkpoints={resume_path}'])
  237. def _update_output_dir(self, save_dir: str):
  238. """update output directory
  239. Args:
  240. save_dir (str): the path to save outputs.
  241. """
  242. self.update([f'Global.output_dir={abspath(save_dir)}'])
  243. def update_log_interval(self, log_interval: int):
  244. """update log interval(steps)
  245. Args:
  246. log_interval (int): the log interval value to set.
  247. """
  248. self.update([f'Global.print_batch_step={log_interval}'])
  249. def update_eval_interval(self, eval_interval: int):
  250. """update eval interval(epochs)
  251. Args:
  252. eval_interval (int): the eval interval value to set.
  253. """
  254. self.update([f'Global.eval_interval={eval_interval}'])
  255. def update_save_interval(self, save_interval: int):
  256. """update eval interval(epochs)
  257. Args:
  258. save_interval (int): the save interval value to set.
  259. """
  260. self.update([f'Global.save_interval={save_interval}'])
  261. def _update_predict_img(self, infer_img: str, infer_list: str=None):
  262. """update image to be predicted
  263. Args:
  264. infer_img (str): the path to image that to be predicted.
  265. infer_list (str, optional): the path to file that images. Defaults to None.
  266. """
  267. if infer_list:
  268. self.update([f'Infer.infer_list={infer_list}'])
  269. self.update([f'Infer.infer_imgs={infer_img}'])
  270. def _update_save_inference_dir(self, save_inference_dir: str):
  271. """update directory path to save inference model files
  272. Args:
  273. save_inference_dir (str): the directory path to set.
  274. """
  275. self.update(
  276. [f'Global.save_inference_dir={abspath(save_inference_dir)}'])
  277. def _update_inference_model_dir(self, model_dir: str):
  278. """update inference model directory
  279. Args:
  280. model_dir (str): the directory path of inference model fils that used to predict.
  281. """
  282. self.update([f'Global.inference_model_dir={abspath(model_dir)}'])
  283. def _update_infer_img(self, infer_img: str):
  284. """update path of image that would be predict
  285. Args:
  286. infer_img (str): the image path.
  287. """
  288. self.update([f'Global.infer_imgs={infer_img}'])
  289. def _update_infer_device(self, device: str):
  290. """update the device used in predicting
  291. Args:
  292. device (str): the running device setting
  293. """
  294. self.update([f'Global.use_gpu={device.split(":")[0]=="gpu"}'])
  295. def _update_enable_mkldnn(self, enable_mkldnn: bool):
  296. """update whether to enable MKLDNN
  297. Args:
  298. enable_mkldnn (bool): `True` is enable, otherwise is disable.
  299. """
  300. self.update([f'Global.enable_mkldnn={enable_mkldnn}'])
  301. def _update_infer_img_shape(self, img_shape: str):
  302. """update image cropping shape in the preprocessing
  303. Args:
  304. img_shape (str): the shape of cropping in the preprocessing,
  305. i.e. `PreProcess.transform_ops.1.CropImage.size`.
  306. """
  307. self.update([f'PreProcess.transform_ops.1.CropImage.size={img_shape}'])
  308. def _update_save_predict_result(self, save_dir: str):
  309. """update directory that save predicting output
  310. Args:
  311. save_dir (str): the dicrectory path that save predicting output.
  312. """
  313. self.update([f'Infer.save_dir={save_dir}'])
  314. def update_model(self, **kwargs):
  315. """update model settings
  316. """
  317. for k in kwargs:
  318. v = kwargs[k]
  319. self.update([f'Arch.{k}={v}'])
  320. def update_teacher_model(self, **kwargs):
  321. """update teacher model settings
  322. """
  323. for k in kwargs:
  324. v = kwargs[k]
  325. self.update([f'Arch.models.0.Teacher.{k}={v}'])
  326. def update_student_model(self, **kwargs):
  327. """update student model settings
  328. """
  329. for k in kwargs:
  330. v = kwargs[k]
  331. self.update([f'Arch.models.1.Student.{k}={v}'])
  332. def get_epochs_iters(self) -> int:
  333. """get epochs
  334. Returns:
  335. int: the epochs value, i.e., `Global.epochs` in config.
  336. """
  337. return self.dict['Global']['epochs']
  338. def get_log_interval(self) -> int:
  339. """get log interval(steps)
  340. Returns:
  341. int: the log interval value, i.e., `Global.print_batch_step` in config.
  342. """
  343. return self.dict['Global']['print_batch_step']
  344. def get_eval_interval(self) -> int:
  345. """get eval interval(epochs)
  346. Returns:
  347. int: the eval interval value, i.e., `Global.eval_interval` in config.
  348. """
  349. return self.dict['Global']['eval_interval']
  350. def get_save_interval(self) -> int:
  351. """get save interval(epochs)
  352. Returns:
  353. int: the save interval value, i.e., `Global.save_interval` in config.
  354. """
  355. return self.dict['Global']['save_interval']
  356. def get_learning_rate(self) -> float:
  357. """get learning rate
  358. Returns:
  359. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  360. """
  361. return self.dict['Optimizer']['lr']['learning_rate']
  362. def get_warmup_epochs(self) -> int:
  363. """get warmup epochs
  364. Returns:
  365. int: the warmup epochs value, i.e., `Optimizer.lr.warmup_epochs` in config.
  366. """
  367. return self.dict['Optimizer']['lr']['warmup_epoch']
  368. def get_label_dict_path(self) -> str:
  369. """get label dict file path
  370. Returns:
  371. str: the label dict file path, i.e., `PostProcess.Topk.class_id_map_file` in config.
  372. """
  373. return self.dict['PostProcess']['Topk']['class_id_map_file']
  374. def get_batch_size(self, mode='train') -> int:
  375. """get batch size
  376. Args:
  377. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  378. Defaults to 'train'.
  379. Returns:
  380. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  381. """
  382. return self.dict['DataLoader']['Train']['sampler']['batch_size']
  383. def get_qat_epochs_iters(self) -> int:
  384. """get qat epochs
  385. Returns:
  386. int: the epochs value.
  387. """
  388. return self.get_epochs_iters()
  389. def get_qat_learning_rate(self) -> float:
  390. """get qat learning rate
  391. Returns:
  392. float: the learning rate value.
  393. """
  394. return self.get_learning_rate()
  395. def _get_arch_name(self) -> str:
  396. """get architecture name of model
  397. Returns:
  398. str: the model arch name, i.e., `Arch.name` in config.
  399. """
  400. return self.dict["Arch"]["name"]
  401. def _get_dataset_root(self) -> str:
  402. """get root directory of dataset, i.e. `DataLoader.Train.dataset.image_root`
  403. Returns:
  404. str: the root directory of dataset
  405. """
  406. return self.dict["DataLoader"]["Train"]['dataset']['image_root']
  407. def get_train_save_dir(self) -> str:
  408. """get the directory to save output
  409. Returns:
  410. str: the directory to save output
  411. """
  412. return self['Global']['output_dir']