config.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. from ...base import BaseConfig
  12. from ....utils.misc import abspath
  13. from ..config_helper import PPDetConfigMixin
  14. class InstanceSegConfig(BaseConfig, PPDetConfigMixin):
  15. """ InstanceSegConfig """
  16. def load(self, config_path: str):
  17. """load the config from config file
  18. Args:
  19. config_path (str): the config file path.
  20. """
  21. dict_ = self.load_config_literally(config_path)
  22. self.reset_from_dict(dict_)
  23. def dump(self, config_path: str):
  24. """dump the config
  25. Args:
  26. config_path (str): the path to save dumped config.
  27. """
  28. self.dump_literal_config(config_path, self._dict)
  29. def update(self, dict_like_obj: list):
  30. """update self from dict
  31. Args:
  32. dict_like_obj (list): the list of pairs that contain key and value.
  33. """
  34. self.update_from_dict(dict_like_obj, self._dict)
  35. def update_dataset(self,
  36. dataset_path: str,
  37. dataset_type: str=None,
  38. *,
  39. data_fields: list[str]=None,
  40. image_dir: str="images",
  41. train_anno_path: str="annotations/instance_train.json",
  42. val_anno_path: str="annotations/instance_val.json",
  43. test_anno_path: str="annotations/instance_val.json"):
  44. """update dataset settings
  45. Args:
  46. dataset_path (str): the root path fo dataset.
  47. dataset_type (str, optional): the dataset type. Defaults to None.
  48. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  49. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  50. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  51. Defaults to "annotations/instance_train.json".
  52. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  53. Defaults to "annotations/instance_val.json".
  54. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  55. Defaults to "annotations/instance_val.json".
  56. Raises:
  57. ValueError: the `dataset_type` error.
  58. """
  59. dataset_path = abspath(dataset_path)
  60. if dataset_type is None:
  61. dataset_type = 'COCOInstSegDataset'
  62. if dataset_type == 'COCOInstSegDataset':
  63. ds_cfg = self._make_dataset_config(dataset_path, data_fields,
  64. image_dir, train_anno_path,
  65. val_anno_path, test_anno_path)
  66. self.set_val('metric', 'COCO')
  67. else:
  68. raise ValueError(f"{repr(dataset_type)} is not supported.")
  69. self.update(ds_cfg)
  70. def _make_dataset_config(
  71. self,
  72. dataset_root_path: str,
  73. data_fields: list[str, ]=None,
  74. image_dir: str="images",
  75. train_anno_path: str="annotations/instance_train.json",
  76. val_anno_path: str="annotations/instance_val.json",
  77. test_anno_path: str="annotations/instance_val.json") -> dict:
  78. """construct the dataset config that meets the format requirements
  79. Args:
  80. dataset_root_path (str): the root directory of dataset.
  81. data_fields (list[str,], optional): the data field. Defaults to None.
  82. image_dir (str, optional): _description_. Defaults to "images".
  83. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  84. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  85. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  86. Returns:
  87. dict: the dataset config.
  88. """
  89. data_fields = ['image', 'gt_bbox', 'gt_poly', 'gt_class',
  90. 'is_crowd'] if data_fields is None else data_fields
  91. return {
  92. 'TrainDataset': {
  93. 'name': 'COCOInstSegDataset',
  94. 'image_dir': image_dir,
  95. 'anno_path': train_anno_path,
  96. 'dataset_dir': dataset_root_path,
  97. 'data_fields': data_fields
  98. },
  99. 'EvalDataset': {
  100. 'name': 'COCOInstSegDataset',
  101. 'image_dir': image_dir,
  102. 'anno_path': val_anno_path,
  103. 'dataset_dir': dataset_root_path
  104. },
  105. 'TestDataset': {
  106. 'name': 'ImageFolder',
  107. 'anno_path': test_anno_path,
  108. 'dataset_dir': dataset_root_path
  109. },
  110. }
  111. def update_ema(self,
  112. use_ema: bool,
  113. ema_decay: float=0.9999,
  114. ema_decay_type: str="exponential",
  115. ema_filter_no_grad: bool=True):
  116. """update EMA setting
  117. Args:
  118. use_ema (bool): whether or not to use EMA
  119. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  120. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  121. ema_filter_no_grad (bool, optional): whether or not to filter the parameters that been set to stop gradient
  122. and are not batch norm parameters. Defaults to True.
  123. """
  124. self.update({
  125. 'use_ema': use_ema,
  126. 'ema_decay': ema_decay,
  127. 'ema_decay_type': ema_decay_type,
  128. 'ema_filter_no_grad': ema_filter_no_grad
  129. })
  130. def update_learning_rate(self, learning_rate: float):
  131. """update learning rate
  132. Args:
  133. learning_rate (float): the learning rate value to set.
  134. """
  135. self.LearningRate['base_lr'] = learning_rate
  136. def update_warmup_steps(self, warmup_steps: int):
  137. """update warmup steps
  138. Args:
  139. warmup_steps (int): the warmup steps value to set.
  140. """
  141. schedulers = self.LearningRate['schedulers']
  142. for sch in schedulers:
  143. key = 'name' if 'name' in sch else '_type_'
  144. if sch[key] == 'LinearWarmup':
  145. sch['steps'] = warmup_steps
  146. sch['epochs_first'] = False
  147. def update_warmup_enable(self, use_warmup: bool):
  148. """whether or not to enable learning rate warmup
  149. Args:
  150. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  151. """
  152. schedulers = self.LearningRate['schedulers']
  153. for sch in schedulers:
  154. if 'use_warmup' in sch:
  155. sch['use_warmup'] = use_warmup
  156. def update_cossch_epoch(self, max_epochs: int):
  157. """update max epochs of cosine learning rate scheduler
  158. Args:
  159. max_epochs (int): the max epochs value.
  160. """
  161. schedulers = self.LearningRate['schedulers']
  162. for sch in schedulers:
  163. key = 'name' if 'name' in sch else '_type_'
  164. if sch[key] == 'CosineDecay':
  165. sch['max_epochs'] = max_epochs
  166. def update_milestone(self, milestones: list[int]):
  167. """update milstone of `PiecewiseDecay` learning scheduler
  168. Args:
  169. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  170. """
  171. schedulers = self.LearningRate['schedulers']
  172. for sch in schedulers:
  173. key = 'name' if 'name' in sch else '_type_'
  174. if sch[key] == 'PiecewiseDecay':
  175. sch['milestones'] = milestones
  176. def update_batch_size(self, batch_size: int, mode: str='train'):
  177. """update batch size setting
  178. Args:
  179. batch_size (int): the batch size number to set.
  180. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  181. Defaults to 'train'.
  182. Raises:
  183. ValueError: mode error.
  184. """
  185. assert mode in ('train', 'eval', 'test'), \
  186. 'mode ({}) should be train, eval or test'.format(mode)
  187. if mode == 'train':
  188. self.TrainReader['batch_size'] = batch_size
  189. elif mode == 'eval':
  190. self.EvalReader['batch_size'] = batch_size
  191. else:
  192. self.TestReader['batch_size'] = batch_size
  193. def update_epochs(self, epochs: int):
  194. """update epochs setting
  195. Args:
  196. epochs (int): the epochs number value to set
  197. """
  198. self.update({'epoch': epochs})
  199. def update_device(self, device_type: str):
  200. """update device setting
  201. Args:
  202. device (str): the running device to set
  203. """
  204. if device_type.lower() == "gpu":
  205. self['use_gpu'] = True
  206. else:
  207. assert device_type.lower() == "cpu"
  208. self['use_gpu'] = False
  209. def update_save_dir(self, save_dir: str):
  210. """update directory to save outputs
  211. Args:
  212. save_dir (str): the directory to save outputs.
  213. """
  214. self['save_dir'] = abspath(save_dir)
  215. def update_log_interval(self, log_interval: int):
  216. """update log interval(steps)
  217. Args:
  218. log_interval (int): the log interval value to set.
  219. """
  220. self.update({'log_iter': log_interval})
  221. def update_eval_interval(self, eval_interval: int):
  222. """update eval interval(epochs)
  223. Args:
  224. eval_interval (int): the eval interval value to set.
  225. """
  226. self.update({'snapshot_epoch': eval_interval})
  227. def update_save_interval(self, save_interval: int):
  228. """update eval interval(epochs)
  229. Args:
  230. save_interval (int): the save interval value to set.
  231. """
  232. self.update({'snapshot_epoch': save_interval})
  233. def update_weights(self, weight_path: str):
  234. """update model weight
  235. Args:
  236. weight_path (str): the path to weight file of model.
  237. """
  238. self['weights'] = abspath(weight_path)
  239. def update_pretrained_weights(self, pretrain_weights: str):
  240. """update pretrained weight path
  241. Args:
  242. pretrained_model (str): the local path or url of pretrained weight file to set.
  243. """
  244. if not pretrain_weights.startswith(
  245. 'http://') and not pretrain_weights.startswith('https://'):
  246. pretrain_weights = abspath(pretrain_weights)
  247. self['pretrain_weights'] = pretrain_weights
  248. def update_num_class(self, num_classes: int):
  249. """update classes number
  250. Args:
  251. num_classes (int): the classes number value to set.
  252. """
  253. self['num_classes'] = num_classes
  254. def update_random_size(self, randomsize: list[list[int, int]]):
  255. """update `target_size` of `BatchRandomResize` op in TestReader
  256. Args:
  257. randomsize (list[list[int, int]]): the list of different size scales.
  258. """
  259. self.TestReader['batch_transforms']['BatchRandomResize'][
  260. 'target_size'] = randomsize
  261. def update_num_workers(self, num_workers: int):
  262. """update workers number of train and eval dataloader
  263. Args:
  264. num_workers (int): the value of train and eval dataloader workers number to set.
  265. """
  266. self['worker_num'] = num_workers
  267. def enable_shared_memory(self):
  268. """enable shared memory setting of train and eval dataloader
  269. """
  270. self.update({'TrainReader': {'use_shared_memory': True}})
  271. self.update({'EvalReader': {'use_shared_memory': True}})
  272. def disable_shared_memory(self):
  273. """disable shared memory setting of train and eval dataloader
  274. """
  275. self.update({'TrainReader': {'use_shared_memory': False}})
  276. self.update({'EvalReader': {'use_shared_memory': False}})
  277. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  278. """update static assigner epochs value
  279. Args:
  280. static_assigner_epochs (int): the value of static assigner epochs
  281. """
  282. assert 'PicoHeadV2' in self
  283. self.PicoHeadV2['static_assigner_epoch'] = static_assigner_epochs
  284. def get_epochs_iters(self) -> int:
  285. """get epochs
  286. Returns:
  287. int: the epochs value, i.e., `Global.epochs` in config.
  288. """
  289. return self.epoch
  290. def get_log_interval(self) -> int:
  291. """get log interval(steps)
  292. Returns:
  293. int: the log interval value, i.e., `Global.print_batch_step` in config.
  294. """
  295. self.log_iter
  296. def get_eval_interval(self) -> int:
  297. """get eval interval(epochs)
  298. Returns:
  299. int: the eval interval value, i.e., `Global.eval_interval` in config.
  300. """
  301. self.snapshot_epoch
  302. def get_save_interval(self) -> int:
  303. """get save interval(epochs)
  304. Returns:
  305. int: the save interval value, i.e., `Global.save_interval` in config.
  306. """
  307. self.snapshot_epoch
  308. def get_learning_rate(self) -> float:
  309. """get learning rate
  310. Returns:
  311. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  312. """
  313. return self.LearningRate['base_lr']
  314. def get_batch_size(self, mode='train') -> int:
  315. """get batch size
  316. Args:
  317. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  318. Defaults to 'train'.
  319. Returns:
  320. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  321. """
  322. if mode == 'train':
  323. return self.TrainReader['batch_size']
  324. elif mode == 'eval':
  325. return self.EvalReader['batch_size']
  326. elif mode == 'test':
  327. return self.TestReader['batch_size']
  328. else:
  329. raise (f"Unknown mode: {repr(mode)}")
  330. def get_qat_epochs_iters(self) -> int:
  331. """get qat epochs
  332. Returns:
  333. int: the epochs value.
  334. """
  335. return self.epoch // 2.0
  336. def get_qat_learning_rate(self) -> float:
  337. """get qat learning rate
  338. Returns:
  339. float: the learning rate value.
  340. """
  341. return self.LearningRate['base_lr'] // 2.0
  342. def get_train_save_dir(self) -> str:
  343. """get the directory to save output
  344. Returns:
  345. str: the directory to save output
  346. """
  347. return self.save_dir