config.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. from ....utils.misc import abspath
  16. from ..object_det.config import DetConfig
  17. class InstanceSegConfig(DetConfig):
  18. """InstanceSegConfig"""
  19. def load(self, config_path: str):
  20. """load the config from config file
  21. Args:
  22. config_path (str): the config file path.
  23. """
  24. dict_ = self.load_config_literally(config_path)
  25. self.reset_from_dict(dict_)
  26. def dump(self, config_path: str):
  27. """dump the config
  28. Args:
  29. config_path (str): the path to save dumped config.
  30. """
  31. self.dump_literal_config(config_path, self._dict)
  32. def update(self, dict_like_obj: list):
  33. """update self from dict
  34. Args:
  35. dict_like_obj (list): the list of pairs that contain key and value.
  36. """
  37. self.update_from_dict(dict_like_obj, self._dict)
  38. def update_dataset(
  39. self,
  40. dataset_path: str,
  41. dataset_type: str = None,
  42. *,
  43. data_fields: List[str] = None,
  44. image_dir: str = "images",
  45. train_anno_path: str = "annotations/instance_train.json",
  46. val_anno_path: str = "annotations/instance_val.json",
  47. test_anno_path: str = "annotations/instance_val.json",
  48. ):
  49. """update dataset settings
  50. Args:
  51. dataset_path (str): the root path fo dataset.
  52. dataset_type (str, optional): the dataset type. Defaults to None.
  53. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  54. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  55. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  56. Defaults to "annotations/instance_train.json".
  57. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  58. Defaults to "annotations/instance_val.json".
  59. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  60. Defaults to "annotations/instance_val.json".
  61. Raises:
  62. ValueError: the `dataset_type` error.
  63. """
  64. dataset_path = abspath(dataset_path)
  65. if dataset_type is None:
  66. dataset_type = "COCOInstSegDataset"
  67. if dataset_type == "COCOInstSegDataset":
  68. ds_cfg = self._make_dataset_config(
  69. dataset_path,
  70. data_fields,
  71. image_dir,
  72. train_anno_path,
  73. val_anno_path,
  74. test_anno_path,
  75. )
  76. self.set_val("metric", "COCO")
  77. else:
  78. raise ValueError(f"{repr(dataset_type)} is not supported.")
  79. self.update(ds_cfg)
  80. def _make_dataset_config(
  81. self,
  82. dataset_root_path: str,
  83. data_fields: List[str,] = None,
  84. image_dir: str = "images",
  85. train_anno_path: str = "annotations/instance_train.json",
  86. val_anno_path: str = "annotations/instance_val.json",
  87. test_anno_path: str = "annotations/instance_val.json",
  88. ) -> dict:
  89. """construct the dataset config that meets the format requirements
  90. Args:
  91. dataset_root_path (str): the root directory of dataset.
  92. data_fields (list[str,], optional): the data field. Defaults to None.
  93. image_dir (str, optional): _description_. Defaults to "images".
  94. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  95. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  96. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  97. Returns:
  98. dict: the dataset config.
  99. """
  100. data_fields = (
  101. ["image", "gt_bbox", "gt_poly", "gt_class", "is_crowd"]
  102. if data_fields is None
  103. else data_fields
  104. )
  105. return {
  106. "TrainDataset": {
  107. "name": "COCOInstSegDataset",
  108. "image_dir": image_dir,
  109. "anno_path": train_anno_path,
  110. "dataset_dir": dataset_root_path,
  111. "data_fields": data_fields,
  112. },
  113. "EvalDataset": {
  114. "name": "COCOInstSegDataset",
  115. "image_dir": image_dir,
  116. "anno_path": val_anno_path,
  117. "dataset_dir": dataset_root_path,
  118. },
  119. "TestDataset": {
  120. "name": "ImageFolder",
  121. "anno_path": test_anno_path,
  122. "dataset_dir": dataset_root_path,
  123. },
  124. }
  125. def update_ema(
  126. self,
  127. use_ema: bool,
  128. ema_decay: float = 0.9999,
  129. ema_decay_type: str = "exponential",
  130. ema_filter_no_grad: bool = True,
  131. ):
  132. """update EMA setting
  133. Args:
  134. use_ema (bool): whether or not to use EMA
  135. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  136. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  137. ema_filter_no_grad (bool, optional): whether or not to filter the parameters that been set to stop gradient
  138. and are not batch norm parameters. Defaults to True.
  139. """
  140. self.update(
  141. {
  142. "use_ema": use_ema,
  143. "ema_decay": ema_decay,
  144. "ema_decay_type": ema_decay_type,
  145. "ema_filter_no_grad": ema_filter_no_grad,
  146. }
  147. )
  148. def update_learning_rate(self, learning_rate: float):
  149. """update learning rate
  150. Args:
  151. learning_rate (float): the learning rate value to set.
  152. """
  153. self.LearningRate["base_lr"] = learning_rate
  154. def update_warmup_steps(self, warmup_steps: int):
  155. """update warmup steps
  156. Args:
  157. warmup_steps (int): the warmup steps value to set.
  158. """
  159. schedulers = self.LearningRate["schedulers"]
  160. for sch in schedulers:
  161. key = "name" if "name" in sch else "_type_"
  162. if sch[key] == "LinearWarmup":
  163. sch["steps"] = warmup_steps
  164. sch["epochs_first"] = False
  165. def update_warmup_enable(self, use_warmup: bool):
  166. """whether or not to enable learning rate warmup
  167. Args:
  168. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  169. """
  170. schedulers = self.LearningRate["schedulers"]
  171. for sch in schedulers:
  172. if "use_warmup" in sch:
  173. sch["use_warmup"] = use_warmup
  174. def update_cossch_epoch(self, max_epochs: int):
  175. """update max epochs of cosine learning rate scheduler
  176. Args:
  177. max_epochs (int): the max epochs value.
  178. """
  179. schedulers = self.LearningRate["schedulers"]
  180. for sch in schedulers:
  181. key = "name" if "name" in sch else "_type_"
  182. if sch[key] == "CosineDecay":
  183. sch["max_epochs"] = max_epochs
  184. def update_milestone(self, milestones: List[int]):
  185. """update milstone of `PiecewiseDecay` learning scheduler
  186. Args:
  187. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  188. """
  189. schedulers = self.LearningRate["schedulers"]
  190. for sch in schedulers:
  191. key = "name" if "name" in sch else "_type_"
  192. if sch[key] == "PiecewiseDecay":
  193. sch["milestones"] = milestones
  194. def update_batch_size(self, batch_size: int, mode: str = "train"):
  195. """update batch size setting
  196. Args:
  197. batch_size (int): the batch size number to set.
  198. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  199. Defaults to 'train'.
  200. Raises:
  201. ValueError: mode error.
  202. """
  203. assert mode in (
  204. "train",
  205. "eval",
  206. "test",
  207. ), "mode ({}) should be train, eval or test".format(mode)
  208. if mode == "train":
  209. self.TrainReader["batch_size"] = batch_size
  210. elif mode == "eval":
  211. self.EvalReader["batch_size"] = batch_size
  212. else:
  213. self.TestReader["batch_size"] = batch_size
  214. def update_epochs(self, epochs: int):
  215. """update epochs setting
  216. Args:
  217. epochs (int): the epochs number value to set
  218. """
  219. self.update({"epoch": epochs})
  220. def update_device(self, device_type: str):
  221. """update device setting
  222. Args:
  223. device (str): the running device to set
  224. """
  225. if device_type.lower() == "gpu":
  226. self["use_gpu"] = True
  227. elif device_type.lower() == "xpu":
  228. self["use_xpu"] = True
  229. self["use_gpu"] = False
  230. elif device_type.lower() == "npu":
  231. self["use_npu"] = True
  232. self["use_gpu"] = False
  233. elif device_type.lower() == "mlu":
  234. self["use_mlu"] = True
  235. self["use_gpu"] = False
  236. elif device_type.lower() == "gcu":
  237. self["use_gcu"] = True
  238. self["use_gpu"] = False
  239. else:
  240. assert device_type.lower() == "cpu"
  241. self["use_gpu"] = False
  242. def update_save_dir(self, save_dir: str):
  243. """update directory to save outputs
  244. Args:
  245. save_dir (str): the directory to save outputs.
  246. """
  247. self["save_dir"] = abspath(save_dir)
  248. def update_log_interval(self, log_interval: int):
  249. """update log interval(steps)
  250. Args:
  251. log_interval (int): the log interval value to set.
  252. """
  253. self.update({"log_iter": log_interval})
  254. def update_eval_interval(self, eval_interval: int):
  255. """update eval interval(epochs)
  256. Args:
  257. eval_interval (int): the eval interval value to set.
  258. """
  259. self.update({"snapshot_epoch": eval_interval})
  260. def update_save_interval(self, save_interval: int):
  261. """update eval interval(epochs)
  262. Args:
  263. save_interval (int): the save interval value to set.
  264. """
  265. self.update({"snapshot_epoch": save_interval})
  266. def update_log_ranks(self, device):
  267. """update log ranks
  268. Args:
  269. device (str): the running device to set
  270. """
  271. log_ranks = device.split(":")[1]
  272. self.update({"log_ranks": log_ranks})
  273. def update_weights(self, weight_path: str):
  274. """update model weight
  275. Args:
  276. weight_path (str): the path to weight file of model.
  277. """
  278. self["weights"] = weight_path
  279. def update_pretrained_weights(self, pretrain_weights: str):
  280. """update pretrained weight path
  281. Args:
  282. pretrained_model (str): the local path or url of pretrained weight file to set.
  283. """
  284. if not pretrain_weights.startswith(
  285. "http://"
  286. ) and not pretrain_weights.startswith("https://"):
  287. pretrain_weights = abspath(pretrain_weights)
  288. self["pretrain_weights"] = pretrain_weights
  289. def update_num_class(self, num_classes: int):
  290. """update classes number
  291. Args:
  292. num_classes (int): the classes number value to set.
  293. """
  294. self["num_classes"] = num_classes
  295. def update_random_size(self, randomsize):
  296. """update `target_size` of `BatchRandomResize` op in TestReader
  297. Args:
  298. randomsize (list[list[int, int]]): the list of different size scales.
  299. """
  300. self.TestReader["batch_transforms"]["BatchRandomResize"][
  301. "target_size"
  302. ] = randomsize
  303. def update_num_workers(self, num_workers: int):
  304. """update workers number of train and eval dataloader
  305. Args:
  306. num_workers (int): the value of train and eval dataloader workers number to set.
  307. """
  308. self["worker_num"] = num_workers
  309. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  310. """update static assigner epochs value
  311. Args:
  312. static_assigner_epochs (int): the value of static assigner epochs
  313. """
  314. assert "PicoHeadV2" in self
  315. self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
  316. def get_epochs_iters(self) -> int:
  317. """get epochs
  318. Returns:
  319. int: the epochs value, i.e., `Global.epochs` in config.
  320. """
  321. return self.epoch
  322. def get_log_interval(self) -> int:
  323. """get log interval(steps)
  324. Returns:
  325. int: the log interval value, i.e., `Global.print_batch_step` in config.
  326. """
  327. self.log_iter
  328. def get_eval_interval(self) -> int:
  329. """get eval interval(epochs)
  330. Returns:
  331. int: the eval interval value, i.e., `Global.eval_interval` in config.
  332. """
  333. self.snapshot_epoch
  334. def get_save_interval(self) -> int:
  335. """get save interval(epochs)
  336. Returns:
  337. int: the save interval value, i.e., `Global.save_interval` in config.
  338. """
  339. self.snapshot_epoch
  340. def get_learning_rate(self) -> float:
  341. """get learning rate
  342. Returns:
  343. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  344. """
  345. return self.LearningRate["base_lr"]
  346. def get_batch_size(self, mode="train") -> int:
  347. """get batch size
  348. Args:
  349. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  350. Defaults to 'train'.
  351. Returns:
  352. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  353. """
  354. if mode == "train":
  355. return self.TrainReader["batch_size"]
  356. elif mode == "eval":
  357. return self.EvalReader["batch_size"]
  358. elif mode == "test":
  359. return self.TestReader["batch_size"]
  360. else:
  361. raise (f"Unknown mode: {repr(mode)}")
  362. def get_qat_epochs_iters(self) -> int:
  363. """get qat epochs
  364. Returns:
  365. int: the epochs value.
  366. """
  367. return self.epoch // 2.0
  368. def get_qat_learning_rate(self) -> float:
  369. """get qat learning rate
  370. Returns:
  371. float: the learning rate value.
  372. """
  373. return self.LearningRate["base_lr"] // 2.0
  374. def get_train_save_dir(self) -> str:
  375. """get the directory to save output
  376. Returns:
  377. str: the directory to save output
  378. """
  379. return self.save_dir