config.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...base import BaseConfig
  15. from ....utils.misc import abspath
  16. from ....utils import logging
  17. from ..config_helper import PPDetConfigMixin
  18. class DetConfig(BaseConfig, PPDetConfigMixin):
  19. """DetConfig"""
  20. def load(self, config_path: str):
  21. """load the config from config file
  22. Args:
  23. config_path (str): the config file path.
  24. """
  25. dict_ = self.load_config_literally(config_path)
  26. self.reset_from_dict(dict_)
  27. def dump(self, config_path: str):
  28. """dump the config
  29. Args:
  30. config_path (str): the path to save dumped config.
  31. """
  32. self.dump_literal_config(config_path, self._dict)
  33. def update(self, dict_like_obj: list):
  34. """update self from dict
  35. Args:
  36. dict_like_obj (list): the list of pairs that contain key and value.
  37. """
  38. self.update_from_dict(dict_like_obj, self._dict)
  39. def update_dataset(
  40. self,
  41. dataset_path: str,
  42. dataset_type: str = None,
  43. *,
  44. data_fields: list[str] = None,
  45. image_dir: str = "images",
  46. train_anno_path: str = "annotations/instance_train.json",
  47. val_anno_path: str = "annotations/instance_val.json",
  48. test_anno_path: str = "annotations/instance_val.json",
  49. ):
  50. """update dataset settings
  51. Args:
  52. dataset_path (str): the root path fo dataset.
  53. dataset_type (str, optional): the dataset type. Defaults to None.
  54. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  55. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  56. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  57. Defaults to "annotations/instance_train.json".
  58. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  59. Defaults to "annotations/instance_val.json".
  60. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  61. Defaults to "annotations/instance_val.json".
  62. Raises:
  63. ValueError: the `dataset_type` error.
  64. """
  65. dataset_path = abspath(dataset_path)
  66. if dataset_type is None:
  67. dataset_type = "COCODetDataset"
  68. if dataset_type == "COCODetDataset":
  69. ds_cfg = self._make_dataset_config(
  70. dataset_path,
  71. data_fields,
  72. image_dir,
  73. train_anno_path,
  74. val_anno_path,
  75. test_anno_path,
  76. )
  77. self.set_val("metric", "COCO")
  78. else:
  79. raise ValueError(f"{repr(dataset_type)} is not supported.")
  80. self.update(ds_cfg)
  81. def _make_dataset_config(
  82. self,
  83. dataset_root_path: str,
  84. data_fields: list[str,] = None,
  85. image_dir: str = "images",
  86. train_anno_path: str = "annotations/instance_train.json",
  87. val_anno_path: str = "annotations/instance_val.json",
  88. test_anno_path: str = "annotations/instance_val.json",
  89. ) -> dict:
  90. """construct the dataset config that meets the format requirements
  91. Args:
  92. dataset_root_path (str): the root directory of dataset.
  93. data_fields (list[str,], optional): the data field. Defaults to None.
  94. image_dir (str, optional): _description_. Defaults to "images".
  95. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  96. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  97. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  98. Returns:
  99. dict: the dataset config.
  100. """
  101. data_fields = (
  102. ["image", "gt_bbox", "gt_class", "is_crowd"]
  103. if data_fields is None
  104. else data_fields
  105. )
  106. return {
  107. "TrainDataset": {
  108. "name": "COCODetDataset",
  109. "image_dir": image_dir,
  110. "anno_path": train_anno_path,
  111. "dataset_dir": dataset_root_path,
  112. "data_fields": data_fields,
  113. },
  114. "EvalDataset": {
  115. "name": "COCODetDataset",
  116. "image_dir": image_dir,
  117. "anno_path": val_anno_path,
  118. "dataset_dir": dataset_root_path,
  119. },
  120. "TestDataset": {
  121. "name": "ImageFolder",
  122. "anno_path": test_anno_path,
  123. "dataset_dir": dataset_root_path,
  124. },
  125. }
  126. def update_ema(
  127. self,
  128. use_ema: bool,
  129. ema_decay: float = 0.9999,
  130. ema_decay_type: str = "exponential",
  131. ema_filter_no_grad: bool = True,
  132. ):
  133. """update EMA setting
  134. Args:
  135. use_ema (bool): whether or not to use EMA
  136. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  137. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  138. ema_filter_no_grad (bool, optional): whether or not to filter the parameters
  139. that been set to stop gradient and are not batch norm parameters. Defaults to True.
  140. """
  141. self.update(
  142. {
  143. "use_ema": use_ema,
  144. "ema_decay": ema_decay,
  145. "ema_decay_type": ema_decay_type,
  146. "ema_filter_no_grad": ema_filter_no_grad,
  147. }
  148. )
  149. def update_learning_rate(self, learning_rate: float):
  150. """update learning rate
  151. Args:
  152. learning_rate (float): the learning rate value to set.
  153. """
  154. self.LearningRate["base_lr"] = learning_rate
  155. def update_warmup_steps(self, warmup_steps: int):
  156. """update warmup steps
  157. Args:
  158. warmup_steps (int): the warmup steps value to set.
  159. """
  160. schedulers = self.LearningRate["schedulers"]
  161. for sch in schedulers:
  162. key = "name" if "name" in sch else "_type_"
  163. if sch[key] == "LinearWarmup":
  164. sch["steps"] = warmup_steps
  165. sch["epochs_first"] = False
  166. def update_warmup_enable(self, use_warmup: bool):
  167. """whether or not to enable learning rate warmup
  168. Args:
  169. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  170. """
  171. schedulers = self.LearningRate["schedulers"]
  172. for sch in schedulers:
  173. if "use_warmup" in sch:
  174. sch["use_warmup"] = use_warmup
  175. def update_cossch_epoch(self, max_epochs: int):
  176. """update max epochs of cosine learning rate scheduler
  177. Args:
  178. max_epochs (int): the max epochs value.
  179. """
  180. schedulers = self.LearningRate["schedulers"]
  181. for sch in schedulers:
  182. key = "name" if "name" in sch else "_type_"
  183. if sch[key] == "CosineDecay":
  184. sch["max_epochs"] = max_epochs
  185. def update_milestone(self, milestones: list[int]):
  186. """update milstone of `PiecewiseDecay` learning scheduler
  187. Args:
  188. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  189. """
  190. schedulers = self.LearningRate["schedulers"]
  191. for sch in schedulers:
  192. key = "name" if "name" in sch else "_type_"
  193. if sch[key] == "PiecewiseDecay":
  194. sch["milestones"] = milestones
  195. def update_batch_size(self, batch_size: int, mode: str = "train"):
  196. """update batch size setting
  197. Args:
  198. batch_size (int): the batch size number to set.
  199. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  200. Defaults to 'train'.
  201. Raises:
  202. ValueError: mode error.
  203. """
  204. assert mode in (
  205. "train",
  206. "eval",
  207. "test",
  208. ), "mode ({}) should be train, eval or test".format(mode)
  209. if mode == "train":
  210. self.TrainReader["batch_size"] = batch_size
  211. elif mode == "eval":
  212. self.EvalReader["batch_size"] = batch_size
  213. else:
  214. self.TestReader["batch_size"] = batch_size
  215. def update_epochs(self, epochs: int):
  216. """update epochs setting
  217. Args:
  218. epochs (int): the epochs number value to set
  219. """
  220. self.update({"epoch": epochs})
  221. def update_device(self, device_type: str):
  222. """update device setting
  223. Args:
  224. device (str): the running device to set
  225. """
  226. if device_type.lower() == "gpu":
  227. self["use_gpu"] = True
  228. elif device_type.lower() == "xpu":
  229. self["use_xpu"] = True
  230. self["use_gpu"] = False
  231. elif device_type.lower() == "npu":
  232. self["use_npu"] = True
  233. self["use_gpu"] = False
  234. elif device_type.lower() == "mlu":
  235. self["use_mlu"] = True
  236. self["use_gpu"] = False
  237. elif device_type.lower() == "gcu":
  238. self["use_gcu"] = True
  239. self["use_gpu"] = False
  240. else:
  241. assert device_type.lower() == "cpu"
  242. self["use_gpu"] = False
  243. def update_save_dir(self, save_dir: str):
  244. """update directory to save outputs
  245. Args:
  246. save_dir (str): the directory to save outputs.
  247. """
  248. self["save_dir"] = abspath(save_dir)
  249. def update_log_interval(self, log_interval: int):
  250. """update log interval(steps)
  251. Args:
  252. log_interval (int): the log interval value to set.
  253. """
  254. self.update({"log_iter": log_interval})
  255. def update_eval_interval(self, eval_interval: int):
  256. """update eval interval(epochs)
  257. Args:
  258. eval_interval (int): the eval interval value to set.
  259. """
  260. self.update({"snapshot_epoch": eval_interval})
  261. def update_save_interval(self, save_interval: int):
  262. """update eval interval(epochs)
  263. Args:
  264. save_interval (int): the save interval value to set.
  265. """
  266. self.update({"snapshot_epoch": save_interval})
  267. def update_log_ranks(self, device):
  268. """update log ranks
  269. Args:
  270. device (str): the running device to set
  271. """
  272. log_ranks = device.split(":")[1]
  273. self.update({"log_ranks": log_ranks})
  274. def update_print_mem_info(self, print_mem_info: bool):
  275. """setting print memory info"""
  276. assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
  277. self.update({"print_mem_info": f"{print_mem_info}"})
  278. def update_shared_memory(self, shared_memeory: bool):
  279. """update shared memory setting of train and eval dataloader
  280. Args:
  281. shared_memeory (bool): whether or not to use shared memory
  282. """
  283. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  284. self.update({"print_mem_info": f"{shared_memeory}"})
  285. def update_shuffle(self, shuffle: bool):
  286. """update shuffle setting of train and eval dataloader
  287. Args:
  288. shuffle (bool): whether or not to shuffle the data
  289. """
  290. assert isinstance(shuffle, bool), "shuffle should be a bool"
  291. self.update({"TrainReader": {"shuffle": shuffle}})
  292. self.update({"EvalReader": {"shuffle": shuffle}})
  293. def update_weights(self, weight_path: str):
  294. """update model weight
  295. Args:
  296. weight_path (str): the path to weight file of model.
  297. """
  298. self["weights"] = weight_path
  299. def update_pretrained_weights(self, pretrain_weights: str):
  300. """update pretrained weight path
  301. Args:
  302. pretrained_model (str): the local path or url of pretrained weight file to set.
  303. """
  304. if not pretrain_weights.startswith(
  305. "http://"
  306. ) and not pretrain_weights.startswith("https://"):
  307. pretrain_weights = abspath(pretrain_weights)
  308. self["pretrain_weights"] = pretrain_weights
  309. def update_num_class(self, num_classes: int):
  310. """update classes number
  311. Args:
  312. num_classes (int): the classes number value to set.
  313. """
  314. self["num_classes"] = num_classes
  315. if "CenterNet" in self.model_name:
  316. for i in range(len(self["TrainReader"]["sample_transforms"])):
  317. if (
  318. "Gt2CenterNetTarget"
  319. in self["TrainReader"]["sample_transforms"][i].keys()
  320. ):
  321. self["TrainReader"]["sample_transforms"][i]["Gt2CenterNetTarget"][
  322. "num_classes"
  323. ] = num_classes
  324. def update_random_size(self, randomsize: list[list[int, int]]):
  325. """update `target_size` of `BatchRandomResize` op in TestReader
  326. Args:
  327. randomsize (list[list[int, int]]): the list of different size scales.
  328. """
  329. self.TestReader["batch_transforms"]["BatchRandomResize"][
  330. "target_size"
  331. ] = randomsize
  332. def update_num_workers(self, num_workers: int):
  333. """update workers number of train and eval dataloader
  334. Args:
  335. num_workers (int): the value of train and eval dataloader workers number to set.
  336. """
  337. self["worker_num"] = num_workers
  338. def _recursively_set(self, config: dict, update_dict: dict):
  339. """recursively set config
  340. Args:
  341. config (dict): the original config.
  342. update_dict (dict): to be updated paramenters and its values
  343. Example:
  344. self._recursively_set(self.HybridEncoder, {'encoder_layer': {'dim_feedforward': 2048}})
  345. """
  346. assert isinstance(update_dict, dict)
  347. for key in update_dict:
  348. if key not in config:
  349. logging.info(f"A new filed of config to set found: {repr(key)}.")
  350. config[key] = update_dict[key]
  351. elif not isinstance(update_dict[key], dict):
  352. config[key] = update_dict[key]
  353. else:
  354. self._recursively_set(config[key], update_dict[key])
  355. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  356. """update static assigner epochs value
  357. Args:
  358. static_assigner_epochs (int): the value of static assigner epochs
  359. """
  360. assert "PicoHeadV2" in self
  361. self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
  362. def update_HybridEncoder(self, update_dict: dict):
  363. """update the HybridEncoder neck setting
  364. Args:
  365. update_dict (dict): the HybridEncoder setting.
  366. """
  367. assert "HybridEncoder" in self
  368. self._recursively_set(self.HybridEncoder, update_dict)
  369. def get_epochs_iters(self) -> int:
  370. """get epochs
  371. Returns:
  372. int: the epochs value, i.e., `Global.epochs` in config.
  373. """
  374. return self.epoch
  375. def get_log_interval(self) -> int:
  376. """get log interval(steps)
  377. Returns:
  378. int: the log interval value, i.e., `Global.print_batch_step` in config.
  379. """
  380. self.log_iter
  381. def get_eval_interval(self) -> int:
  382. """get eval interval(epochs)
  383. Returns:
  384. int: the eval interval value, i.e., `Global.eval_interval` in config.
  385. """
  386. self.snapshot_epoch
  387. def get_save_interval(self) -> int:
  388. """get save interval(epochs)
  389. Returns:
  390. int: the save interval value, i.e., `Global.save_interval` in config.
  391. """
  392. self.snapshot_epoch
  393. def get_learning_rate(self) -> float:
  394. """get learning rate
  395. Returns:
  396. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  397. """
  398. return self.LearningRate["base_lr"]
  399. def get_batch_size(self, mode="train") -> int:
  400. """get batch size
  401. Args:
  402. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  403. Defaults to 'train'.
  404. Returns:
  405. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  406. """
  407. if mode == "train":
  408. return self.TrainReader["batch_size"]
  409. elif mode == "eval":
  410. return self.EvalReader["batch_size"]
  411. elif mode == "test":
  412. return self.TestReader["batch_size"]
  413. else:
  414. raise (f"Unknown mode: {repr(mode)}")
  415. def get_qat_epochs_iters(self) -> int:
  416. """get qat epochs
  417. Returns:
  418. int: the epochs value.
  419. """
  420. return self.epoch // 2.0
  421. def get_qat_learning_rate(self) -> float:
  422. """get qat learning rate
  423. Returns:
  424. float: the learning rate value.
  425. """
  426. return self.LearningRate["base_lr"] // 2.0
  427. def get_train_save_dir(self) -> str:
  428. """get the directory to save output
  429. Returns:
  430. str: the directory to save output
  431. """
  432. return self.save_dir