config.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...base import BaseConfig
  15. from ....utils.misc import abspath
  16. from ..config_helper import PPDetConfigMixin
  17. class InstanceSegConfig(BaseConfig, PPDetConfigMixin):
  18. """ InstanceSegConfig """
  19. def load(self, config_path: str):
  20. """load the config from config file
  21. Args:
  22. config_path (str): the config file path.
  23. """
  24. dict_ = self.load_config_literally(config_path)
  25. self.reset_from_dict(dict_)
  26. def dump(self, config_path: str):
  27. """dump the config
  28. Args:
  29. config_path (str): the path to save dumped config.
  30. """
  31. self.dump_literal_config(config_path, self._dict)
  32. def update(self, dict_like_obj: list):
  33. """update self from dict
  34. Args:
  35. dict_like_obj (list): the list of pairs that contain key and value.
  36. """
  37. self.update_from_dict(dict_like_obj, self._dict)
  38. def update_dataset(self,
  39. dataset_path: str,
  40. dataset_type: str=None,
  41. *,
  42. data_fields: list[str]=None,
  43. image_dir: str="images",
  44. train_anno_path: str="annotations/instance_train.json",
  45. val_anno_path: str="annotations/instance_val.json",
  46. test_anno_path: str="annotations/instance_val.json"):
  47. """update dataset settings
  48. Args:
  49. dataset_path (str): the root path fo dataset.
  50. dataset_type (str, optional): the dataset type. Defaults to None.
  51. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  52. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  53. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  54. Defaults to "annotations/instance_train.json".
  55. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  56. Defaults to "annotations/instance_val.json".
  57. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  58. Defaults to "annotations/instance_val.json".
  59. Raises:
  60. ValueError: the `dataset_type` error.
  61. """
  62. dataset_path = abspath(dataset_path)
  63. if dataset_type is None:
  64. dataset_type = 'COCOInstSegDataset'
  65. if dataset_type == 'COCOInstSegDataset':
  66. ds_cfg = self._make_dataset_config(dataset_path, data_fields,
  67. image_dir, train_anno_path,
  68. val_anno_path, test_anno_path)
  69. self.set_val('metric', 'COCO')
  70. else:
  71. raise ValueError(f"{repr(dataset_type)} is not supported.")
  72. self.update(ds_cfg)
  73. def _make_dataset_config(
  74. self,
  75. dataset_root_path: str,
  76. data_fields: list[str, ]=None,
  77. image_dir: str="images",
  78. train_anno_path: str="annotations/instance_train.json",
  79. val_anno_path: str="annotations/instance_val.json",
  80. test_anno_path: str="annotations/instance_val.json") -> dict:
  81. """construct the dataset config that meets the format requirements
  82. Args:
  83. dataset_root_path (str): the root directory of dataset.
  84. data_fields (list[str,], optional): the data field. Defaults to None.
  85. image_dir (str, optional): _description_. Defaults to "images".
  86. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  87. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  88. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  89. Returns:
  90. dict: the dataset config.
  91. """
  92. data_fields = ['image', 'gt_bbox', 'gt_poly', 'gt_class',
  93. 'is_crowd'] if data_fields is None else data_fields
  94. return {
  95. 'TrainDataset': {
  96. 'name': 'COCOInstSegDataset',
  97. 'image_dir': image_dir,
  98. 'anno_path': train_anno_path,
  99. 'dataset_dir': dataset_root_path,
  100. 'data_fields': data_fields
  101. },
  102. 'EvalDataset': {
  103. 'name': 'COCOInstSegDataset',
  104. 'image_dir': image_dir,
  105. 'anno_path': val_anno_path,
  106. 'dataset_dir': dataset_root_path
  107. },
  108. 'TestDataset': {
  109. 'name': 'ImageFolder',
  110. 'anno_path': test_anno_path,
  111. 'dataset_dir': dataset_root_path
  112. },
  113. }
  114. def update_ema(self,
  115. use_ema: bool,
  116. ema_decay: float=0.9999,
  117. ema_decay_type: str="exponential",
  118. ema_filter_no_grad: bool=True):
  119. """update EMA setting
  120. Args:
  121. use_ema (bool): whether or not to use EMA
  122. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  123. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  124. ema_filter_no_grad (bool, optional): whether or not to filter the parameters that been set to stop gradient
  125. and are not batch norm parameters. Defaults to True.
  126. """
  127. self.update({
  128. 'use_ema': use_ema,
  129. 'ema_decay': ema_decay,
  130. 'ema_decay_type': ema_decay_type,
  131. 'ema_filter_no_grad': ema_filter_no_grad
  132. })
  133. def update_learning_rate(self, learning_rate: float):
  134. """update learning rate
  135. Args:
  136. learning_rate (float): the learning rate value to set.
  137. """
  138. self.LearningRate['base_lr'] = learning_rate
  139. def update_warmup_steps(self, warmup_steps: int):
  140. """update warmup steps
  141. Args:
  142. warmup_steps (int): the warmup steps value to set.
  143. """
  144. schedulers = self.LearningRate['schedulers']
  145. for sch in schedulers:
  146. key = 'name' if 'name' in sch else '_type_'
  147. if sch[key] == 'LinearWarmup':
  148. sch['steps'] = warmup_steps
  149. sch['epochs_first'] = False
  150. def update_warmup_enable(self, use_warmup: bool):
  151. """whether or not to enable learning rate warmup
  152. Args:
  153. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  154. """
  155. schedulers = self.LearningRate['schedulers']
  156. for sch in schedulers:
  157. if 'use_warmup' in sch:
  158. sch['use_warmup'] = use_warmup
  159. def update_cossch_epoch(self, max_epochs: int):
  160. """update max epochs of cosine learning rate scheduler
  161. Args:
  162. max_epochs (int): the max epochs value.
  163. """
  164. schedulers = self.LearningRate['schedulers']
  165. for sch in schedulers:
  166. key = 'name' if 'name' in sch else '_type_'
  167. if sch[key] == 'CosineDecay':
  168. sch['max_epochs'] = max_epochs
  169. def update_milestone(self, milestones: list[int]):
  170. """update milstone of `PiecewiseDecay` learning scheduler
  171. Args:
  172. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  173. """
  174. schedulers = self.LearningRate['schedulers']
  175. for sch in schedulers:
  176. key = 'name' if 'name' in sch else '_type_'
  177. if sch[key] == 'PiecewiseDecay':
  178. sch['milestones'] = milestones
  179. def update_batch_size(self, batch_size: int, mode: str='train'):
  180. """update batch size setting
  181. Args:
  182. batch_size (int): the batch size number to set.
  183. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  184. Defaults to 'train'.
  185. Raises:
  186. ValueError: mode error.
  187. """
  188. assert mode in ('train', 'eval', 'test'), \
  189. 'mode ({}) should be train, eval or test'.format(mode)
  190. if mode == 'train':
  191. self.TrainReader['batch_size'] = batch_size
  192. elif mode == 'eval':
  193. self.EvalReader['batch_size'] = batch_size
  194. else:
  195. self.TestReader['batch_size'] = batch_size
  196. def update_epochs(self, epochs: int):
  197. """update epochs setting
  198. Args:
  199. epochs (int): the epochs number value to set
  200. """
  201. self.update({'epoch': epochs})
  202. def update_device(self, device_type: str):
  203. """update device setting
  204. Args:
  205. device (str): the running device to set
  206. """
  207. if device_type.lower() == "gpu":
  208. self['use_gpu'] = True
  209. else:
  210. assert device_type.lower() == "cpu"
  211. self['use_gpu'] = False
  212. def update_save_dir(self, save_dir: str):
  213. """update directory to save outputs
  214. Args:
  215. save_dir (str): the directory to save outputs.
  216. """
  217. self['save_dir'] = abspath(save_dir)
  218. def update_log_interval(self, log_interval: int):
  219. """update log interval(steps)
  220. Args:
  221. log_interval (int): the log interval value to set.
  222. """
  223. self.update({'log_iter': log_interval})
  224. def update_eval_interval(self, eval_interval: int):
  225. """update eval interval(epochs)
  226. Args:
  227. eval_interval (int): the eval interval value to set.
  228. """
  229. self.update({'snapshot_epoch': eval_interval})
  230. def update_save_interval(self, save_interval: int):
  231. """update eval interval(epochs)
  232. Args:
  233. save_interval (int): the save interval value to set.
  234. """
  235. self.update({'snapshot_epoch': save_interval})
  236. def update_weights(self, weight_path: str):
  237. """update model weight
  238. Args:
  239. weight_path (str): the path to weight file of model.
  240. """
  241. self['weights'] = abspath(weight_path)
  242. def update_pretrained_weights(self, pretrain_weights: str):
  243. """update pretrained weight path
  244. Args:
  245. pretrained_model (str): the local path or url of pretrained weight file to set.
  246. """
  247. if not pretrain_weights.startswith(
  248. 'http://') and not pretrain_weights.startswith('https://'):
  249. pretrain_weights = abspath(pretrain_weights)
  250. self['pretrain_weights'] = pretrain_weights
  251. def update_num_class(self, num_classes: int):
  252. """update classes number
  253. Args:
  254. num_classes (int): the classes number value to set.
  255. """
  256. self['num_classes'] = num_classes
  257. def update_random_size(self, randomsize: list[list[int, int]]):
  258. """update `target_size` of `BatchRandomResize` op in TestReader
  259. Args:
  260. randomsize (list[list[int, int]]): the list of different size scales.
  261. """
  262. self.TestReader['batch_transforms']['BatchRandomResize'][
  263. 'target_size'] = randomsize
  264. def update_num_workers(self, num_workers: int):
  265. """update workers number of train and eval dataloader
  266. Args:
  267. num_workers (int): the value of train and eval dataloader workers number to set.
  268. """
  269. self['worker_num'] = num_workers
  270. def enable_shared_memory(self):
  271. """enable shared memory setting of train and eval dataloader
  272. """
  273. self.update({'TrainReader': {'use_shared_memory': True}})
  274. self.update({'EvalReader': {'use_shared_memory': True}})
  275. def disable_shared_memory(self):
  276. """disable shared memory setting of train and eval dataloader
  277. """
  278. self.update({'TrainReader': {'use_shared_memory': False}})
  279. self.update({'EvalReader': {'use_shared_memory': False}})
  280. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  281. """update static assigner epochs value
  282. Args:
  283. static_assigner_epochs (int): the value of static assigner epochs
  284. """
  285. assert 'PicoHeadV2' in self
  286. self.PicoHeadV2['static_assigner_epoch'] = static_assigner_epochs
  287. def get_epochs_iters(self) -> int:
  288. """get epochs
  289. Returns:
  290. int: the epochs value, i.e., `Global.epochs` in config.
  291. """
  292. return self.epoch
  293. def get_log_interval(self) -> int:
  294. """get log interval(steps)
  295. Returns:
  296. int: the log interval value, i.e., `Global.print_batch_step` in config.
  297. """
  298. self.log_iter
  299. def get_eval_interval(self) -> int:
  300. """get eval interval(epochs)
  301. Returns:
  302. int: the eval interval value, i.e., `Global.eval_interval` in config.
  303. """
  304. self.snapshot_epoch
  305. def get_save_interval(self) -> int:
  306. """get save interval(epochs)
  307. Returns:
  308. int: the save interval value, i.e., `Global.save_interval` in config.
  309. """
  310. self.snapshot_epoch
  311. def get_learning_rate(self) -> float:
  312. """get learning rate
  313. Returns:
  314. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  315. """
  316. return self.LearningRate['base_lr']
  317. def get_batch_size(self, mode='train') -> int:
  318. """get batch size
  319. Args:
  320. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  321. Defaults to 'train'.
  322. Returns:
  323. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  324. """
  325. if mode == 'train':
  326. return self.TrainReader['batch_size']
  327. elif mode == 'eval':
  328. return self.EvalReader['batch_size']
  329. elif mode == 'test':
  330. return self.TestReader['batch_size']
  331. else:
  332. raise (f"Unknown mode: {repr(mode)}")
  333. def get_qat_epochs_iters(self) -> int:
  334. """get qat epochs
  335. Returns:
  336. int: the epochs value.
  337. """
  338. return self.epoch // 2.0
  339. def get_qat_learning_rate(self) -> float:
  340. """get qat learning rate
  341. Returns:
  342. float: the learning rate value.
  343. """
  344. return self.LearningRate['base_lr'] // 2.0
  345. def get_train_save_dir(self) -> str:
  346. """get the directory to save output
  347. Returns:
  348. str: the directory to save output
  349. """
  350. return self.save_dir