config.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. from ...base import BaseConfig
  16. from ....utils.misc import abspath
  17. from ..config_helper import PPDetConfigMixin
  18. from ..object_det.config import DetConfig
  19. class InstanceSegConfig(DetConfig):
  20. """InstanceSegConfig"""
  21. def load(self, config_path: str):
  22. """load the config from config file
  23. Args:
  24. config_path (str): the config file path.
  25. """
  26. dict_ = self.load_config_literally(config_path)
  27. self.reset_from_dict(dict_)
  28. def dump(self, config_path: str):
  29. """dump the config
  30. Args:
  31. config_path (str): the path to save dumped config.
  32. """
  33. self.dump_literal_config(config_path, self._dict)
  34. def update(self, dict_like_obj: list):
  35. """update self from dict
  36. Args:
  37. dict_like_obj (list): the list of pairs that contain key and value.
  38. """
  39. self.update_from_dict(dict_like_obj, self._dict)
  40. def update_dataset(
  41. self,
  42. dataset_path: str,
  43. dataset_type: str = None,
  44. *,
  45. data_fields: List[str] = None,
  46. image_dir: str = "images",
  47. train_anno_path: str = "annotations/instance_train.json",
  48. val_anno_path: str = "annotations/instance_val.json",
  49. test_anno_path: str = "annotations/instance_val.json",
  50. ):
  51. """update dataset settings
  52. Args:
  53. dataset_path (str): the root path fo dataset.
  54. dataset_type (str, optional): the dataset type. Defaults to None.
  55. data_fields (list[str], optional): the data fields in dataset. Defaults to None.
  56. image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
  57. train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
  58. Defaults to "annotations/instance_train.json".
  59. val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
  60. Defaults to "annotations/instance_val.json".
  61. test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
  62. Defaults to "annotations/instance_val.json".
  63. Raises:
  64. ValueError: the `dataset_type` error.
  65. """
  66. dataset_path = abspath(dataset_path)
  67. if dataset_type is None:
  68. dataset_type = "COCOInstSegDataset"
  69. if dataset_type == "COCOInstSegDataset":
  70. ds_cfg = self._make_dataset_config(
  71. dataset_path,
  72. data_fields,
  73. image_dir,
  74. train_anno_path,
  75. val_anno_path,
  76. test_anno_path,
  77. )
  78. self.set_val("metric", "COCO")
  79. else:
  80. raise ValueError(f"{repr(dataset_type)} is not supported.")
  81. self.update(ds_cfg)
  82. def _make_dataset_config(
  83. self,
  84. dataset_root_path: str,
  85. data_fields: List[str,] = None,
  86. image_dir: str = "images",
  87. train_anno_path: str = "annotations/instance_train.json",
  88. val_anno_path: str = "annotations/instance_val.json",
  89. test_anno_path: str = "annotations/instance_val.json",
  90. ) -> dict:
  91. """construct the dataset config that meets the format requirements
  92. Args:
  93. dataset_root_path (str): the root directory of dataset.
  94. data_fields (list[str,], optional): the data field. Defaults to None.
  95. image_dir (str, optional): _description_. Defaults to "images".
  96. train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
  97. val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  98. test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
  99. Returns:
  100. dict: the dataset config.
  101. """
  102. data_fields = (
  103. ["image", "gt_bbox", "gt_poly", "gt_class", "is_crowd"]
  104. if data_fields is None
  105. else data_fields
  106. )
  107. return {
  108. "TrainDataset": {
  109. "name": "COCOInstSegDataset",
  110. "image_dir": image_dir,
  111. "anno_path": train_anno_path,
  112. "dataset_dir": dataset_root_path,
  113. "data_fields": data_fields,
  114. },
  115. "EvalDataset": {
  116. "name": "COCOInstSegDataset",
  117. "image_dir": image_dir,
  118. "anno_path": val_anno_path,
  119. "dataset_dir": dataset_root_path,
  120. },
  121. "TestDataset": {
  122. "name": "ImageFolder",
  123. "anno_path": test_anno_path,
  124. "dataset_dir": dataset_root_path,
  125. },
  126. }
  127. def update_ema(
  128. self,
  129. use_ema: bool,
  130. ema_decay: float = 0.9999,
  131. ema_decay_type: str = "exponential",
  132. ema_filter_no_grad: bool = True,
  133. ):
  134. """update EMA setting
  135. Args:
  136. use_ema (bool): whether or not to use EMA
  137. ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
  138. ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
  139. ema_filter_no_grad (bool, optional): whether or not to filter the parameters that been set to stop gradient
  140. and are not batch norm parameters. Defaults to True.
  141. """
  142. self.update(
  143. {
  144. "use_ema": use_ema,
  145. "ema_decay": ema_decay,
  146. "ema_decay_type": ema_decay_type,
  147. "ema_filter_no_grad": ema_filter_no_grad,
  148. }
  149. )
  150. def update_learning_rate(self, learning_rate: float):
  151. """update learning rate
  152. Args:
  153. learning_rate (float): the learning rate value to set.
  154. """
  155. self.LearningRate["base_lr"] = learning_rate
  156. def update_warmup_steps(self, warmup_steps: int):
  157. """update warmup steps
  158. Args:
  159. warmup_steps (int): the warmup steps value to set.
  160. """
  161. schedulers = self.LearningRate["schedulers"]
  162. for sch in schedulers:
  163. key = "name" if "name" in sch else "_type_"
  164. if sch[key] == "LinearWarmup":
  165. sch["steps"] = warmup_steps
  166. sch["epochs_first"] = False
  167. def update_warmup_enable(self, use_warmup: bool):
  168. """whether or not to enable learning rate warmup
  169. Args:
  170. use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
  171. """
  172. schedulers = self.LearningRate["schedulers"]
  173. for sch in schedulers:
  174. if "use_warmup" in sch:
  175. sch["use_warmup"] = use_warmup
  176. def update_cossch_epoch(self, max_epochs: int):
  177. """update max epochs of cosine learning rate scheduler
  178. Args:
  179. max_epochs (int): the max epochs value.
  180. """
  181. schedulers = self.LearningRate["schedulers"]
  182. for sch in schedulers:
  183. key = "name" if "name" in sch else "_type_"
  184. if sch[key] == "CosineDecay":
  185. sch["max_epochs"] = max_epochs
  186. def update_milestone(self, milestones: List[int]):
  187. """update milstone of `PiecewiseDecay` learning scheduler
  188. Args:
  189. milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
  190. """
  191. schedulers = self.LearningRate["schedulers"]
  192. for sch in schedulers:
  193. key = "name" if "name" in sch else "_type_"
  194. if sch[key] == "PiecewiseDecay":
  195. sch["milestones"] = milestones
  196. def update_batch_size(self, batch_size: int, mode: str = "train"):
  197. """update batch size setting
  198. Args:
  199. batch_size (int): the batch size number to set.
  200. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  201. Defaults to 'train'.
  202. Raises:
  203. ValueError: mode error.
  204. """
  205. assert mode in (
  206. "train",
  207. "eval",
  208. "test",
  209. ), "mode ({}) should be train, eval or test".format(mode)
  210. if mode == "train":
  211. self.TrainReader["batch_size"] = batch_size
  212. elif mode == "eval":
  213. self.EvalReader["batch_size"] = batch_size
  214. else:
  215. self.TestReader["batch_size"] = batch_size
  216. def update_epochs(self, epochs: int):
  217. """update epochs setting
  218. Args:
  219. epochs (int): the epochs number value to set
  220. """
  221. self.update({"epoch": epochs})
  222. def update_device(self, device_type: str):
  223. """update device setting
  224. Args:
  225. device (str): the running device to set
  226. """
  227. if device_type.lower() == "gpu":
  228. self["use_gpu"] = True
  229. elif device_type.lower() == "xpu":
  230. self["use_xpu"] = True
  231. self["use_gpu"] = False
  232. elif device_type.lower() == "npu":
  233. self["use_npu"] = True
  234. self["use_gpu"] = False
  235. elif device_type.lower() == "mlu":
  236. self["use_mlu"] = True
  237. self["use_gpu"] = False
  238. elif device_type.lower() == "gcu":
  239. self["use_gcu"] = True
  240. self["use_gpu"] = False
  241. else:
  242. assert device_type.lower() == "cpu"
  243. self["use_gpu"] = False
  244. def update_save_dir(self, save_dir: str):
  245. """update directory to save outputs
  246. Args:
  247. save_dir (str): the directory to save outputs.
  248. """
  249. self["save_dir"] = abspath(save_dir)
  250. def update_log_interval(self, log_interval: int):
  251. """update log interval(steps)
  252. Args:
  253. log_interval (int): the log interval value to set.
  254. """
  255. self.update({"log_iter": log_interval})
  256. def update_eval_interval(self, eval_interval: int):
  257. """update eval interval(epochs)
  258. Args:
  259. eval_interval (int): the eval interval value to set.
  260. """
  261. self.update({"snapshot_epoch": eval_interval})
  262. def update_save_interval(self, save_interval: int):
  263. """update eval interval(epochs)
  264. Args:
  265. save_interval (int): the save interval value to set.
  266. """
  267. self.update({"snapshot_epoch": save_interval})
  268. def update_log_ranks(self, device):
  269. """update log ranks
  270. Args:
  271. device (str): the running device to set
  272. """
  273. log_ranks = device.split(":")[1]
  274. self.update({"log_ranks": log_ranks})
  275. def update_weights(self, weight_path: str):
  276. """update model weight
  277. Args:
  278. weight_path (str): the path to weight file of model.
  279. """
  280. self["weights"] = weight_path
  281. def update_pretrained_weights(self, pretrain_weights: str):
  282. """update pretrained weight path
  283. Args:
  284. pretrained_model (str): the local path or url of pretrained weight file to set.
  285. """
  286. if not pretrain_weights.startswith(
  287. "http://"
  288. ) and not pretrain_weights.startswith("https://"):
  289. pretrain_weights = abspath(pretrain_weights)
  290. self["pretrain_weights"] = pretrain_weights
  291. def update_num_class(self, num_classes: int):
  292. """update classes number
  293. Args:
  294. num_classes (int): the classes number value to set.
  295. """
  296. self["num_classes"] = num_classes
  297. def update_random_size(self, randomsize):
  298. """update `target_size` of `BatchRandomResize` op in TestReader
  299. Args:
  300. randomsize (list[list[int, int]]): the list of different size scales.
  301. """
  302. self.TestReader["batch_transforms"]["BatchRandomResize"][
  303. "target_size"
  304. ] = randomsize
  305. def update_num_workers(self, num_workers: int):
  306. """update workers number of train and eval dataloader
  307. Args:
  308. num_workers (int): the value of train and eval dataloader workers number to set.
  309. """
  310. self["worker_num"] = num_workers
  311. def update_static_assigner_epochs(self, static_assigner_epochs: int):
  312. """update static assigner epochs value
  313. Args:
  314. static_assigner_epochs (int): the value of static assigner epochs
  315. """
  316. assert "PicoHeadV2" in self
  317. self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
  318. def get_epochs_iters(self) -> int:
  319. """get epochs
  320. Returns:
  321. int: the epochs value, i.e., `Global.epochs` in config.
  322. """
  323. return self.epoch
  324. def get_log_interval(self) -> int:
  325. """get log interval(steps)
  326. Returns:
  327. int: the log interval value, i.e., `Global.print_batch_step` in config.
  328. """
  329. self.log_iter
  330. def get_eval_interval(self) -> int:
  331. """get eval interval(epochs)
  332. Returns:
  333. int: the eval interval value, i.e., `Global.eval_interval` in config.
  334. """
  335. self.snapshot_epoch
  336. def get_save_interval(self) -> int:
  337. """get save interval(epochs)
  338. Returns:
  339. int: the save interval value, i.e., `Global.save_interval` in config.
  340. """
  341. self.snapshot_epoch
  342. def get_learning_rate(self) -> float:
  343. """get learning rate
  344. Returns:
  345. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  346. """
  347. return self.LearningRate["base_lr"]
  348. def get_batch_size(self, mode="train") -> int:
  349. """get batch size
  350. Args:
  351. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  352. Defaults to 'train'.
  353. Returns:
  354. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  355. """
  356. if mode == "train":
  357. return self.TrainReader["batch_size"]
  358. elif mode == "eval":
  359. return self.EvalReader["batch_size"]
  360. elif mode == "test":
  361. return self.TestReader["batch_size"]
  362. else:
  363. raise (f"Unknown mode: {repr(mode)}")
  364. def get_qat_epochs_iters(self) -> int:
  365. """get qat epochs
  366. Returns:
  367. int: the epochs value.
  368. """
  369. return self.epoch // 2.0
  370. def get_qat_learning_rate(self) -> float:
  371. """get qat learning rate
  372. Returns:
  373. float: the learning rate value.
  374. """
  375. return self.LearningRate["base_lr"] // 2.0
  376. def get_train_save_dir(self) -> str:
  377. """get the directory to save output
  378. Returns:
  379. str: the directory to save output
  380. """
  381. return self.save_dir