config.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union
  15. import yaml
  16. from ....utils.misc import abspath
  17. from ...base import BaseConfig
  18. class ClsConfig(BaseConfig):
  19. """Image Classification Task Config"""
  20. def update(self, list_like_obj: list):
  21. """update self
  22. Args:
  23. list_like_obj (list): list of pairs(key0.key1.idx.key2=value), such as:
  24. [
  25. 'topk=2',
  26. 'VALID.transforms.1.ResizeImage.resize_short=300'
  27. ]
  28. """
  29. from paddleclas.ppcls.utils.config import override_config
  30. dict_ = override_config(self.dict, list_like_obj)
  31. self.reset_from_dict(dict_)
  32. def load(self, config_file_path: str):
  33. """load config from yaml file
  34. Args:
  35. config_file_path (str): the path of yaml file.
  36. Raises:
  37. TypeError: the content of yaml file `config_file_path` error.
  38. """
  39. dict_ = yaml.load(open(config_file_path, "rb"), Loader=yaml.Loader)
  40. if not isinstance(dict_, dict):
  41. raise TypeError
  42. self.reset_from_dict(dict_)
  43. def dump(self, config_file_path: str):
  44. """dump self to yaml file
  45. Args:
  46. config_file_path (str): the path to save self as yaml file.
  47. """
  48. with open(config_file_path, "w", encoding="utf-8") as f:
  49. yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
  50. def update_dataset(
  51. self,
  52. dataset_path: str,
  53. dataset_type: str = None,
  54. *,
  55. train_list_path: str = None,
  56. ):
  57. """update dataset settings
  58. Args:
  59. dataset_path (str): the root path of dataset.
  60. dataset_type (str, optional): dataset type. Defaults to None.
  61. train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
  62. Raises:
  63. ValueError: the dataset_type error.
  64. """
  65. dataset_path = abspath(dataset_path)
  66. if dataset_type is None:
  67. dataset_type = "ClsDataset"
  68. if train_list_path:
  69. train_list_path = f"{train_list_path}"
  70. else:
  71. train_list_path = f"{dataset_path}/train.txt"
  72. if dataset_type in ["ClsDataset", "MLClsDataset"]:
  73. ds_cfg = [
  74. f"DataLoader.Train.dataset.name={dataset_type}",
  75. f"DataLoader.Train.dataset.image_root={dataset_path}",
  76. f"DataLoader.Train.dataset.cls_label_path={train_list_path}",
  77. f"DataLoader.Eval.dataset.name={dataset_type}",
  78. f"DataLoader.Eval.dataset.image_root={dataset_path}",
  79. f"DataLoader.Eval.dataset.cls_label_path={dataset_path}/val.txt",
  80. f"Infer.PostProcess.class_id_map_file={dataset_path}/label.txt",
  81. ]
  82. else:
  83. raise ValueError(f"{repr(dataset_type)} is not supported.")
  84. self.update(ds_cfg)
  85. def update_batch_size(self, batch_size: int, mode: str = "train"):
  86. """update batch size setting
  87. Args:
  88. batch_size (int): the batch size number to set.
  89. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  90. Defaults to 'train'.
  91. Raises:
  92. ValueError: `mode` error.
  93. """
  94. if mode == "train":
  95. if self.DataLoader["Train"]["sampler"].get("batch_size", False):
  96. _cfg = [f"DataLoader.Train.sampler.batch_size={batch_size}"]
  97. else:
  98. _cfg = [f"DataLoader.Train.sampler.first_bs={batch_size}"]
  99. _cfg = [f"DataLoader.Train.dataset.name=MultiScaleDataset"]
  100. elif mode == "eval":
  101. _cfg = [f"DataLoader.Eval.sampler.batch_size={batch_size}"]
  102. elif mode == "test":
  103. _cfg = [f"DataLoader.Infer.batch_size={batch_size}"]
  104. else:
  105. raise ValueError("The input `mode` should be train, eval or test.")
  106. self.update(_cfg)
  107. def update_learning_rate(self, learning_rate: float):
  108. """update learning rate
  109. Args:
  110. learning_rate (float): the learning rate value to set.
  111. """
  112. if self._dict["Optimizer"]["lr"].get("learning_rate", None) is not None:
  113. _cfg = [f"Optimizer.lr.learning_rate={learning_rate}"]
  114. elif self._dict["Optimizer"]["lr"].get("max_learning_rate", None) is not None:
  115. _cfg = [f"Optimizer.lr.max_learning_rate={learning_rate}"]
  116. else:
  117. raise ValueError("unsupported lr format")
  118. self.update(_cfg)
  119. def update_warmup_epochs(self, warmup_epochs: int):
  120. """update warmup epochs
  121. Args:
  122. warmup_epochs (int): the warmup epochs value to set.
  123. """
  124. _cfg = [f"Optimizer.lr.warmup_epoch={warmup_epochs}"]
  125. self.update(_cfg)
  126. def update_pretrained_weights(self, pretrained_model: str):
  127. """update pretrained weight path
  128. Args:
  129. pretrained_model (str): the local path or url of pretrained weight file to set.
  130. """
  131. assert isinstance(
  132. pretrained_model, (str, type(None))
  133. ), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \
  134. indicating that no pretrained model to be used."
  135. if pretrained_model is None:
  136. self.update(["Global.pretrained_model=None"])
  137. self.update(["Arch.pretrained=False"])
  138. else:
  139. if pretrained_model.lower() == "default":
  140. self.update(["Global.pretrained_model=None"])
  141. self.update(["Arch.pretrained=True"])
  142. else:
  143. if not pretrained_model.startswith(("http://", "https://")):
  144. pretrained_model = abspath(
  145. pretrained_model.replace(".pdparams", "")
  146. )
  147. self.update([f"Global.pretrained_model={pretrained_model}"])
  148. def update_num_classes(self, num_classes: int):
  149. """update classes number
  150. Args:
  151. num_classes (int): the classes number value to set.
  152. """
  153. update_str_list = [f"Arch.class_num={num_classes}"]
  154. if self._get_arch_name() == "DistillationModel":
  155. update_str_list.append(f"Arch.models.0.Teacher.class_num={num_classes}")
  156. update_str_list.append(f"Arch.models.1.Student.class_num={num_classes}")
  157. ml_decoder = self.dict.get("MLDecoder", None)
  158. if ml_decoder is not None:
  159. self.update_ml_query_num(num_classes)
  160. self.update_ml_class_num(num_classes)
  161. self.update(update_str_list)
  162. def update_ml_query_num(self, query_num: int):
  163. """update MLDecoder query number
  164. Args:
  165. query_num (int): the query number value to set,qury_num should be less than or equal to num_classes.
  166. """
  167. base_query_num = self.dict.get("MLDecoder", {}).get("query_num", None)
  168. if base_query_num is not None:
  169. _cfg = [f"MLDecoder.query_num={query_num}"]
  170. self.update(_cfg)
  171. def update_ml_class_num(self, class_num: int):
  172. """update MLDecoder query number
  173. Args:
  174. num_classes (int): the classes number value to set.
  175. """
  176. base_class_num = self.dict.get("MLDecoder", {}).get("class_num", None)
  177. if base_class_num is not None:
  178. _cfg = [f"MLDecoder.class_num={class_num}"]
  179. self.update(_cfg)
  180. def _update_slim_config(self, slim_config_path: str):
  181. """update slim settings
  182. Args:
  183. slim_config_path (str): the path to slim config yaml file.
  184. """
  185. slim_config = yaml.load(open(slim_config_path, "rb"), Loader=yaml.Loader)[
  186. "Slim"
  187. ]
  188. self.update([f"Slim={slim_config}"])
  189. def _update_amp(self, amp: Union[None, str]):
  190. """update AMP settings
  191. Args:
  192. amp (None | str): the AMP settings.
  193. Raises:
  194. ValueError: AMP setting `amp` error, missing field `AMP`.
  195. """
  196. if amp is None or amp == "OFF":
  197. if "AMP" in self.dict:
  198. self._dict.pop("AMP")
  199. else:
  200. if "AMP" not in self.dict:
  201. raise ValueError("Config must have AMP information.")
  202. _cfg = ["AMP.use_amp=True", f"AMP.level={amp}"]
  203. self.update(_cfg)
  204. def update_num_workers(self, num_workers: int):
  205. """update workers number of train and eval dataloader
  206. Args:
  207. num_workers (int): the value of train and eval dataloader workers number to set.
  208. """
  209. _cfg = [
  210. f"DataLoader.Train.loader.num_workers={num_workers}",
  211. f"DataLoader.Eval.loader.num_workers={num_workers}",
  212. ]
  213. self.update(_cfg)
  214. def update_shared_memory(self, shared_memeory: bool):
  215. """update shared memory setting of train and eval dataloader
  216. Args:
  217. shared_memeory (bool): whether or not to use shared memory
  218. """
  219. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  220. _cfg = [
  221. f"DataLoader.Train.loader.use_shared_memory={shared_memeory}",
  222. f"DataLoader.Eval.loader.use_shared_memory={shared_memeory}",
  223. ]
  224. self.update(_cfg)
  225. def update_shuffle(self, shuffle: bool):
  226. """update shuffle setting of train and eval dataloader
  227. Args:
  228. shuffle (bool): whether or not to shuffle the data
  229. """
  230. assert isinstance(shuffle, bool), "shuffle should be a bool"
  231. _cfg = [
  232. f"DataLoader.Train.loader.shuffle={shuffle}",
  233. f"DataLoader.Eval.loader.shuffle={shuffle}",
  234. ]
  235. self.update(_cfg)
  236. def update_dali(self, dali: bool):
  237. """enable DALI setting of train and eval dataloader
  238. Args:
  239. dali (bool): whether or not to use DALI
  240. """
  241. assert isinstance(dali, bool), "dali should be a bool"
  242. _cfg = [
  243. f"Global.use_dali={dali}",
  244. f"Global.use_dali={dali}",
  245. ]
  246. self.update(_cfg)
  247. def update_seed(self, seed: int):
  248. """update seed
  249. Args:
  250. seed (int): the random seed value to set
  251. """
  252. _cfg = [f"Global.seed={seed}"]
  253. self.update(_cfg)
  254. def update_device(self, device: str):
  255. """update device setting
  256. Args:
  257. device (str): the running device to set
  258. """
  259. device = device.split(":")[0]
  260. _cfg = [f"Global.device={device}"]
  261. self.update(_cfg)
  262. def update_label_dict_path(self, dict_path: str):
  263. """update label dict file path
  264. Args:
  265. dict_path (str): the path of label dict file to set
  266. """
  267. _cfg = [
  268. f"PostProcess.Topk.class_id_map_file={abspath(dict_path)}",
  269. ]
  270. self.update(_cfg)
  271. def _update_to_static(self, dy2st: bool):
  272. """update config to set dynamic to static mode
  273. Args:
  274. dy2st (bool): whether or not to use the dynamic to static mode.
  275. """
  276. self.update([f"Global.to_static={dy2st}"])
  277. def _update_use_vdl(self, use_vdl: bool):
  278. """update config to set VisualDL
  279. Args:
  280. use_vdl (bool): whether or not to use VisualDL.
  281. """
  282. self.update([f"Global.use_visualdl={use_vdl}"])
  283. def _update_epochs(self, epochs: int):
  284. """update epochs setting
  285. Args:
  286. epochs (int): the epochs number value to set
  287. """
  288. self.update([f"Global.epochs={epochs}"])
  289. def _update_checkpoints(self, resume_path: Union[None, str]):
  290. """update checkpoint setting
  291. Args:
  292. resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
  293. train from checkpoint file that path is `.pdparams` file.
  294. """
  295. if resume_path is not None:
  296. resume_path = resume_path.replace(".pdparams", "")
  297. self.update([f"Global.checkpoints={resume_path}"])
  298. def _update_output_dir(self, save_dir: str):
  299. """update output directory
  300. Args:
  301. save_dir (str): the path to save outputs.
  302. """
  303. self.update([f"Global.output_dir={abspath(save_dir)}"])
  304. def update_log_interval(self, log_interval: int):
  305. """update log interval(steps)
  306. Args:
  307. log_interval (int): the log interval value to set.
  308. """
  309. self.update([f"Global.print_batch_step={log_interval}"])
  310. def update_eval_interval(self, eval_interval: int):
  311. """update eval interval(epochs)
  312. Args:
  313. eval_interval (int): the eval interval value to set.
  314. """
  315. self.update([f"Global.eval_interval={eval_interval}"])
  316. def update_save_interval(self, save_interval: int):
  317. """update eval interval(epochs)
  318. Args:
  319. save_interval (int): the save interval value to set.
  320. """
  321. self.update([f"Global.save_interval={save_interval}"])
  322. def update_log_ranks(self, device):
  323. """update log ranks
  324. Args:
  325. device (str): the running device to set
  326. """
  327. log_ranks = device.split(":")[1]
  328. self.update([f'Global.log_ranks="{log_ranks}"'])
  329. def update_print_mem_info(self, print_mem_info: bool):
  330. """setting print memory info"""
  331. assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
  332. self.update([f"Global.print_mem_info={print_mem_info}"])
  333. def _update_predict_img(self, infer_img: str, infer_list: str = None):
  334. """update image to be predicted
  335. Args:
  336. infer_img (str): the path to image that to be predicted.
  337. infer_list (str, optional): the path to file that images. Defaults to None.
  338. """
  339. if infer_list:
  340. self.update([f"Infer.infer_list={infer_list}"])
  341. self.update([f"Infer.infer_imgs={infer_img}"])
  342. def _update_save_inference_dir(self, save_inference_dir: str):
  343. """update directory path to save inference model files
  344. Args:
  345. save_inference_dir (str): the directory path to set.
  346. """
  347. self.update([f"Global.save_inference_dir={abspath(save_inference_dir)}"])
  348. def _update_inference_model_dir(self, model_dir: str):
  349. """update inference model directory
  350. Args:
  351. model_dir (str): the directory path of inference model fils that used to predict.
  352. """
  353. self.update([f"Global.inference_model_dir={abspath(model_dir)}"])
  354. def _update_infer_img(self, infer_img: str):
  355. """update path of image that would be predict
  356. Args:
  357. infer_img (str): the image path.
  358. """
  359. self.update([f"Global.infer_imgs={infer_img}"])
  360. def _update_infer_device(self, device: str):
  361. """update the device used in predicting
  362. Args:
  363. device (str): the running device setting
  364. """
  365. self.update([f'Global.use_gpu={device.split(":")[0]=="gpu"}'])
  366. def _update_enable_mkldnn(self, enable_mkldnn: bool):
  367. """update whether to enable MKLDNN
  368. Args:
  369. enable_mkldnn (bool): `True` is enable, otherwise is disable.
  370. """
  371. self.update([f"Global.enable_mkldnn={enable_mkldnn}"])
  372. def _update_infer_img_shape(self, img_shape: str):
  373. """update image cropping shape in the preprocessing
  374. Args:
  375. img_shape (str): the shape of cropping in the preprocessing,
  376. i.e. `PreProcess.transform_ops.1.CropImage.size`.
  377. """
  378. self.update([f"PreProcess.transform_ops.1.CropImage.size={img_shape}"])
  379. def _update_save_predict_result(self, save_dir: str):
  380. """update directory that save predicting output
  381. Args:
  382. save_dir (str): the directory path that save predicting output.
  383. """
  384. self.update([f"Infer.save_dir={save_dir}"])
  385. def update_model(self, **kwargs):
  386. """update model settings"""
  387. for k in kwargs:
  388. v = kwargs[k]
  389. self.update([f"Arch.{k}={v}"])
  390. def update_teacher_model(self, **kwargs):
  391. """update teacher model settings"""
  392. for k in kwargs:
  393. v = kwargs[k]
  394. self.update([f"Arch.models.0.Teacher.{k}={v}"])
  395. def update_student_model(self, **kwargs):
  396. """update student model settings"""
  397. for k in kwargs:
  398. v = kwargs[k]
  399. self.update([f"Arch.models.1.Student.{k}={v}"])
  400. def get_epochs_iters(self) -> int:
  401. """get epochs
  402. Returns:
  403. int: the epochs value, i.e., `Global.epochs` in config.
  404. """
  405. return self.dict["Global"]["epochs"]
  406. def get_log_interval(self) -> int:
  407. """get log interval(steps)
  408. Returns:
  409. int: the log interval value, i.e., `Global.print_batch_step` in config.
  410. """
  411. return self.dict["Global"]["print_batch_step"]
  412. def get_eval_interval(self) -> int:
  413. """get eval interval(epochs)
  414. Returns:
  415. int: the eval interval value, i.e., `Global.eval_interval` in config.
  416. """
  417. return self.dict["Global"]["eval_interval"]
  418. def get_save_interval(self) -> int:
  419. """get save interval(epochs)
  420. Returns:
  421. int: the save interval value, i.e., `Global.save_interval` in config.
  422. """
  423. return self.dict["Global"]["save_interval"]
  424. def get_learning_rate(self) -> float:
  425. """get learning rate
  426. Returns:
  427. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  428. """
  429. return self.dict["Optimizer"]["lr"]["learning_rate"]
  430. def get_warmup_epochs(self) -> int:
  431. """get warmup epochs
  432. Returns:
  433. int: the warmup epochs value, i.e., `Optimizer.lr.warmup_epochs` in config.
  434. """
  435. return self.dict["Optimizer"]["lr"]["warmup_epoch"]
  436. def get_label_dict_path(self) -> str:
  437. """get label dict file path
  438. Returns:
  439. str: the label dict file path, i.e., `PostProcess.Topk.class_id_map_file` in config.
  440. """
  441. return self.dict["PostProcess"]["Topk"]["class_id_map_file"]
  442. def get_batch_size(self, mode="train") -> int:
  443. """get batch size
  444. Args:
  445. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  446. Defaults to 'train'.
  447. Returns:
  448. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  449. """
  450. return self.dict["DataLoader"]["Train"]["sampler"]["batch_size"]
  451. def get_qat_epochs_iters(self) -> int:
  452. """get qat epochs
  453. Returns:
  454. int: the epochs value.
  455. """
  456. return self.get_epochs_iters()
  457. def get_qat_learning_rate(self) -> float:
  458. """get qat learning rate
  459. Returns:
  460. float: the learning rate value.
  461. """
  462. return self.get_learning_rate()
  463. def _get_arch_name(self) -> str:
  464. """get architecture name of model
  465. Returns:
  466. str: the model arch name, i.e., `Arch.name` in config.
  467. """
  468. return self.dict["Arch"]["name"]
  469. def _get_dataset_root(self) -> str:
  470. """get root directory of dataset, i.e. `DataLoader.Train.dataset.image_root`
  471. Returns:
  472. str: the root directory of dataset
  473. """
  474. return self.dict["DataLoader"]["Train"]["dataset"]["image_root"]
  475. def get_train_save_dir(self) -> str:
  476. """get the directory to save output
  477. Returns:
  478. str: the directory to save output
  479. """
  480. return self["Global"]["output_dir"]