config.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. from ....utils import logging
  16. from ....utils.misc import abspath
  17. from ...base import BaseConfig
  18. from ..config_helper import PPDetConfigMixin
  19. class DetConfig(BaseConfig, PPDetConfigMixin):
  20. """DetConfig"""
  21. def load(self, config_path: str):
  22. """load the config from config file
  23. Args:
  24. config_path (str): the config file path.
  25. """
  26. dict_ = self.load_config_literally(config_path)
  27. self.reset_from_dict(dict_)
  28. def dump(self, config_path: str):
  29. """dump the config
  30. Args:
  31. config_path (str): the path to save dumped config.
  32. """
  33. self.dump_literal_config(config_path, self._dict)
  34. def update(self, dict_like_obj: list):
  35. """update self from dict
  36. Args:
  37. dict_like_obj (list): the list of pairs that contain key and value.
  38. """
  39. self.update_from_dict(dict_like_obj, self._dict)
  40. def update_dataset(
  41. self,
  42. dataset_path: str,
  43. dataset_type: str = None,
  44. *,
  45. data_fields: List[str] = None,
  46. image_dir: str = "images",
  47. train_anno_path: str = "annotations/instance_train.json",
  48. val_anno_path: str = "annotations/instance_val.json",
  49. test_anno_path: str = "annotations/instance_val.json",
  50. metric: str = "COCO",
  51. ):
  52. """update dataset settings
  53. Args:
  54. dataset_path (str): the root path fo dataset.
  55. dataset_type (str, optional): the dataset type. Defaults to None.
  56. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  57. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  58. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  59. Defaults to "annotations/instance_train.json".
  60. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  61. Defaults to "annotations/instance_val.json".
  62. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  63. Defaults to "annotations/instance_val.json".
  64. metric (str, optional): Evaluation metric. Defaults to "COCO".
  65. Raises:
  66. ValueError: the `dataset_type` error.
  67. """
  68. dataset_path = abspath(dataset_path)
  69. if dataset_type is None:
  70. dataset_type = "COCODetDataset"
  71. if dataset_type == "COCODetDataset":
  72. ds_cfg = self._make_dataset_config(
  73. dataset_path,
  74. data_fields,
  75. image_dir,
  76. train_anno_path,
  77. val_anno_path,
  78. test_anno_path,
  79. )
  80. elif dataset_type == "KeypointTopDownCocoDataset":
  81. ds_cfg = {
  82. "TrainDataset": {
  83. "image_dir": image_dir,
  84. "anno_path": train_anno_path,
  85. "dataset_dir": dataset_path,
  86. },
  87. "EvalDataset": {
  88. "image_dir": image_dir,
  89. "anno_path": val_anno_path,
  90. "dataset_dir": dataset_path,
  91. },
  92. "TestDataset": {
  93. "anno_path": test_anno_path,
  94. },
  95. }
  96. else:
  97. raise ValueError(f"{repr(dataset_type)} is not supported.")
  98. self.update(ds_cfg)
  99. self.set_val("metric", metric)
  100. def _make_dataset_config(
  101. self,
  102. dataset_root_path: str,
  103. data_fields: List[str,] = None,
  104. image_dir: str = "images",
  105. train_anno_path: str = "annotations/instance_train.json",
  106. val_anno_path: str = "annotations/instance_val.json",
  107. test_anno_path: str = "annotations/instance_val.json",
  108. ) -> dict:
  109. """construct the dataset config that meets the format requirements
  110. Args:
  111. dataset_root_path (str): the root directory of dataset.
  112. data_fields (list[str,], optional): the data field. Defaults to None.
  113. image_dir (str, optional): _description_. Defaults to "images".
  114. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  115. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  116. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  117. Returns:
  118. dict: the dataset config.
  119. """
  120. data_fields = (
  121. ["image", "gt_bbox", "gt_class", "is_crowd"]
  122. if data_fields is None
  123. else data_fields
  124. )
  125. return {
  126. "TrainDataset": {
  127. "name": "COCODetDataset",
  128. "image_dir": image_dir,
  129. "anno_path": train_anno_path,
  130. "dataset_dir": dataset_root_path,
  131. "data_fields": data_fields,
  132. },
  133. "EvalDataset": {
  134. "name": "COCODetDataset",
  135. "image_dir": image_dir,
  136. "anno_path": val_anno_path,
  137. "dataset_dir": dataset_root_path,
  138. },
  139. "TestDataset": {
  140. "name": "ImageFolder",
  141. "anno_path": test_anno_path,
  142. "dataset_dir": dataset_root_path,
  143. },
  144. }
  145. def update_ema(
  146. self,
  147. use_ema: bool,
  148. ema_decay: float = 0.9999,
  149. ema_decay_type: str = "exponential",
  150. ema_filter_no_grad: bool = True,
  151. ):
  152. """update EMA setting
  153. Args:
  154. use_ema (bool): whether or not to use EMA
  155. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  156. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  157. ema_filter_no_grad (bool, optional): whether or not to filter the parameters
  158. that been set to stop gradient and are not batch norm parameters. Defaults to True.
  159. """
  160. self.update(
  161. {
  162. "use_ema": use_ema,
  163. "ema_decay": ema_decay,
  164. "ema_decay_type": ema_decay_type,
  165. "ema_filter_no_grad": ema_filter_no_grad,
  166. }
  167. )
  168. def update_learning_rate(self, learning_rate: float):
  169. """update learning rate
  170. Args:
  171. learning_rate (float): the learning rate value to set.
  172. """
  173. self.LearningRate["base_lr"] = learning_rate
  174. def update_warmup_steps(self, warmup_steps: int):
  175. """update warmup steps
  176. Args:
  177. warmup_steps (int): the warmup steps value to set.
  178. """
  179. schedulers = self.LearningRate["schedulers"]
  180. for sch in schedulers:
  181. key = "name" if "name" in sch else "_type_"
  182. if sch[key] == "LinearWarmup":
  183. sch["steps"] = warmup_steps
  184. sch["epochs_first"] = False
  185. def update_warmup_enable(self, use_warmup: bool):
  186. """whether or not to enable learning rate warmup
  187. Args:
  188. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  189. """
  190. schedulers = self.LearningRate["schedulers"]
  191. for sch in schedulers:
  192. if "use_warmup" in sch:
  193. sch["use_warmup"] = use_warmup
  194. def update_cossch_epoch(self, max_epochs: int):
  195. """update max epochs of cosine learning rate scheduler
  196. Args:
  197. max_epochs (int): the max epochs value.
  198. """
  199. schedulers = self.LearningRate["schedulers"]
  200. for sch in schedulers:
  201. key = "name" if "name" in sch else "_type_"
  202. if sch[key] == "CosineDecay":
  203. sch["max_epochs"] = max_epochs
  204. def update_milestone(self, milestones: List[int]):
  205. """update milstone of `PiecewiseDecay` learning scheduler
  206. Args:
  207. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  208. """
  209. schedulers = self.LearningRate["schedulers"]
  210. for sch in schedulers:
  211. key = "name" if "name" in sch else "_type_"
  212. if sch[key] == "PiecewiseDecay":
  213. sch["milestones"] = milestones
  214. def update_batch_size(self, batch_size: int, mode: str = "train"):
  215. """update batch size setting
  216. Args:
  217. batch_size (int): the batch size number to set.
  218. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  219. Defaults to 'train'.
  220. Raises:
  221. ValueError: mode error.
  222. """
  223. assert mode in (
  224. "train",
  225. "eval",
  226. "test",
  227. ), "mode ({}) should be train, eval or test".format(mode)
  228. if mode == "train":
  229. self.TrainReader["batch_size"] = batch_size
  230. elif mode == "eval":
  231. self.EvalReader["batch_size"] = batch_size
  232. else:
  233. self.TestReader["batch_size"] = batch_size
  234. def update_epochs(self, epochs: int):
  235. """update epochs setting
  236. Args:
  237. epochs (int): the epochs number value to set
  238. """
  239. self.update({"epoch": epochs})
  240. def update_device(self, device_type: str):
  241. """update device setting
  242. Args:
  243. device (str): the running device to set
  244. """
  245. if device_type.lower() == "gpu":
  246. self["use_gpu"] = True
  247. elif device_type.lower() == "xpu":
  248. self["use_xpu"] = True
  249. self["use_gpu"] = False
  250. elif device_type.lower() == "npu":
  251. self["use_npu"] = True
  252. self["use_gpu"] = False
  253. elif device_type.lower() == "mlu":
  254. self["use_mlu"] = True
  255. self["use_gpu"] = False
  256. elif device_type.lower() == "gcu":
  257. self["use_gcu"] = True
  258. self["use_gpu"] = False
  259. else:
  260. assert device_type.lower() == "cpu"
  261. self["use_gpu"] = False
  262. def update_save_dir(self, save_dir: str):
  263. """update directory to save outputs
  264. Args:
  265. save_dir (str): the directory to save outputs.
  266. """
  267. self["save_dir"] = abspath(save_dir)
  268. def update_log_interval(self, log_interval: int):
  269. """update log interval(steps)
  270. Args:
  271. log_interval (int): the log interval value to set.
  272. """
  273. self.update({"log_iter": log_interval})
  274. def update_eval_interval(self, eval_interval: int):
  275. """update eval interval(epochs)
  276. Args:
  277. eval_interval (int): the eval interval value to set.
  278. """
  279. self.update({"snapshot_epoch": eval_interval})
  280. def update_save_interval(self, save_interval: int):
  281. """update eval interval(epochs)
  282. Args:
  283. save_interval (int): the save interval value to set.
  284. """
  285. self.update({"snapshot_epoch": save_interval})
  286. def update_log_ranks(self, device):
  287. """update log ranks
  288. Args:
  289. device (str): the running device to set
  290. """
  291. log_ranks = device.split(":")[1]
  292. self.update({"log_ranks": log_ranks})
  293. def update_print_mem_info(self, print_mem_info: bool):
  294. """setting print memory info"""
  295. assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
  296. self.update({"print_mem_info": f"{print_mem_info}"})
  297. def update_shared_memory(self, shared_memeory: bool):
  298. """update shared memory setting of train and eval dataloader
  299. Args:
  300. shared_memeory (bool): whether or not to use shared memory
  301. """
  302. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  303. self.update({"print_mem_info": f"{shared_memeory}"})
  304. def update_shuffle(self, shuffle: bool):
  305. """update shuffle setting of train and eval dataloader
  306. Args:
  307. shuffle (bool): whether or not to shuffle the data
  308. """
  309. assert isinstance(shuffle, bool), "shuffle should be a bool"
  310. self.update({"TrainReader": {"shuffle": shuffle}})
  311. self.update({"EvalReader": {"shuffle": shuffle}})
  312. def update_weights(self, weight_path: str):
  313. """update model weight
  314. Args:
  315. weight_path (str): the path to weight file of model.
  316. """
  317. self["weights"] = weight_path
  318. def update_pretrained_weights(self, pretrain_weights: str):
  319. """update pretrained weight path
  320. Args:
  321. pretrained_model (str): the local path or url of pretrained weight file to set.
  322. """
  323. if not pretrain_weights.startswith(
  324. "http://"
  325. ) and not pretrain_weights.startswith("https://"):
  326. pretrain_weights = abspath(pretrain_weights)
  327. self["pretrain_weights"] = pretrain_weights
  328. def update_num_class(self, num_classes: int):
  329. """update classes number
  330. Args:
  331. num_classes (int): the classes number value to set.
  332. """
  333. self["num_classes"] = num_classes
  334. if "CenterNet" in self.model_name:
  335. for i in range(len(self["TrainReader"]["sample_transforms"])):
  336. if (
  337. "Gt2CenterNetTarget"
  338. in self["TrainReader"]["sample_transforms"][i].keys()
  339. ):
  340. self["TrainReader"]["sample_transforms"][i]["Gt2CenterNetTarget"][
  341. "num_classes"
  342. ] = num_classes
  343. def update_random_size(self, randomsize):
  344. """update `target_size` of `BatchRandomResize` op in TestReader
  345. Args:
  346. randomsize (list[list[int, int]]): the list of different size scales.
  347. """
  348. self.TestReader["batch_transforms"]["BatchRandomResize"][
  349. "target_size"
  350. ] = randomsize
  351. def update_num_workers(self, num_workers: int):
  352. """update workers number of train and eval dataloader
  353. Args:
  354. num_workers (int): the value of train and eval dataloader workers number to set.
  355. """
  356. self["worker_num"] = num_workers
  357. def _recursively_set(self, config: dict, update_dict: dict):
  358. """recursively set config
  359. Args:
  360. config (dict): the original config.
  361. update_dict (dict): to be updated parameters and its values
  362. Example:
  363. self._recursively_set(self.HybridEncoder, {'encoder_layer': {'dim_feedforward': 2048}})
  364. """
  365. assert isinstance(update_dict, dict)
  366. for key in update_dict:
  367. if key not in config:
  368. logging.info(f"A new filed of config to set found: {repr(key)}.")
  369. config[key] = update_dict[key]
  370. elif not isinstance(update_dict[key], dict):
  371. config[key] = update_dict[key]
  372. else:
  373. self._recursively_set(config[key], update_dict[key])
  374. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  375. """update static assigner epochs value
  376. Args:
  377. static_assigner_epochs (int): the value of static assigner epochs
  378. """
  379. assert "PicoHeadV2" in self
  380. self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
  381. def update_HybridEncoder(self, update_dict: dict):
  382. """update the HybridEncoder neck setting
  383. Args:
  384. update_dict (dict): the HybridEncoder setting.
  385. """
  386. assert "HybridEncoder" in self
  387. self._recursively_set(self.HybridEncoder, update_dict)
  388. def get_epochs_iters(self) -> int:
  389. """get epochs
  390. Returns:
  391. int: the epochs value, i.e., `Global.epochs` in config.
  392. """
  393. return self.epoch
  394. def get_log_interval(self) -> int:
  395. """get log interval(steps)
  396. Returns:
  397. int: the log interval value, i.e., `Global.print_batch_step` in config.
  398. """
  399. self.log_iter
  400. def get_eval_interval(self) -> int:
  401. """get eval interval(epochs)
  402. Returns:
  403. int: the eval interval value, i.e., `Global.eval_interval` in config.
  404. """
  405. self.snapshot_epoch
  406. def get_save_interval(self) -> int:
  407. """get save interval(epochs)
  408. Returns:
  409. int: the save interval value, i.e., `Global.save_interval` in config.
  410. """
  411. self.snapshot_epoch
  412. def get_learning_rate(self) -> float:
  413. """get learning rate
  414. Returns:
  415. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  416. """
  417. return self.LearningRate["base_lr"]
  418. def get_batch_size(self, mode="train") -> int:
  419. """get batch size
  420. Args:
  421. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  422. Defaults to 'train'.
  423. Returns:
  424. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  425. """
  426. if mode == "train":
  427. return self.TrainReader["batch_size"]
  428. elif mode == "eval":
  429. return self.EvalReader["batch_size"]
  430. elif mode == "test":
  431. return self.TestReader["batch_size"]
  432. else:
  433. raise (f"Unknown mode: {repr(mode)}")
  434. def get_qat_epochs_iters(self) -> int:
  435. """get qat epochs
  436. Returns:
  437. int: the epochs value.
  438. """
  439. return self.epoch // 2.0
  440. def get_qat_learning_rate(self) -> float:
  441. """get qat learning rate
  442. Returns:
  443. float: the learning rate value.
  444. """
  445. return self.LearningRate["base_lr"] // 2.0
  446. def get_train_save_dir(self) -> str:
  447. """get the directory to save output
  448. Returns:
  449. str: the directory to save output
  450. """
  451. return self.save_dir