config.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. from typing import Union
  17. from ...base import BaseConfig
  18. from ....utils.misc import abspath
  19. from ..config_utils import merge_config
  20. class VideoDetConfig(BaseConfig):
  21. """Image Det Task Config"""
  22. def update(self, dict_like_obj: list):
  23. """update self
  24. Args:
  25. dict_like_obj (list): list of pairs(key0.key1.idx.key2=value)
  26. """
  27. dict_ = merge_config(self.dict, dict_like_obj)
  28. self.reset_from_dict(dict_)
  29. def load(self, config_file_path: str):
  30. """load config from yaml file
  31. Args:
  32. config_file_path (str): the path of yaml file.
  33. Raises:
  34. TypeError: the content of yaml file `config_file_path` error.
  35. """
  36. dict_ = yaml.load(open(config_file_path, "rb"), Loader=yaml.Loader)
  37. if not isinstance(dict_, dict):
  38. raise TypeError
  39. self.reset_from_dict(dict_)
  40. def dump(self, config_file_path: str):
  41. """dump self to yaml file
  42. Args:
  43. config_file_path (str): the path to save self as yaml file.
  44. """
  45. with open(config_file_path, "w", encoding="utf-8") as f:
  46. yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
  47. def update_dataset(
  48. self,
  49. dataset_path: str,
  50. dataset_type: str = None,
  51. *,
  52. train_list_path: str = None,
  53. ):
  54. """update dataset settings
  55. Args:
  56. dataset_path (str): the root path of dataset.
  57. dataset_type (str, optional): dataset type. Defaults to None.
  58. train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
  59. Raises:
  60. ValueError: the dataset_type error.
  61. """
  62. dataset_path = abspath(dataset_path)
  63. if dataset_type is None:
  64. dataset_type = "VideoDetDataset"
  65. if train_list_path:
  66. train_list_path = f"{train_list_path}"
  67. else:
  68. train_list_path = f"{dataset_path}/train.txt"
  69. if dataset_type in ["VideoDetDataset"]:
  70. _cfg = {
  71. "DATASET.train.image_dir": dataset_path,
  72. "DATASET.train.file_path": os.path.join(dataset_path, "train.txt"),
  73. "DATASET.valid.image_dir": dataset_path,
  74. "DATASET.valid.file_path": os.path.join(dataset_path, "val.txt"),
  75. "DATASET.test.image_dir": dataset_path,
  76. "DATASET.test.file_path": os.path.join(dataset_path, "val.txt"),
  77. "METRIC.gt_folder": os.path.join(dataset_path, "val.txt"),
  78. "label_dict_path": os.path.join(dataset_path, "label_map.txt"),
  79. }
  80. else:
  81. raise ValueError(f"{repr(dataset_type)} is not supported.")
  82. self.update(_cfg)
  83. def update_batch_size(self, batch_size: int, mode: str = "train"):
  84. """update batch size setting
  85. Args:
  86. batch_size (int): the batch size number to set.
  87. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  88. Defaults to 'train'.
  89. Raises:
  90. ValueError: `mode` error.
  91. """
  92. if mode == "train":
  93. _cfg = {"DATASET.batch_size": batch_size}
  94. elif mode == "eval":
  95. _cfg = {"DATASET.test_batch_size": batch_size}
  96. elif mode == "test":
  97. _cfg = {"DATASET.test_batch_size": batch_size}
  98. else:
  99. raise ValueError("The input `mode` should be train, eval or test.")
  100. self.update(_cfg)
  101. def update_learning_rate(self, learning_rate: float):
  102. """update learning rate
  103. Args:
  104. learning_rate (float): the learning rate value to set.
  105. """
  106. if (
  107. self._dict["OPTIMIZER"]["learning_rate"].get("learning_rate", None)
  108. is not None
  109. ):
  110. _cfg = {"OPTIMIZER.learning_rate.learning_rate": learning_rate}
  111. else:
  112. raise ValueError("unsupported lr format")
  113. self.update(_cfg)
  114. def update_num_classes(self, num_classes: int):
  115. """update classes number
  116. Args:
  117. num_classes (int): the classes number value to set.
  118. """
  119. update_str_list = {"MODEL.backbone.num_class": num_classes}
  120. self.update(update_str_list)
  121. update_str_list = {"MODEL.loss.num_classes": num_classes}
  122. self.update(update_str_list)
  123. def update_label_list(self, label_path: str):
  124. """update label list
  125. Args:
  126. label_list (str): the path of label list file to set.
  127. """
  128. with open(label_path, "r") as f:
  129. lines = [line.strip().split(" ") for line in f.readlines()]
  130. sorted_lines = sorted(lines, key=lambda x: int(x[1]))
  131. label_list = [line[0] for line in sorted_lines]
  132. f.close()
  133. self.update({"label_list": label_list})
  134. def update_pretrained_weights(self, pretrained_model: str):
  135. """update pretrained weight path
  136. Args:
  137. pretrained_model (str): the local path or url of pretrained weight file to set.
  138. """
  139. assert isinstance(
  140. pretrained_model, (str, type(None))
  141. ), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \
  142. indicating that no pretrained model to be used."
  143. if pretrained_model is None:
  144. self.update({"Global.pretrained_model", None})
  145. else:
  146. if pretrained_model.lower() == "default":
  147. self.update({"Global.pretrained_model", None})
  148. else:
  149. if not pretrained_model.startswith(("http://", "https://")):
  150. pretrained_model = abspath(pretrained_model)
  151. self.update({"Global.pretrained_model": pretrained_model})
  152. def _update_slim_config(self, slim_config_path: str):
  153. """update slim settings
  154. Args:
  155. slim_config_path (str): the path to slim config yaml file.
  156. """
  157. slim_config = yaml.load(open(slim_config_path, "rb"), Loader=yaml.Loader)[
  158. "Slim"
  159. ]
  160. self.update({"Slim": slim_config})
  161. def _update_amp(self, amp: Union[None, str]):
  162. """update AMP settings
  163. Args:
  164. amp (None | str): the AMP settings.
  165. Raises:
  166. ValueError: AMP setting `amp` error, missing field `AMP`.
  167. """
  168. if amp is None or amp == "OFF":
  169. if "AMP" in self.dict:
  170. self._dict.pop("AMP")
  171. else:
  172. if "AMP" not in self.dict:
  173. raise ValueError("Config must have AMP information.")
  174. _cfg = {"AMP.use_amp": True, "AMP.level": amp}
  175. self.update(_cfg)
  176. def update_num_workers(self, num_workers: int):
  177. """update workers number of train and eval dataloader
  178. Args:
  179. num_workers (int): the value of train and eval dataloader workers number to set.
  180. """
  181. _cfg = {
  182. "DATASET.num_workers": num_workers,
  183. }
  184. self.update(_cfg)
  185. def update_shared_memory(self, shared_memeory: bool):
  186. """update shared memory setting of train and eval dataloader
  187. Args:
  188. shared_memeory (bool): whether or not to use shared memory
  189. """
  190. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  191. _cfg = {
  192. "DataLoader.Train.loader.use_shared_memory": shared_memeory,
  193. "DataLoader.Eval.loader.use_shared_memory": shared_memeory,
  194. }
  195. self.update(_cfg)
  196. def update_shuffle(self, shuffle: bool):
  197. """update shuffle setting of train and eval dataloader
  198. Args:
  199. shuffle (bool): whether or not to shuffle the data
  200. """
  201. assert isinstance(shuffle, bool), "shuffle should be a bool"
  202. _cfg = [
  203. f"DataLoader.Train.loader.shuffle={shuffle}",
  204. f"DataLoader.Eval.loader.shuffle={shuffle}",
  205. ]
  206. self.update(_cfg)
  207. def update_dali(self, dali: bool):
  208. """enable DALI setting of train and eval dataloader
  209. Args:
  210. dali (bool): whether or not to use DALI
  211. """
  212. assert isinstance(dali, bool), "dali should be a bool"
  213. _cfg = [
  214. f"Global.use_dali={dali}",
  215. f"Global.use_dali={dali}",
  216. ]
  217. self.update(_cfg)
  218. def update_seed(self, seed: int):
  219. """update seed
  220. Args:
  221. seed (int): the random seed value to set
  222. """
  223. _cfg = {"Global.seed": seed}
  224. self.update(_cfg)
  225. def update_device(self, device: str):
  226. """update device setting
  227. Args:
  228. device (str): the running device to set
  229. """
  230. device = device.split(":")[0]
  231. _cfg = {"Global.device": device}
  232. self.update(_cfg)
  233. def update_label_dict_path(self, dict_path: str):
  234. """update label dict file path
  235. Args:
  236. dict_path (str): the path of label dict file to set
  237. """
  238. _cfg = {
  239. "PostProcess.Topk.class_id_map_file": {abspath(dict_path)},
  240. }
  241. self.update(_cfg)
  242. def _update_to_static(self, dy2st: bool):
  243. """update config to set dynamic to static mode
  244. Args:
  245. dy2st (bool): whether or not to use the dynamic to static mode.
  246. """
  247. self.update({"to_static": dy2st})
  248. def _update_use_vdl(self, use_vdl: bool):
  249. """update config to set VisualDL
  250. Args:
  251. use_vdl (bool): whether or not to use VisualDL.
  252. """
  253. self.update({"Global.use_visuald": use_vdl})
  254. def _update_epochs(self, epochs: int):
  255. """update epochs setting
  256. Args:
  257. epochs (int): the epochs number value to set
  258. """
  259. self.update({"epochs": epochs})
  260. def _update_checkpoints(self, resume_path: Union[None, str]):
  261. """update checkpoint setting
  262. Args:
  263. resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
  264. train from checkpoint file that path is `.pdparams` file.
  265. """
  266. if resume_path is not None:
  267. resume_path = resume_path.replace(".pdparams", "")
  268. self.update({"Global.checkpoints": resume_path})
  269. def _update_output_dir(self, save_dir: str):
  270. """update output directory
  271. Args:
  272. save_dir (str): the path to save outputs.
  273. """
  274. self.update({"output_dir": abspath(save_dir)})
  275. self.update({"METRIC.result_path": abspath(save_dir)})
  276. def update_log_interval(self, log_interval: int):
  277. """update log interval(steps)
  278. Args:
  279. log_interval (int): the log interval value to set.
  280. """
  281. self.update({"log_interval": log_interval})
  282. def update_eval_interval(self, eval_interval: int):
  283. """update eval interval(epochs)
  284. Args:
  285. eval_interval (int): the eval interval value to set.
  286. """
  287. self.update({"val_interval": eval_interval})
  288. def update_save_interval(self, save_interval: int):
  289. """update eval interval(epochs)
  290. Args:
  291. save_interval (int): the save interval value to set.
  292. """
  293. self.update({"save_interval": save_interval})
  294. def update_log_ranks(self, device):
  295. """update log ranks
  296. Args:
  297. device (str): the running device to set
  298. """
  299. log_ranks = device.split(":")[1]
  300. self.update({"Global.log_ranks": log_ranks})
  301. def update_print_mem_info(self, print_mem_info: bool):
  302. """setting print memory info"""
  303. assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
  304. self.update({"Global.print_mem_info": print_mem_info})
  305. def _update_predict_video(self, infer_video: str, infer_list: str = None):
  306. """update video to be predicted
  307. Args:
  308. infer_video (str): the path to image that to be predicted.
  309. infer_list (str, optional): the path to file that videos. Defaults to None.
  310. """
  311. if infer_list:
  312. self.update({"Infer.infer_list": infer_list})
  313. self.update({"Infer.infer_videos": infer_video})
  314. def _update_save_inference_dir(self, save_inference_dir: str):
  315. """update directory path to save inference model files
  316. Args:
  317. save_inference_dir (str): the directory path to set.
  318. """
  319. self.update({"Global.save_inference_dir": abspath(save_inference_dir)})
  320. def _update_inference_model_dir(self, model_dir: str):
  321. """update inference model directory
  322. Args:
  323. model_dir (str): the directory path of inference model fils that used to predict.
  324. """
  325. self.update({"Global.inference_model_dir": abspath(model_dir)})
  326. def _update_infer_video(self, infer_video: str):
  327. """update path of image that would be predict
  328. Args:
  329. infer_video (str): the image path.
  330. """
  331. self.update({"Global.infer_videos": infer_video})
  332. def _update_infer_device(self, device: str):
  333. """update the device used in predicting
  334. Args:
  335. device (str): the running device setting
  336. """
  337. self.update({"Global.use_gpu": device.split(":")[0] == "gpu"})
  338. def _update_enable_mkldnn(self, enable_mkldnn: bool):
  339. """update whether to enable MKLDNN
  340. Args:
  341. enable_mkldnn (bool): `True` is enable, otherwise is disable.
  342. """
  343. self.update({"Global.enable_mkldnn": enable_mkldnn})
  344. def _update_infer_video_shape(self, img_shape: str):
  345. """update image cropping shape in the preprocessing
  346. Args:
  347. img_shape (str): the shape of cropping in the preprocessing,
  348. i.e. `PreProcess.transform_ops.1.CropImage.size`.
  349. """
  350. self.update({"INFERENCE.target_size": img_shape})
  351. def _update_save_predict_result(self, save_dir: str):
  352. """update directory that save predicting output
  353. Args:
  354. save_dir (str): the dicrectory path that save predicting output.
  355. """
  356. self.update({"Infer.save_dir": save_dir})
  357. def get_epochs_iters(self) -> int:
  358. """get epochs
  359. Returns:
  360. int: the epochs value, i.e., `Global.epochs` in config.
  361. """
  362. return self.dict["Global"]["epochs"]
  363. def get_log_interval(self) -> int:
  364. """get log interval(steps)
  365. Returns:
  366. int: the log interval value, i.e., `Global.print_batch_step` in config.
  367. """
  368. return self.dict["Global"]["print_batch_step"]
  369. def get_eval_interval(self) -> int:
  370. """get eval interval(epochs)
  371. Returns:
  372. int: the eval interval value, i.e., `Global.eval_interval` in config.
  373. """
  374. return self.dict["Global"]["eval_interval"]
  375. def get_save_interval(self) -> int:
  376. """get save interval(epochs)
  377. Returns:
  378. int: the save interval value, i.e., `Global.save_interval` in config.
  379. """
  380. return self.dict["Global"]["save_interval"]
  381. def get_learning_rate(self) -> float:
  382. """get learning rate
  383. Returns:
  384. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  385. """
  386. return self.dict["Optimizer"]["lr"]["learning_rate"]
  387. def get_warmup_epochs(self) -> int:
  388. """get warmup epochs
  389. Returns:
  390. int: the warmup epochs value, i.e., `Optimizer.lr.warmup_epochs` in config.
  391. """
  392. return self.dict["Optimizer"]["lr"]["warmup_epoch"]
  393. def get_label_dict_path(self) -> str:
  394. """get label dict file path
  395. Returns:
  396. str: the label dict file path, i.e., `PostProcess.Topk.class_id_map_file` in config.
  397. """
  398. return self.dict["PostProcess"]["Topk"]["class_id_map_file"]
  399. def get_batch_size(self, mode="train") -> int:
  400. """get batch size
  401. Args:
  402. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  403. Defaults to 'train'.
  404. Returns:
  405. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  406. """
  407. return self.dict["DataLoader"]["Train"]["sampler"]["batch_size"]
  408. def get_qat_epochs_iters(self) -> int:
  409. """get qat epochs
  410. Returns:
  411. int: the epochs value.
  412. """
  413. return self.get_epochs_iters()
  414. def get_qat_learning_rate(self) -> float:
  415. """get qat learning rate
  416. Returns:
  417. float: the learning rate value.
  418. """
  419. return self.get_learning_rate()
  420. def _get_arch_name(self) -> str:
  421. """get architecture name of model
  422. Returns:
  423. str: the model arch name, i.e., `Arch.name` in config.
  424. """
  425. return self.dict["Arch"]["name"]
  426. def _get_dataset_root(self) -> str:
  427. """get root directory of dataset, i.e. `DataLoader.Train.dataset.video_root`
  428. Returns:
  429. str: the root directory of dataset
  430. """
  431. return self.dict["DataLoader"]["Train"]["dataset"]["video_root"]
  432. def get_train_save_dir(self) -> str:
  433. """get the directory to save output
  434. Returns:
  435. str: the directory to save output
  436. """
  437. return self["output_dir"]