config.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...base import BaseConfig
  15. from ....utils.misc import abspath
  16. from ....utils import logging
  17. from ..config_helper import PPDetConfigMixin
  18. class DetConfig(BaseConfig, PPDetConfigMixin):
  19. """DetConfig"""
  20. def load(self, config_path: str):
  21. """load the config from config file
  22. Args:
  23. config_path (str): the config file path.
  24. """
  25. dict_ = self.load_config_literally(config_path)
  26. self.reset_from_dict(dict_)
  27. def dump(self, config_path: str):
  28. """dump the config
  29. Args:
  30. config_path (str): the path to save dumped config.
  31. """
  32. self.dump_literal_config(config_path, self._dict)
  33. def update(self, dict_like_obj: list):
  34. """update self from dict
  35. Args:
  36. dict_like_obj (list): the list of pairs that contain key and value.
  37. """
  38. self.update_from_dict(dict_like_obj, self._dict)
  39. def update_dataset(
  40. self,
  41. dataset_path: str,
  42. dataset_type: str = None,
  43. *,
  44. data_fields: list[str] = None,
  45. image_dir: str = "images",
  46. train_anno_path: str = "annotations/instance_train.json",
  47. val_anno_path: str = "annotations/instance_val.json",
  48. test_anno_path: str = "annotations/instance_val.json",
  49. metric: str = "COCO",
  50. ):
  51. """update dataset settings
  52. Args:
  53. dataset_path (str): the root path fo dataset.
  54. dataset_type (str, optional): the dataset type. Defaults to None.
  55. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  56. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  57. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  58. Defaults to "annotations/instance_train.json".
  59. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  60. Defaults to "annotations/instance_val.json".
  61. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  62. Defaults to "annotations/instance_val.json".
  63. metric (str, optional): Evaluation metric. Defaults to "COCO".
  64. Raises:
  65. ValueError: the `dataset_type` error.
  66. """
  67. dataset_path = abspath(dataset_path)
  68. if dataset_type is None:
  69. dataset_type = "COCODetDataset"
  70. if dataset_type == "COCODetDataset":
  71. ds_cfg = self._make_dataset_config(
  72. dataset_path,
  73. data_fields,
  74. image_dir,
  75. train_anno_path,
  76. val_anno_path,
  77. test_anno_path,
  78. )
  79. else:
  80. raise ValueError(f"{repr(dataset_type)} is not supported.")
  81. self.update(ds_cfg)
  82. self.set_val("metric", metric)
  83. def _make_dataset_config(
  84. self,
  85. dataset_root_path: str,
  86. data_fields: list[str,] = None,
  87. image_dir: str = "images",
  88. train_anno_path: str = "annotations/instance_train.json",
  89. val_anno_path: str = "annotations/instance_val.json",
  90. test_anno_path: str = "annotations/instance_val.json",
  91. ) -> dict:
  92. """construct the dataset config that meets the format requirements
  93. Args:
  94. dataset_root_path (str): the root directory of dataset.
  95. data_fields (list[str,], optional): the data field. Defaults to None.
  96. image_dir (str, optional): _description_. Defaults to "images".
  97. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  98. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  99. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  100. Returns:
  101. dict: the dataset config.
  102. """
  103. data_fields = (
  104. ["image", "gt_bbox", "gt_class", "is_crowd"]
  105. if data_fields is None
  106. else data_fields
  107. )
  108. return {
  109. "TrainDataset": {
  110. "name": "COCODetDataset",
  111. "image_dir": image_dir,
  112. "anno_path": train_anno_path,
  113. "dataset_dir": dataset_root_path,
  114. "data_fields": data_fields,
  115. },
  116. "EvalDataset": {
  117. "name": "COCODetDataset",
  118. "image_dir": image_dir,
  119. "anno_path": val_anno_path,
  120. "dataset_dir": dataset_root_path,
  121. },
  122. "TestDataset": {
  123. "name": "ImageFolder",
  124. "anno_path": test_anno_path,
  125. "dataset_dir": dataset_root_path,
  126. },
  127. }
  128. def update_ema(
  129. self,
  130. use_ema: bool,
  131. ema_decay: float = 0.9999,
  132. ema_decay_type: str = "exponential",
  133. ema_filter_no_grad: bool = True,
  134. ):
  135. """update EMA setting
  136. Args:
  137. use_ema (bool): whether or not to use EMA
  138. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  139. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  140. ema_filter_no_grad (bool, optional): whether or not to filter the parameters
  141. that been set to stop gradient and are not batch norm parameters. Defaults to True.
  142. """
  143. self.update(
  144. {
  145. "use_ema": use_ema,
  146. "ema_decay": ema_decay,
  147. "ema_decay_type": ema_decay_type,
  148. "ema_filter_no_grad": ema_filter_no_grad,
  149. }
  150. )
  151. def update_learning_rate(self, learning_rate: float):
  152. """update learning rate
  153. Args:
  154. learning_rate (float): the learning rate value to set.
  155. """
  156. self.LearningRate["base_lr"] = learning_rate
  157. def update_warmup_steps(self, warmup_steps: int):
  158. """update warmup steps
  159. Args:
  160. warmup_steps (int): the warmup steps value to set.
  161. """
  162. schedulers = self.LearningRate["schedulers"]
  163. for sch in schedulers:
  164. key = "name" if "name" in sch else "_type_"
  165. if sch[key] == "LinearWarmup":
  166. sch["steps"] = warmup_steps
  167. sch["epochs_first"] = False
  168. def update_warmup_enable(self, use_warmup: bool):
  169. """whether or not to enable learning rate warmup
  170. Args:
  171. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  172. """
  173. schedulers = self.LearningRate["schedulers"]
  174. for sch in schedulers:
  175. if "use_warmup" in sch:
  176. sch["use_warmup"] = use_warmup
  177. def update_cossch_epoch(self, max_epochs: int):
  178. """update max epochs of cosine learning rate scheduler
  179. Args:
  180. max_epochs (int): the max epochs value.
  181. """
  182. schedulers = self.LearningRate["schedulers"]
  183. for sch in schedulers:
  184. key = "name" if "name" in sch else "_type_"
  185. if sch[key] == "CosineDecay":
  186. sch["max_epochs"] = max_epochs
  187. def update_milestone(self, milestones: list[int]):
  188. """update milstone of `PiecewiseDecay` learning scheduler
  189. Args:
  190. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  191. """
  192. schedulers = self.LearningRate["schedulers"]
  193. for sch in schedulers:
  194. key = "name" if "name" in sch else "_type_"
  195. if sch[key] == "PiecewiseDecay":
  196. sch["milestones"] = milestones
  197. def update_batch_size(self, batch_size: int, mode: str = "train"):
  198. """update batch size setting
  199. Args:
  200. batch_size (int): the batch size number to set.
  201. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  202. Defaults to 'train'.
  203. Raises:
  204. ValueError: mode error.
  205. """
  206. assert mode in (
  207. "train",
  208. "eval",
  209. "test",
  210. ), "mode ({}) should be train, eval or test".format(mode)
  211. if mode == "train":
  212. self.TrainReader["batch_size"] = batch_size
  213. elif mode == "eval":
  214. self.EvalReader["batch_size"] = batch_size
  215. else:
  216. self.TestReader["batch_size"] = batch_size
  217. def update_epochs(self, epochs: int):
  218. """update epochs setting
  219. Args:
  220. epochs (int): the epochs number value to set
  221. """
  222. self.update({"epoch": epochs})
  223. def update_device(self, device_type: str):
  224. """update device setting
  225. Args:
  226. device (str): the running device to set
  227. """
  228. if device_type.lower() == "gpu":
  229. self["use_gpu"] = True
  230. elif device_type.lower() == "xpu":
  231. self["use_xpu"] = True
  232. self["use_gpu"] = False
  233. elif device_type.lower() == "npu":
  234. self["use_npu"] = True
  235. self["use_gpu"] = False
  236. elif device_type.lower() == "mlu":
  237. self["use_mlu"] = True
  238. self["use_gpu"] = False
  239. elif device_type.lower() == "gcu":
  240. self["use_gcu"] = True
  241. self["use_gpu"] = False
  242. else:
  243. assert device_type.lower() == "cpu"
  244. self["use_gpu"] = False
  245. def update_save_dir(self, save_dir: str):
  246. """update directory to save outputs
  247. Args:
  248. save_dir (str): the directory to save outputs.
  249. """
  250. self["save_dir"] = abspath(save_dir)
  251. def update_log_interval(self, log_interval: int):
  252. """update log interval(steps)
  253. Args:
  254. log_interval (int): the log interval value to set.
  255. """
  256. self.update({"log_iter": log_interval})
  257. def update_eval_interval(self, eval_interval: int):
  258. """update eval interval(epochs)
  259. Args:
  260. eval_interval (int): the eval interval value to set.
  261. """
  262. self.update({"snapshot_epoch": eval_interval})
  263. def update_save_interval(self, save_interval: int):
  264. """update eval interval(epochs)
  265. Args:
  266. save_interval (int): the save interval value to set.
  267. """
  268. self.update({"snapshot_epoch": save_interval})
  269. def update_log_ranks(self, device):
  270. """update log ranks
  271. Args:
  272. device (str): the running device to set
  273. """
  274. log_ranks = device.split(":")[1]
  275. self.update({"log_ranks": log_ranks})
  276. def update_print_mem_info(self, print_mem_info: bool):
  277. """setting print memory info"""
  278. assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
  279. self.update({"print_mem_info": f"{print_mem_info}"})
  280. def update_shared_memory(self, shared_memeory: bool):
  281. """update shared memory setting of train and eval dataloader
  282. Args:
  283. shared_memeory (bool): whether or not to use shared memory
  284. """
  285. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  286. self.update({"print_mem_info": f"{shared_memeory}"})
  287. def update_shuffle(self, shuffle: bool):
  288. """update shuffle setting of train and eval dataloader
  289. Args:
  290. shuffle (bool): whether or not to shuffle the data
  291. """
  292. assert isinstance(shuffle, bool), "shuffle should be a bool"
  293. self.update({"TrainReader": {"shuffle": shuffle}})
  294. self.update({"EvalReader": {"shuffle": shuffle}})
  295. def update_weights(self, weight_path: str):
  296. """update model weight
  297. Args:
  298. weight_path (str): the path to weight file of model.
  299. """
  300. self["weights"] = weight_path
  301. def update_pretrained_weights(self, pretrain_weights: str):
  302. """update pretrained weight path
  303. Args:
  304. pretrained_model (str): the local path or url of pretrained weight file to set.
  305. """
  306. if not pretrain_weights.startswith(
  307. "http://"
  308. ) and not pretrain_weights.startswith("https://"):
  309. pretrain_weights = abspath(pretrain_weights)
  310. self["pretrain_weights"] = pretrain_weights
  311. def update_num_class(self, num_classes: int):
  312. """update classes number
  313. Args:
  314. num_classes (int): the classes number value to set.
  315. """
  316. self["num_classes"] = num_classes
  317. if "CenterNet" in self.model_name:
  318. for i in range(len(self["TrainReader"]["sample_transforms"])):
  319. if (
  320. "Gt2CenterNetTarget"
  321. in self["TrainReader"]["sample_transforms"][i].keys()
  322. ):
  323. self["TrainReader"]["sample_transforms"][i]["Gt2CenterNetTarget"][
  324. "num_classes"
  325. ] = num_classes
  326. def update_random_size(self, randomsize: list[list[int, int]]):
  327. """update `target_size` of `BatchRandomResize` op in TestReader
  328. Args:
  329. randomsize (list[list[int, int]]): the list of different size scales.
  330. """
  331. self.TestReader["batch_transforms"]["BatchRandomResize"][
  332. "target_size"
  333. ] = randomsize
  334. def update_num_workers(self, num_workers: int):
  335. """update workers number of train and eval dataloader
  336. Args:
  337. num_workers (int): the value of train and eval dataloader workers number to set.
  338. """
  339. self["worker_num"] = num_workers
  340. def _recursively_set(self, config: dict, update_dict: dict):
  341. """recursively set config
  342. Args:
  343. config (dict): the original config.
  344. update_dict (dict): to be updated paramenters and its values
  345. Example:
  346. self._recursively_set(self.HybridEncoder, {'encoder_layer': {'dim_feedforward': 2048}})
  347. """
  348. assert isinstance(update_dict, dict)
  349. for key in update_dict:
  350. if key not in config:
  351. logging.info(f"A new filed of config to set found: {repr(key)}.")
  352. config[key] = update_dict[key]
  353. elif not isinstance(update_dict[key], dict):
  354. config[key] = update_dict[key]
  355. else:
  356. self._recursively_set(config[key], update_dict[key])
  357. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  358. """update static assigner epochs value
  359. Args:
  360. static_assigner_epochs (int): the value of static assigner epochs
  361. """
  362. assert "PicoHeadV2" in self
  363. self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
  364. def update_HybridEncoder(self, update_dict: dict):
  365. """update the HybridEncoder neck setting
  366. Args:
  367. update_dict (dict): the HybridEncoder setting.
  368. """
  369. assert "HybridEncoder" in self
  370. self._recursively_set(self.HybridEncoder, update_dict)
  371. def get_epochs_iters(self) -> int:
  372. """get epochs
  373. Returns:
  374. int: the epochs value, i.e., `Global.epochs` in config.
  375. """
  376. return self.epoch
  377. def get_log_interval(self) -> int:
  378. """get log interval(steps)
  379. Returns:
  380. int: the log interval value, i.e., `Global.print_batch_step` in config.
  381. """
  382. self.log_iter
  383. def get_eval_interval(self) -> int:
  384. """get eval interval(epochs)
  385. Returns:
  386. int: the eval interval value, i.e., `Global.eval_interval` in config.
  387. """
  388. self.snapshot_epoch
  389. def get_save_interval(self) -> int:
  390. """get save interval(epochs)
  391. Returns:
  392. int: the save interval value, i.e., `Global.save_interval` in config.
  393. """
  394. self.snapshot_epoch
  395. def get_learning_rate(self) -> float:
  396. """get learning rate
  397. Returns:
  398. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  399. """
  400. return self.LearningRate["base_lr"]
  401. def get_batch_size(self, mode="train") -> int:
  402. """get batch size
  403. Args:
  404. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  405. Defaults to 'train'.
  406. Returns:
  407. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  408. """
  409. if mode == "train":
  410. return self.TrainReader["batch_size"]
  411. elif mode == "eval":
  412. return self.EvalReader["batch_size"]
  413. elif mode == "test":
  414. return self.TestReader["batch_size"]
  415. else:
  416. raise (f"Unknown mode: {repr(mode)}")
  417. def get_qat_epochs_iters(self) -> int:
  418. """get qat epochs
  419. Returns:
  420. int: the epochs value.
  421. """
  422. return self.epoch // 2.0
  423. def get_qat_learning_rate(self) -> float:
  424. """get qat learning rate
  425. Returns:
  426. float: the learning rate value.
  427. """
  428. return self.LearningRate["base_lr"] // 2.0
  429. def get_train_save_dir(self) -> str:
  430. """get the directory to save output
  431. Returns:
  432. str: the directory to save output
  433. """
  434. return self.save_dir