config.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...base import BaseConfig
  15. from ....utils.misc import abspath
  16. from ....utils import logging
  17. from ..config_helper import PPDetConfigMixin
  18. class DetConfig(BaseConfig, PPDetConfigMixin):
  19. """DetConfig"""
  20. def load(self, config_path: str):
  21. """load the config from config file
  22. Args:
  23. config_path (str): the config file path.
  24. """
  25. dict_ = self.load_config_literally(config_path)
  26. self.reset_from_dict(dict_)
  27. def dump(self, config_path: str):
  28. """dump the config
  29. Args:
  30. config_path (str): the path to save dumped config.
  31. """
  32. self.dump_literal_config(config_path, self._dict)
  33. def update(self, dict_like_obj: list):
  34. """update self from dict
  35. Args:
  36. dict_like_obj (list): the list of pairs that contain key and value.
  37. """
  38. self.update_from_dict(dict_like_obj, self._dict)
  39. def update_dataset(
  40. self,
  41. dataset_path: str,
  42. dataset_type: str = None,
  43. *,
  44. data_fields: list[str] = None,
  45. image_dir: str = "images",
  46. train_anno_path: str = "annotations/instance_train.json",
  47. val_anno_path: str = "annotations/instance_val.json",
  48. test_anno_path: str = "annotations/instance_val.json",
  49. metric: str = "COCO",
  50. ):
  51. """update dataset settings
  52. Args:
  53. dataset_path (str): the root path fo dataset.
  54. dataset_type (str, optional): the dataset type. Defaults to None.
  55. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  56. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  57. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  58. Defaults to "annotations/instance_train.json".
  59. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  60. Defaults to "annotations/instance_val.json".
  61. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  62. Defaults to "annotations/instance_val.json".
  63. metric (str, optional): Evaluation metric. Defaults to "COCO".
  64. Raises:
  65. ValueError: the `dataset_type` error.
  66. """
  67. dataset_path = abspath(dataset_path)
  68. if dataset_type is None:
  69. dataset_type = "COCODetDataset"
  70. if dataset_type == "COCODetDataset":
  71. ds_cfg = self._make_dataset_config(
  72. dataset_path,
  73. data_fields,
  74. image_dir,
  75. train_anno_path,
  76. val_anno_path,
  77. test_anno_path,
  78. )
  79. elif dataset_type == "KeypointTopDownCocoDataset":
  80. ds_cfg = {
  81. "TrainDataset": {
  82. "image_dir": image_dir,
  83. "anno_path": train_anno_path,
  84. "dataset_dir": dataset_path,
  85. },
  86. "EvalDataset": {
  87. "image_dir": image_dir,
  88. "anno_path": val_anno_path,
  89. "dataset_dir": dataset_path,
  90. },
  91. "TestDataset": {
  92. "anno_path": test_anno_path,
  93. },
  94. }
  95. else:
  96. raise ValueError(f"{repr(dataset_type)} is not supported.")
  97. self.update(ds_cfg)
  98. self.set_val("metric", metric)
  99. def _make_dataset_config(
  100. self,
  101. dataset_root_path: str,
  102. data_fields: list[str,] = None,
  103. image_dir: str = "images",
  104. train_anno_path: str = "annotations/instance_train.json",
  105. val_anno_path: str = "annotations/instance_val.json",
  106. test_anno_path: str = "annotations/instance_val.json",
  107. ) -> dict:
  108. """construct the dataset config that meets the format requirements
  109. Args:
  110. dataset_root_path (str): the root directory of dataset.
  111. data_fields (list[str,], optional): the data field. Defaults to None.
  112. image_dir (str, optional): _description_. Defaults to "images".
  113. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  114. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  115. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  116. Returns:
  117. dict: the dataset config.
  118. """
  119. data_fields = (
  120. ["image", "gt_bbox", "gt_class", "is_crowd"]
  121. if data_fields is None
  122. else data_fields
  123. )
  124. return {
  125. "TrainDataset": {
  126. "name": "COCODetDataset",
  127. "image_dir": image_dir,
  128. "anno_path": train_anno_path,
  129. "dataset_dir": dataset_root_path,
  130. "data_fields": data_fields,
  131. },
  132. "EvalDataset": {
  133. "name": "COCODetDataset",
  134. "image_dir": image_dir,
  135. "anno_path": val_anno_path,
  136. "dataset_dir": dataset_root_path,
  137. },
  138. "TestDataset": {
  139. "name": "ImageFolder",
  140. "anno_path": test_anno_path,
  141. "dataset_dir": dataset_root_path,
  142. },
  143. }
  144. def update_ema(
  145. self,
  146. use_ema: bool,
  147. ema_decay: float = 0.9999,
  148. ema_decay_type: str = "exponential",
  149. ema_filter_no_grad: bool = True,
  150. ):
  151. """update EMA setting
  152. Args:
  153. use_ema (bool): whether or not to use EMA
  154. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  155. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  156. ema_filter_no_grad (bool, optional): whether or not to filter the parameters
  157. that been set to stop gradient and are not batch norm parameters. Defaults to True.
  158. """
  159. self.update(
  160. {
  161. "use_ema": use_ema,
  162. "ema_decay": ema_decay,
  163. "ema_decay_type": ema_decay_type,
  164. "ema_filter_no_grad": ema_filter_no_grad,
  165. }
  166. )
  167. def update_learning_rate(self, learning_rate: float):
  168. """update learning rate
  169. Args:
  170. learning_rate (float): the learning rate value to set.
  171. """
  172. self.LearningRate["base_lr"] = learning_rate
  173. def update_warmup_steps(self, warmup_steps: int):
  174. """update warmup steps
  175. Args:
  176. warmup_steps (int): the warmup steps value to set.
  177. """
  178. schedulers = self.LearningRate["schedulers"]
  179. for sch in schedulers:
  180. key = "name" if "name" in sch else "_type_"
  181. if sch[key] == "LinearWarmup":
  182. sch["steps"] = warmup_steps
  183. sch["epochs_first"] = False
  184. def update_warmup_enable(self, use_warmup: bool):
  185. """whether or not to enable learning rate warmup
  186. Args:
  187. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  188. """
  189. schedulers = self.LearningRate["schedulers"]
  190. for sch in schedulers:
  191. if "use_warmup" in sch:
  192. sch["use_warmup"] = use_warmup
  193. def update_cossch_epoch(self, max_epochs: int):
  194. """update max epochs of cosine learning rate scheduler
  195. Args:
  196. max_epochs (int): the max epochs value.
  197. """
  198. schedulers = self.LearningRate["schedulers"]
  199. for sch in schedulers:
  200. key = "name" if "name" in sch else "_type_"
  201. if sch[key] == "CosineDecay":
  202. sch["max_epochs"] = max_epochs
  203. def update_milestone(self, milestones: list[int]):
  204. """update milstone of `PiecewiseDecay` learning scheduler
  205. Args:
  206. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  207. """
  208. schedulers = self.LearningRate["schedulers"]
  209. for sch in schedulers:
  210. key = "name" if "name" in sch else "_type_"
  211. if sch[key] == "PiecewiseDecay":
  212. sch["milestones"] = milestones
  213. def update_batch_size(self, batch_size: int, mode: str = "train"):
  214. """update batch size setting
  215. Args:
  216. batch_size (int): the batch size number to set.
  217. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  218. Defaults to 'train'.
  219. Raises:
  220. ValueError: mode error.
  221. """
  222. assert mode in (
  223. "train",
  224. "eval",
  225. "test",
  226. ), "mode ({}) should be train, eval or test".format(mode)
  227. if mode == "train":
  228. self.TrainReader["batch_size"] = batch_size
  229. elif mode == "eval":
  230. self.EvalReader["batch_size"] = batch_size
  231. else:
  232. self.TestReader["batch_size"] = batch_size
  233. def update_epochs(self, epochs: int):
  234. """update epochs setting
  235. Args:
  236. epochs (int): the epochs number value to set
  237. """
  238. self.update({"epoch": epochs})
  239. def update_device(self, device_type: str):
  240. """update device setting
  241. Args:
  242. device (str): the running device to set
  243. """
  244. if device_type.lower() == "gpu":
  245. self["use_gpu"] = True
  246. elif device_type.lower() == "xpu":
  247. self["use_xpu"] = True
  248. self["use_gpu"] = False
  249. elif device_type.lower() == "npu":
  250. self["use_npu"] = True
  251. self["use_gpu"] = False
  252. elif device_type.lower() == "mlu":
  253. self["use_mlu"] = True
  254. self["use_gpu"] = False
  255. elif device_type.lower() == "gcu":
  256. self["use_gcu"] = True
  257. self["use_gpu"] = False
  258. else:
  259. assert device_type.lower() == "cpu"
  260. self["use_gpu"] = False
  261. def update_save_dir(self, save_dir: str):
  262. """update directory to save outputs
  263. Args:
  264. save_dir (str): the directory to save outputs.
  265. """
  266. self["save_dir"] = abspath(save_dir)
  267. def update_log_interval(self, log_interval: int):
  268. """update log interval(steps)
  269. Args:
  270. log_interval (int): the log interval value to set.
  271. """
  272. self.update({"log_iter": log_interval})
  273. def update_eval_interval(self, eval_interval: int):
  274. """update eval interval(epochs)
  275. Args:
  276. eval_interval (int): the eval interval value to set.
  277. """
  278. self.update({"snapshot_epoch": eval_interval})
  279. def update_save_interval(self, save_interval: int):
  280. """update eval interval(epochs)
  281. Args:
  282. save_interval (int): the save interval value to set.
  283. """
  284. self.update({"snapshot_epoch": save_interval})
  285. def update_log_ranks(self, device):
  286. """update log ranks
  287. Args:
  288. device (str): the running device to set
  289. """
  290. log_ranks = device.split(":")[1]
  291. self.update({"log_ranks": log_ranks})
  292. def update_print_mem_info(self, print_mem_info: bool):
  293. """setting print memory info"""
  294. assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
  295. self.update({"print_mem_info": f"{print_mem_info}"})
  296. def update_shared_memory(self, shared_memeory: bool):
  297. """update shared memory setting of train and eval dataloader
  298. Args:
  299. shared_memeory (bool): whether or not to use shared memory
  300. """
  301. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  302. self.update({"print_mem_info": f"{shared_memeory}"})
  303. def update_shuffle(self, shuffle: bool):
  304. """update shuffle setting of train and eval dataloader
  305. Args:
  306. shuffle (bool): whether or not to shuffle the data
  307. """
  308. assert isinstance(shuffle, bool), "shuffle should be a bool"
  309. self.update({"TrainReader": {"shuffle": shuffle}})
  310. self.update({"EvalReader": {"shuffle": shuffle}})
  311. def update_weights(self, weight_path: str):
  312. """update model weight
  313. Args:
  314. weight_path (str): the path to weight file of model.
  315. """
  316. self["weights"] = weight_path
  317. def update_pretrained_weights(self, pretrain_weights: str):
  318. """update pretrained weight path
  319. Args:
  320. pretrained_model (str): the local path or url of pretrained weight file to set.
  321. """
  322. if not pretrain_weights.startswith(
  323. "http://"
  324. ) and not pretrain_weights.startswith("https://"):
  325. pretrain_weights = abspath(pretrain_weights)
  326. self["pretrain_weights"] = pretrain_weights
  327. def update_num_class(self, num_classes: int):
  328. """update classes number
  329. Args:
  330. num_classes (int): the classes number value to set.
  331. """
  332. self["num_classes"] = num_classes
  333. if "CenterNet" in self.model_name:
  334. for i in range(len(self["TrainReader"]["sample_transforms"])):
  335. if (
  336. "Gt2CenterNetTarget"
  337. in self["TrainReader"]["sample_transforms"][i].keys()
  338. ):
  339. self["TrainReader"]["sample_transforms"][i]["Gt2CenterNetTarget"][
  340. "num_classes"
  341. ] = num_classes
  342. def update_random_size(self, randomsize: list[list[int, int]]):
  343. """update `target_size` of `BatchRandomResize` op in TestReader
  344. Args:
  345. randomsize (list[list[int, int]]): the list of different size scales.
  346. """
  347. self.TestReader["batch_transforms"]["BatchRandomResize"][
  348. "target_size"
  349. ] = randomsize
  350. def update_num_workers(self, num_workers: int):
  351. """update workers number of train and eval dataloader
  352. Args:
  353. num_workers (int): the value of train and eval dataloader workers number to set.
  354. """
  355. self["worker_num"] = num_workers
  356. def _recursively_set(self, config: dict, update_dict: dict):
  357. """recursively set config
  358. Args:
  359. config (dict): the original config.
  360. update_dict (dict): to be updated paramenters and its values
  361. Example:
  362. self._recursively_set(self.HybridEncoder, {'encoder_layer': {'dim_feedforward': 2048}})
  363. """
  364. assert isinstance(update_dict, dict)
  365. for key in update_dict:
  366. if key not in config:
  367. logging.info(f"A new filed of config to set found: {repr(key)}.")
  368. config[key] = update_dict[key]
  369. elif not isinstance(update_dict[key], dict):
  370. config[key] = update_dict[key]
  371. else:
  372. self._recursively_set(config[key], update_dict[key])
  373. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  374. """update static assigner epochs value
  375. Args:
  376. static_assigner_epochs (int): the value of static assigner epochs
  377. """
  378. assert "PicoHeadV2" in self
  379. self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
  380. def update_HybridEncoder(self, update_dict: dict):
  381. """update the HybridEncoder neck setting
  382. Args:
  383. update_dict (dict): the HybridEncoder setting.
  384. """
  385. assert "HybridEncoder" in self
  386. self._recursively_set(self.HybridEncoder, update_dict)
  387. def get_epochs_iters(self) -> int:
  388. """get epochs
  389. Returns:
  390. int: the epochs value, i.e., `Global.epochs` in config.
  391. """
  392. return self.epoch
  393. def get_log_interval(self) -> int:
  394. """get log interval(steps)
  395. Returns:
  396. int: the log interval value, i.e., `Global.print_batch_step` in config.
  397. """
  398. self.log_iter
  399. def get_eval_interval(self) -> int:
  400. """get eval interval(epochs)
  401. Returns:
  402. int: the eval interval value, i.e., `Global.eval_interval` in config.
  403. """
  404. self.snapshot_epoch
  405. def get_save_interval(self) -> int:
  406. """get save interval(epochs)
  407. Returns:
  408. int: the save interval value, i.e., `Global.save_interval` in config.
  409. """
  410. self.snapshot_epoch
  411. def get_learning_rate(self) -> float:
  412. """get learning rate
  413. Returns:
  414. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  415. """
  416. return self.LearningRate["base_lr"]
  417. def get_batch_size(self, mode="train") -> int:
  418. """get batch size
  419. Args:
  420. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  421. Defaults to 'train'.
  422. Returns:
  423. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  424. """
  425. if mode == "train":
  426. return self.TrainReader["batch_size"]
  427. elif mode == "eval":
  428. return self.EvalReader["batch_size"]
  429. elif mode == "test":
  430. return self.TestReader["batch_size"]
  431. else:
  432. raise (f"Unknown mode: {repr(mode)}")
  433. def get_qat_epochs_iters(self) -> int:
  434. """get qat epochs
  435. Returns:
  436. int: the epochs value.
  437. """
  438. return self.epoch // 2.0
  439. def get_qat_learning_rate(self) -> float:
  440. """get qat learning rate
  441. Returns:
  442. float: the learning rate value.
  443. """
  444. return self.LearningRate["base_lr"] // 2.0
  445. def get_train_save_dir(self) -> str:
  446. """get the directory to save output
  447. Returns:
  448. str: the directory to save output
  449. """
  450. return self.save_dir