config.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...base import BaseConfig
  15. from ....utils.misc import abspath
  16. from ....utils import logging
  17. from ..config_helper import PPDetConfigMixin
  18. class DetConfig(BaseConfig, PPDetConfigMixin):
  19. """ DetConfig """
  20. def load(self, config_path: str):
  21. """load the config from config file
  22. Args:
  23. config_path (str): the config file path.
  24. """
  25. dict_ = self.load_config_literally(config_path)
  26. self.reset_from_dict(dict_)
  27. def dump(self, config_path: str):
  28. """dump the config
  29. Args:
  30. config_path (str): the path to save dumped config.
  31. """
  32. self.dump_literal_config(config_path, self._dict)
  33. def update(self, dict_like_obj: list):
  34. """update self from dict
  35. Args:
  36. dict_like_obj (list): the list of pairs that contain key and value.
  37. """
  38. self.update_from_dict(dict_like_obj, self._dict)
  39. def update_dataset(self,
  40. dataset_path: str,
  41. dataset_type: str=None,
  42. *,
  43. data_fields: list[str]=None,
  44. image_dir: str="images",
  45. train_anno_path: str="annotations/instance_train.json",
  46. val_anno_path: str="annotations/instance_val.json",
  47. test_anno_path: str="annotations/instance_val.json"):
  48. """update dataset settings
  49. Args:
  50. dataset_path (str): the root path fo dataset.
  51. dataset_type (str, optional): the dataset type. Defaults to None.
  52. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  53. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  54. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  55. Defaults to "annotations/instance_train.json".
  56. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  57. Defaults to "annotations/instance_val.json".
  58. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  59. Defaults to "annotations/instance_val.json".
  60. Raises:
  61. ValueError: the `dataset_type` error.
  62. """
  63. dataset_path = abspath(dataset_path)
  64. if dataset_type is None:
  65. dataset_type = 'COCODetDataset'
  66. if dataset_type == 'COCODetDataset':
  67. ds_cfg = self._make_dataset_config(dataset_path, data_fields,
  68. image_dir, train_anno_path,
  69. val_anno_path, test_anno_path)
  70. self.set_val('metric', 'COCO')
  71. else:
  72. raise ValueError(f"{repr(dataset_type)} is not supported.")
  73. self.update(ds_cfg)
  74. def _make_dataset_config(
  75. self,
  76. dataset_root_path: str,
  77. data_fields: list[str, ]=None,
  78. image_dir: str="images",
  79. train_anno_path: str="annotations/instance_train.json",
  80. val_anno_path: str="annotations/instance_val.json",
  81. test_anno_path: str="annotations/instance_val.json") -> dict:
  82. """construct the dataset config that meets the format requirements
  83. Args:
  84. dataset_root_path (str): the root directory of dataset.
  85. data_fields (list[str,], optional): the data field. Defaults to None.
  86. image_dir (str, optional): _description_. Defaults to "images".
  87. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  88. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  89. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  90. Returns:
  91. dict: the dataset config.
  92. """
  93. data_fields = ['image', 'gt_bbox', 'gt_class',
  94. 'is_crowd'] if data_fields is None else data_fields
  95. return {
  96. 'TrainDataset': {
  97. 'name': 'COCODetDataset',
  98. 'image_dir': image_dir,
  99. 'anno_path': train_anno_path,
  100. 'dataset_dir': dataset_root_path,
  101. 'data_fields': data_fields
  102. },
  103. 'EvalDataset': {
  104. 'name': 'COCODetDataset',
  105. 'image_dir': image_dir,
  106. 'anno_path': val_anno_path,
  107. 'dataset_dir': dataset_root_path
  108. },
  109. 'TestDataset': {
  110. 'name': 'ImageFolder',
  111. 'anno_path': test_anno_path,
  112. 'dataset_dir': dataset_root_path
  113. },
  114. }
  115. def update_ema(self,
  116. use_ema: bool,
  117. ema_decay: float=0.9999,
  118. ema_decay_type: str="exponential",
  119. ema_filter_no_grad: bool=True):
  120. """update EMA setting
  121. Args:
  122. use_ema (bool): whether or not to use EMA
  123. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  124. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  125. ema_filter_no_grad (bool, optional): whether or not to filter the parameters
  126. that been set to stop gradient and are not batch norm parameters. Defaults to True.
  127. """
  128. self.update({
  129. 'use_ema': use_ema,
  130. 'ema_decay': ema_decay,
  131. 'ema_decay_type': ema_decay_type,
  132. 'ema_filter_no_grad': ema_filter_no_grad
  133. })
  134. def update_learning_rate(self, learning_rate: float):
  135. """update learning rate
  136. Args:
  137. learning_rate (float): the learning rate value to set.
  138. """
  139. self.LearningRate['base_lr'] = learning_rate
  140. def update_warmup_steps(self, warmup_steps: int):
  141. """update warmup steps
  142. Args:
  143. warmup_steps (int): the warmup steps value to set.
  144. """
  145. schedulers = self.LearningRate['schedulers']
  146. for sch in schedulers:
  147. key = 'name' if 'name' in sch else '_type_'
  148. if sch[key] == 'LinearWarmup':
  149. sch['steps'] = warmup_steps
  150. sch['epochs_first'] = False
  151. def update_warmup_enable(self, use_warmup: bool):
  152. """whether or not to enable learning rate warmup
  153. Args:
  154. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  155. """
  156. schedulers = self.LearningRate['schedulers']
  157. for sch in schedulers:
  158. if 'use_warmup' in sch:
  159. sch['use_warmup'] = use_warmup
  160. def update_cossch_epoch(self, max_epochs: int):
  161. """update max epochs of cosine learning rate scheduler
  162. Args:
  163. max_epochs (int): the max epochs value.
  164. """
  165. schedulers = self.LearningRate['schedulers']
  166. for sch in schedulers:
  167. key = 'name' if 'name' in sch else '_type_'
  168. if sch[key] == 'CosineDecay':
  169. sch['max_epochs'] = max_epochs
  170. def update_milestone(self, milestones: list[int]):
  171. """update milstone of `PiecewiseDecay` learning scheduler
  172. Args:
  173. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  174. """
  175. schedulers = self.LearningRate['schedulers']
  176. for sch in schedulers:
  177. key = 'name' if 'name' in sch else '_type_'
  178. if sch[key] == 'PiecewiseDecay':
  179. sch['milestones'] = milestones
  180. def update_batch_size(self, batch_size: int, mode: str='train'):
  181. """update batch size setting
  182. Args:
  183. batch_size (int): the batch size number to set.
  184. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  185. Defaults to 'train'.
  186. Raises:
  187. ValueError: mode error.
  188. """
  189. assert mode in ('train', 'eval', 'test'), \
  190. 'mode ({}) should be train, eval or test'.format(mode)
  191. if mode == 'train':
  192. self.TrainReader['batch_size'] = batch_size
  193. elif mode == 'eval':
  194. self.EvalReader['batch_size'] = batch_size
  195. else:
  196. self.TestReader['batch_size'] = batch_size
  197. def update_epochs(self, epochs: int):
  198. """update epochs setting
  199. Args:
  200. epochs (int): the epochs number value to set
  201. """
  202. self.update({'epoch': epochs})
  203. def update_device(self, device_type: str):
  204. """update device setting
  205. Args:
  206. device (str): the running device to set
  207. """
  208. if device_type.lower() == "gpu":
  209. self['use_gpu'] = True
  210. elif device_type.lower() == "xpu":
  211. self['use_xpu'] = True
  212. self['use_gpu'] = False
  213. elif device_type.lower() == "npu":
  214. self['use_npu'] = True
  215. self['use_gpu'] = False
  216. elif device_type.lower() == "mlu":
  217. self['use_mlu'] = True
  218. self['use_gpu'] = False
  219. else:
  220. assert device_type.lower() == "cpu"
  221. self['use_gpu'] = False
  222. def update_save_dir(self, save_dir: str):
  223. """update directory to save outputs
  224. Args:
  225. save_dir (str): the directory to save outputs.
  226. """
  227. self['save_dir'] = abspath(save_dir)
  228. def update_log_interval(self, log_interval: int):
  229. """update log interval(steps)
  230. Args:
  231. log_interval (int): the log interval value to set.
  232. """
  233. self.update({'log_iter': log_interval})
  234. def update_eval_interval(self, eval_interval: int):
  235. """update eval interval(epochs)
  236. Args:
  237. eval_interval (int): the eval interval value to set.
  238. """
  239. self.update({'snapshot_epoch': eval_interval})
  240. def update_save_interval(self, save_interval: int):
  241. """update eval interval(epochs)
  242. Args:
  243. save_interval (int): the save interval value to set.
  244. """
  245. self.update({'snapshot_epoch': save_interval})
  246. def update_log_ranks(self, device):
  247. """update log ranks
  248. Args:
  249. device (str): the running device to set
  250. """
  251. log_ranks = device.split(':')[1]
  252. self.update({'log_ranks': log_ranks})
  253. def update_print_mem_info(self, print_mem_info: bool):
  254. """setting print memory info"""
  255. assert isinstance(print_mem_info,
  256. bool), "print_mem_info should be a bool"
  257. self.update({'print_mem_info': f'{print_mem_info}'})
  258. def update_shared_memory(self, shared_memeory: bool):
  259. """update shared memory setting of train and eval dataloader
  260. Args:
  261. shared_memeory (bool): whether or not to use shared memory
  262. """
  263. assert isinstance(shared_memeory,
  264. bool), "shared_memeory should be a bool"
  265. self.update({'print_mem_info': f'{shared_memeory}'})
  266. def update_shuffle(self, shuffle: bool):
  267. """update shuffle setting of train and eval dataloader
  268. Args:
  269. shuffle (bool): whether or not to shuffle the data
  270. """
  271. assert isinstance(shuffle, bool), "shuffle should be a bool"
  272. self.update({'TrainReader': {'shuffle': shuffle}})
  273. self.update({'EvalReader': {'shuffle': shuffle}})
  274. def update_weights(self, weight_path: str):
  275. """update model weight
  276. Args:
  277. weight_path (str): the path to weight file of model.
  278. """
  279. self['weights'] = abspath(weight_path)
  280. def update_pretrained_weights(self, pretrain_weights: str):
  281. """update pretrained weight path
  282. Args:
  283. pretrained_model (str): the local path or url of pretrained weight file to set.
  284. """
  285. if not pretrain_weights.startswith(
  286. 'http://') and not pretrain_weights.startswith('https://'):
  287. pretrain_weights = abspath(pretrain_weights)
  288. self['pretrain_weights'] = pretrain_weights
  289. def update_num_class(self, num_classes: int):
  290. """update classes number
  291. Args:
  292. num_classes (int): the classes number value to set.
  293. """
  294. self['num_classes'] = num_classes
  295. def update_random_size(self, randomsize: list[list[int, int]]):
  296. """update `target_size` of `BatchRandomResize` op in TestReader
  297. Args:
  298. randomsize (list[list[int, int]]): the list of different size scales.
  299. """
  300. self.TestReader['batch_transforms']['BatchRandomResize'][
  301. 'target_size'] = randomsize
  302. def update_num_workers(self, num_workers: int):
  303. """update workers number of train and eval dataloader
  304. Args:
  305. num_workers (int): the value of train and eval dataloader workers number to set.
  306. """
  307. self['worker_num'] = num_workers
  308. def _recursively_set(self, config: dict, update_dict: dict):
  309. """recursively set config
  310. Args:
  311. config (dict): the original config.
  312. update_dict (dict): to be updated paramenters and its values
  313. Example:
  314. self._recursively_set(self.HybridEncoder, {'encoder_layer': {'dim_feedforward': 2048}})
  315. """
  316. assert isinstance(update_dict, dict)
  317. for key in update_dict:
  318. if key not in config:
  319. logging.info(
  320. f'A new filed of config to set found: {repr(key)}.')
  321. config[key] = update_dict[key]
  322. elif not isinstance(update_dict[key], dict):
  323. config[key] = update_dict[key]
  324. else:
  325. self._recursively_set(config[key], update_dict[key])
  326. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  327. """update static assigner epochs value
  328. Args:
  329. static_assigner_epochs (int): the value of static assigner epochs
  330. """
  331. assert 'PicoHeadV2' in self
  332. self.PicoHeadV2['static_assigner_epoch'] = static_assigner_epochs
  333. def update_HybridEncoder(self, update_dict: dict):
  334. """update the HybridEncoder neck setting
  335. Args:
  336. update_dict (dict): the HybridEncoder setting.
  337. """
  338. assert 'HybridEncoder' in self
  339. self._recursively_set(self.HybridEncoder, update_dict)
  340. def get_epochs_iters(self) -> int:
  341. """get epochs
  342. Returns:
  343. int: the epochs value, i.e., `Global.epochs` in config.
  344. """
  345. return self.epoch
  346. def get_log_interval(self) -> int:
  347. """get log interval(steps)
  348. Returns:
  349. int: the log interval value, i.e., `Global.print_batch_step` in config.
  350. """
  351. self.log_iter
  352. def get_eval_interval(self) -> int:
  353. """get eval interval(epochs)
  354. Returns:
  355. int: the eval interval value, i.e., `Global.eval_interval` in config.
  356. """
  357. self.snapshot_epoch
  358. def get_save_interval(self) -> int:
  359. """get save interval(epochs)
  360. Returns:
  361. int: the save interval value, i.e., `Global.save_interval` in config.
  362. """
  363. self.snapshot_epoch
  364. def get_learning_rate(self) -> float:
  365. """get learning rate
  366. Returns:
  367. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  368. """
  369. return self.LearningRate['base_lr']
  370. def get_batch_size(self, mode='train') -> int:
  371. """get batch size
  372. Args:
  373. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  374. Defaults to 'train'.
  375. Returns:
  376. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  377. """
  378. if mode == 'train':
  379. return self.TrainReader['batch_size']
  380. elif mode == 'eval':
  381. return self.EvalReader['batch_size']
  382. elif mode == 'test':
  383. return self.TestReader['batch_size']
  384. else:
  385. raise (f"Unknown mode: {repr(mode)}")
  386. def get_qat_epochs_iters(self) -> int:
  387. """get qat epochs
  388. Returns:
  389. int: the epochs value.
  390. """
  391. return self.epoch // 2.0
  392. def get_qat_learning_rate(self) -> float:
  393. """get qat learning rate
  394. Returns:
  395. float: the learning rate value.
  396. """
  397. return self.LearningRate['base_lr'] // 2.0
  398. def get_train_save_dir(self) -> str:
  399. """get the directory to save output
  400. Returns:
  401. str: the directory to save output
  402. """
  403. return self.save_dir