config.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Union
  16. import yaml
  17. from ....utils.misc import abspath
  18. from ...base import BaseConfig
  19. from ..config_utils import load_config, merge_config
  20. class TextRecConfig(BaseConfig):
  21. """Text Recognition Config"""
  22. def update(self, dict_like_obj: list):
  23. """update self
  24. Args:
  25. dict_like_obj (dict): dict of pairs(key0.key1.idx.key2=value).
  26. """
  27. dict_ = merge_config(self.dict, dict_like_obj)
  28. self.reset_from_dict(dict_)
  29. def load(self, config_file_path: str):
  30. """load config from yaml file
  31. Args:
  32. config_file_path (str): the path of yaml file.
  33. Raises:
  34. TypeError: the content of yaml file `config_file_path` error.
  35. """
  36. dict_ = load_config(config_file_path)
  37. if not isinstance(dict_, dict):
  38. raise TypeError
  39. self.reset_from_dict(dict_)
  40. def dump(self, config_file_path: str):
  41. """dump self to yaml file
  42. Args:
  43. config_file_path (str): the path to save self as yaml file.
  44. """
  45. with open(config_file_path, "w", encoding="utf-8") as f:
  46. yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
  47. def update_dataset(
  48. self,
  49. dataset_path: str,
  50. dataset_type: str = None,
  51. *,
  52. train_list_path: str = None,
  53. ):
  54. """update dataset settings
  55. Args:
  56. dataset_path (str): the root path of dataset.
  57. dataset_type (str, optional): dataset type. Defaults to None.
  58. train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
  59. Raises:
  60. ValueError: the dataset_type error.
  61. """
  62. dataset_path = abspath(dataset_path)
  63. if dataset_type is None:
  64. dataset_type = "TextRecDataset"
  65. if train_list_path:
  66. train_list_path = f"{train_list_path}"
  67. else:
  68. train_list_path = os.path.join(dataset_path, "train.txt")
  69. if (dataset_type == "TextRecDataset") or (dataset_type == "MSTextRecDataset"):
  70. _cfg = {
  71. "Train.dataset.name": dataset_type,
  72. "Train.dataset.data_dir": dataset_path,
  73. "Train.dataset.label_file_list": [train_list_path],
  74. "Eval.dataset.name": "TextRecDataset",
  75. "Eval.dataset.data_dir": dataset_path,
  76. "Eval.dataset.label_file_list": [os.path.join(dataset_path, "val.txt")],
  77. "Global.character_dict_path": os.path.join(dataset_path, "dict.txt"),
  78. }
  79. self.update(_cfg)
  80. elif dataset_type == "SimpleDataSet":
  81. _cfg = {
  82. "Train.dataset.name": dataset_type,
  83. "Train.dataset.data_dir": dataset_path,
  84. "Train.dataset.label_file_list": [train_list_path],
  85. "Eval.dataset.name": "SimpleDataSet",
  86. "Eval.dataset.data_dir": dataset_path,
  87. "Eval.dataset.label_file_list": [os.path.join(dataset_path, "val.txt")],
  88. "Global.character_dict_path": os.path.join(dataset_path, "dict.txt"),
  89. }
  90. self.update(_cfg)
  91. elif dataset_type == "LaTeXOCRDataSet":
  92. _cfg = {
  93. "Train.dataset.name": dataset_type,
  94. "Train.dataset.data_dir": dataset_path,
  95. "Train.dataset.data": os.path.join(dataset_path, "latexocr_train.pkl"),
  96. "Train.dataset.label_file_list": [train_list_path],
  97. "Eval.dataset.name": dataset_type,
  98. "Eval.dataset.data_dir": dataset_path,
  99. "Eval.dataset.data": os.path.join(dataset_path, "latexocr_val.pkl"),
  100. "Eval.dataset.label_file_list": [os.path.join(dataset_path, "val.txt")],
  101. "Global.character_dict_path": os.path.join(dataset_path, "dict.txt"),
  102. }
  103. self.update(_cfg)
  104. else:
  105. raise ValueError(f"{repr(dataset_type)} is not supported.")
  106. def update_dataset_by_list(self, label_file_list, ratio_list):
  107. _cfg = {
  108. "Train.dataset.name": "MSTextRecDataset",
  109. "Train.dataset.label_file_list": label_file_list,
  110. "Train.dataset.ratio_list": ratio_list,
  111. }
  112. self.update(_cfg)
  113. def update_batch_size(self, batch_size: int, mode: str = "train"):
  114. """update batch size setting
  115. Args:
  116. batch_size (int): the batch size number to set.
  117. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  118. Defaults to 'train'.
  119. Raises:
  120. ValueError: mode error.
  121. """
  122. _cfg = {
  123. "Train.loader.batch_size_per_card": batch_size,
  124. "Eval.loader.batch_size_per_card": batch_size,
  125. }
  126. if "sampler" in self.dict["Train"]:
  127. _cfg["Train.sampler.first_bs"] = batch_size
  128. self.update(_cfg)
  129. def update_batch_size_pair(
  130. self, batch_size_train: int, batch_size_val: int, mode: str = "train"
  131. ):
  132. """update batch size setting
  133. Args:
  134. batch_size (int): the batch size number to set.
  135. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  136. Defaults to 'train'.
  137. Raises:
  138. ValueError: mode error.
  139. """
  140. _cfg = {
  141. "Train.dataset.batch_size_per_pair": batch_size_train,
  142. "Eval.dataset.batch_size_per_pair": batch_size_val,
  143. }
  144. # if "sampler" in self.dict['Train']:
  145. # _cfg['Train.sampler.first_bs'] = 1
  146. self.update(_cfg)
  147. def update_learning_rate(self, learning_rate: float):
  148. """update learning rate
  149. Args:
  150. learning_rate (float): the learning rate value to set.
  151. """
  152. _cfg = {
  153. "Optimizer.lr.learning_rate": learning_rate,
  154. }
  155. self.update(_cfg)
  156. def update_label_dict_path(self, dict_path: str):
  157. """update label dict file path
  158. Args:
  159. dict_path (str): the path to label dict file.
  160. """
  161. _cfg = {
  162. "Global.character_dict_path": abspath(dict_path),
  163. }
  164. self.update(_cfg)
  165. def update_warmup_epochs(self, warmup_epochs: int):
  166. """update warmup epochs
  167. Args:
  168. warmup_epochs (int): the warmup epochs value to set.
  169. """
  170. _cfg = {"Optimizer.lr.warmup_epoch": warmup_epochs}
  171. self.update(_cfg)
  172. def update_pretrained_weights(self, pretrained_model: str):
  173. """update pretrained weight path
  174. Args:
  175. pretrained_model (str): the local path or url of pretrained weight file to set.
  176. """
  177. if pretrained_model:
  178. if not pretrained_model.startswith(
  179. "http://"
  180. ) and not pretrained_model.startswith("https://"):
  181. pretrained_model = abspath(pretrained_model)
  182. self.update(
  183. {"Global.pretrained_model": pretrained_model, "Global.checkpoints": ""}
  184. )
  185. # TODO
  186. def update_class_path(self, class_path: str):
  187. """_summary_
  188. Args:
  189. class_path (str): _description_
  190. """
  191. self.update(
  192. {
  193. "PostProcess.class_path": class_path,
  194. }
  195. )
  196. def _update_amp(self, amp: Union[None, str]):
  197. """update AMP settings
  198. Args:
  199. amp (None | str): the AMP level if it is not None or `OFF`.
  200. """
  201. _cfg = {
  202. "Global.use_amp": amp is not None and amp != "OFF",
  203. "Global.amp_level": amp,
  204. }
  205. self.update(_cfg)
  206. def update_device(self, device: str):
  207. """update device setting
  208. Args:
  209. device (str): the running device to set
  210. """
  211. device = device.split(":")[0]
  212. default_cfg = {
  213. "Global.use_gpu": False,
  214. "Global.use_xpu": False,
  215. "Global.use_npu": False,
  216. "Global.use_mlu": False,
  217. "Global.use_gcu": False,
  218. "Global.use_iluvatar_gpu": False,
  219. }
  220. device_cfg = {
  221. "cpu": {},
  222. "gpu": {"Global.use_gpu": True},
  223. "xpu": {"Global.use_xpu": True},
  224. "mlu": {"Global.use_mlu": True},
  225. "npu": {"Global.use_npu": True},
  226. "gcu": {"Global.use_gcu": True},
  227. "iluvatar_gpu": {"Global.use_iluvatar_gpu": True},
  228. }
  229. default_cfg.update(device_cfg[device])
  230. self.update(default_cfg)
  231. def _update_epochs(self, epochs: int):
  232. """update epochs setting
  233. Args:
  234. epochs (int): the epochs number value to set
  235. """
  236. self.update({"Global.epoch_num": epochs})
  237. def _update_checkpoints(self, resume_path: Union[None, str]):
  238. """update checkpoint setting
  239. Args:
  240. resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
  241. train from checkpoint file that path is `.pdparams` file.
  242. """
  243. self.update(
  244. {"Global.checkpoints": abspath(resume_path), "Global.pretrained_model": ""}
  245. )
  246. def _update_to_static(self, dy2st: bool):
  247. """update config to set dynamic to static mode
  248. Args:
  249. dy2st (bool): whether or not to use the dynamic to static mode.
  250. """
  251. self.update({"Global.to_static": dy2st})
  252. def _update_use_vdl(self, use_vdl: bool):
  253. """update config to set VisualDL
  254. Args:
  255. use_vdl (bool): whether or not to use VisualDL.
  256. """
  257. self.update({"Global.use_visualdl": use_vdl})
  258. def _update_output_dir(self, save_dir: str):
  259. """update output directory
  260. Args:
  261. save_dir (str): the path to save output.
  262. """
  263. self.update({"Global.save_model_dir": abspath(save_dir)})
  264. # TODO
  265. # def _update_log_interval(self, log_interval):
  266. # self.update({'Global.print_batch_step': log_interval})
  267. def update_log_interval(self, log_interval: int):
  268. """update log interval(by steps)
  269. Args:
  270. log_interval (int): the log interval value to set.
  271. """
  272. self.update({"Global.print_batch_step": log_interval})
  273. # def _update_eval_interval(self, eval_start_step, eval_interval):
  274. # self.update({
  275. # 'Global.eval_batch_step': [eval_start_step, eval_interval]
  276. # })
  277. def update_log_ranks(self, device):
  278. """update log ranks
  279. Args:
  280. device (str): the running device to set
  281. """
  282. log_ranks = device.split(":")[1]
  283. self.update({"Global.log_ranks": log_ranks})
  284. def update_print_mem_info(self, print_mem_info: bool):
  285. """setting print memory info"""
  286. assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
  287. self.update({"Global.print_mem_info": f"{print_mem_info}"})
  288. def update_shared_memory(self, shared_memeory: bool):
  289. """update shared memory setting of train and eval dataloader
  290. Args:
  291. shared_memeory (bool): whether or not to use shared memory
  292. """
  293. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  294. _cfg = {
  295. "Train.loader.use_shared_memory": f"{shared_memeory}",
  296. "Train.loader.use_shared_memory": f"{shared_memeory}",
  297. }
  298. self.update(_cfg)
  299. def update_shuffle(self, shuffle: bool):
  300. """update shuffle setting of train and eval dataloader
  301. Args:
  302. shuffle (bool): whether or not to shuffle the data
  303. """
  304. assert isinstance(shuffle, bool), "shuffle should be a bool"
  305. _cfg = {
  306. f"Train.loader.shuffle": shuffle,
  307. f"Train.loader.shuffle": shuffle,
  308. }
  309. self.update(_cfg)
  310. def update_cal_metrics(self, cal_metrics: bool):
  311. """update calculate metrics setting
  312. Args:
  313. cal_metrics (bool): whether or not to calculate metrics during train
  314. """
  315. assert isinstance(cal_metrics, bool), "cal_metrics should be a bool"
  316. self.update({"Global.cal_metric_during_train": cal_metrics})
  317. def update_seed(self, seed: int):
  318. """update seed
  319. Args:
  320. seed (int): the random seed value to set
  321. """
  322. assert isinstance(seed, int), "seed should be an int"
  323. self.update({"Global.seed": seed})
  324. def _update_eval_interval_by_epoch(self, eval_interval):
  325. """update eval interval(by epoch)
  326. Args:
  327. eval_interval (int): the eval interval value to set.
  328. """
  329. self.update({"Global.eval_batch_epoch": eval_interval})
  330. def update_eval_interval(self, eval_interval: int, eval_start_step: int = 0):
  331. """update eval interval(by steps)
  332. Args:
  333. eval_interval (int): the eval interval value to set.
  334. eval_start_step (int, optional): step number from which the evaluation is enabled. Defaults to 0.
  335. """
  336. self._update_eval_interval(eval_start_step, eval_interval)
  337. def _update_save_interval(self, save_interval: int):
  338. """update save interval(by steps)
  339. Args:
  340. save_interval (int): the save interval value to set.
  341. """
  342. self.update({"Global.save_epoch_step": save_interval})
  343. def update_save_interval(self, save_interval: int):
  344. """update save interval(by steps)
  345. Args:
  346. save_interval (int): the save interval value to set.
  347. """
  348. self._update_save_interval(save_interval)
  349. def _update_infer_img(self, infer_img: str, infer_list: str = None):
  350. """update image list to be inferred
  351. Args:
  352. infer_img (str): path to the image file to be inferred. It would be ignored when `infer_list` is be set.
  353. infer_list (str, optional): path to the .txt file containing the paths to image to be inferred.
  354. Defaults to None.
  355. """
  356. if infer_list:
  357. self.update({"Global.infer_list": infer_list})
  358. self.update({"Global.infer_img": infer_img})
  359. def _update_save_inference_dir(self, save_inference_dir: str):
  360. """update the directory saving infer outputs
  361. Args:
  362. save_inference_dir (str): the directory saving infer outputs.
  363. """
  364. self.update({"Global.save_inference_dir": abspath(save_inference_dir)})
  365. def _update_save_res_path(self, save_res_path: str):
  366. """update the .txt file path saving OCR model inference result
  367. Args:
  368. save_res_path (str): the .txt file path saving OCR model inference result.
  369. """
  370. self.update({"Global.save_res_path": abspath(save_res_path)})
  371. def update_num_workers(
  372. self, num_workers: int, modes: Union[str, list] = ["train", "eval"]
  373. ):
  374. """update workers number of train or eval dataloader
  375. Args:
  376. num_workers (int): the value of train and eval dataloader workers number to set.
  377. modes (str | [list], optional): mode. Defaults to ['train', 'eval'].
  378. Raises:
  379. ValueError: mode error. The `mode` should be `train`, `eval` or `['train', 'eval']`.
  380. """
  381. if not isinstance(modes, list):
  382. modes = [modes]
  383. for mode in modes:
  384. if not mode in ("train", "eval"):
  385. raise ValueError
  386. if mode == "train":
  387. self["Train"]["loader"]["num_workers"] = num_workers
  388. else:
  389. self["Eval"]["loader"]["num_workers"] = num_workers
  390. def _get_model_type(self) -> str:
  391. """get model type
  392. Returns:
  393. str: model type, i.e. `Architecture.algorithm` or `Architecture.Models.Student.algorithm`.
  394. """
  395. if "Models" in self.dict["Architecture"]:
  396. return self.dict["Architecture"]["Models"]["Student"]["algorithm"]
  397. return self.dict["Architecture"]["algorithm"]
  398. def get_epochs_iters(self) -> int:
  399. """get epochs
  400. Returns:
  401. int: the epochs value, i.e., `Global.epochs` in config.
  402. """
  403. return self.dict["Global"]["epoch_num"]
  404. def get_learning_rate(self) -> float:
  405. """get learning rate
  406. Returns:
  407. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  408. """
  409. return self.dict["Optimizer"]["lr"]["learning_rate"]
  410. def get_batch_size(self, mode="train") -> int:
  411. """get batch size
  412. Args:
  413. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  414. Defaults to 'train'.
  415. Returns:
  416. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  417. """
  418. return self.dict["Train"]["loader"]["batch_size_per_card"]
  419. def get_qat_epochs_iters(self) -> int:
  420. """get qat epochs
  421. Returns:
  422. int: the epochs value.
  423. """
  424. return self.get_epochs_iters()
  425. def get_qat_learning_rate(self) -> float:
  426. """get qat learning rate
  427. Returns:
  428. float: the learning rate value.
  429. """
  430. return self.get_learning_rate()
  431. def get_label_dict_path(self) -> str:
  432. """get label dict file path
  433. Returns:
  434. str: the label dict file path, i.e., `Global.character_dict_path` in config.
  435. """
  436. return self.dict["Global"]["character_dict_path"]
  437. def _get_dataset_root(self) -> str:
  438. """get root directory of dataset, i.e. `DataLoader.Train.dataset.data_dir`
  439. Returns:
  440. str: the root directory of dataset
  441. """
  442. return self.dict["Train"]["dataset"]["data_dir"]
  443. def _get_infer_shape(self) -> str:
  444. """get resize scale of ResizeImg operation in the evaluation
  445. Returns:
  446. str: resize scale, i.e. `Eval.dataset.transforms.ResizeImg.image_shape`
  447. """
  448. size = None
  449. transforms = self.dict["Eval"]["dataset"]["transforms"]
  450. for op in transforms:
  451. op_name = list(op)[0]
  452. if "ResizeImg" in op_name:
  453. size = op[op_name]["image_shape"]
  454. return ",".join([str(x) for x in size])
  455. def get_train_save_dir(self) -> str:
  456. """get the directory to save output
  457. Returns:
  458. str: the directory to save output
  459. """
  460. return self["Global"]["save_model_dir"]
  461. def get_predict_save_dir(self) -> str:
  462. """get the directory to save output in predicting
  463. Returns:
  464. str: the directory to save output
  465. """
  466. return os.path.dirname(self["Global"]["save_res_path"])