config.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. from typing import Union
  17. from ...base import BaseConfig
  18. from ....utils.misc import abspath
  19. from ..config_utils import load_config, merge_config
  20. class TextRecConfig(BaseConfig):
  21. """ Text Recognition Config """
  22. def update(self, dict_like_obj: list):
  23. """update self
  24. Args:
  25. dict_like_obj (dict): dict of pairs(key0.key1.idx.key2=value).
  26. """
  27. dict_ = merge_config(self.dict, dict_like_obj)
  28. self.reset_from_dict(dict_)
  29. def load(self, config_file_path: str):
  30. """load config from yaml file
  31. Args:
  32. config_file_path (str): the path of yaml file.
  33. Raises:
  34. TypeError: the content of yaml file `config_file_path` error.
  35. """
  36. dict_ = load_config(config_file_path)
  37. if not isinstance(dict_, dict):
  38. raise TypeError
  39. self.reset_from_dict(dict_)
  40. def dump(self, config_file_path: str):
  41. """dump self to yaml file
  42. Args:
  43. config_file_path (str): the path to save self as yaml file.
  44. """
  45. with open(config_file_path, 'w', encoding='utf-8') as f:
  46. yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
  47. def update_dataset(
  48. self,
  49. dataset_path: str,
  50. dataset_type: str=None,
  51. *,
  52. train_list_path: str=None, ):
  53. """update dataset settings
  54. Args:
  55. dataset_path (str): the root path of dataset.
  56. dataset_type (str, optional): dataset type. Defaults to None.
  57. train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
  58. Raises:
  59. ValueError: the dataset_type error.
  60. """
  61. dataset_path = abspath(dataset_path)
  62. if dataset_type is None:
  63. dataset_type = 'TextRecDataset'
  64. if train_list_path:
  65. train_list_path = f"{train_list_path}"
  66. else:
  67. train_list_path = os.path.join(dataset_path, 'train.txt')
  68. if (dataset_type == 'TextRecDataset') or (
  69. dataset_type == "MSTextRecDataset"):
  70. _cfg = {
  71. 'Train.dataset.name': dataset_type,
  72. 'Train.dataset.data_dir': dataset_path,
  73. 'Train.dataset.label_file_list': [train_list_path],
  74. 'Eval.dataset.name': 'TextRecDataset',
  75. 'Eval.dataset.data_dir': dataset_path,
  76. 'Eval.dataset.label_file_list':
  77. [os.path.join(dataset_path, 'val.txt')],
  78. 'Global.character_dict_path':
  79. os.path.join(dataset_path, 'dict.txt')
  80. }
  81. self.update(_cfg)
  82. elif dataset_type == "LaTeXOCRDataSet":
  83. _cfg = {
  84. 'Train.dataset.name': dataset_type,
  85. 'Train.dataset.data_dir': dataset_path,
  86. 'Train.dataset.data':
  87. os.path.join(dataset_path, "latexocr_train.pkl"),
  88. 'Train.dataset.label_file_list': [train_list_path],
  89. 'Eval.dataset.name': dataset_type,
  90. 'Eval.dataset.data_dir': dataset_path,
  91. 'Eval.dataset.data':
  92. os.path.join(dataset_path, "latexocr_val.pkl"),
  93. 'Eval.dataset.label_file_list':
  94. [os.path.join(dataset_path, 'val.txt')],
  95. 'Global.character_dict_path':
  96. os.path.join(dataset_path, 'dict.txt')
  97. }
  98. self.update(_cfg)
  99. else:
  100. raise ValueError(f"{repr(dataset_type)} is not supported.")
  101. def update_batch_size(self, batch_size: int, mode: str='train'):
  102. """update batch size setting
  103. Args:
  104. batch_size (int): the batch size number to set.
  105. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  106. Defaults to 'train'.
  107. Raises:
  108. ValueError: mode error.
  109. """
  110. _cfg = {
  111. 'Train.loader.batch_size_per_card': batch_size,
  112. 'Eval.loader.batch_size_per_card': batch_size,
  113. }
  114. if "sampler" in self.dict['Train']:
  115. _cfg['Train.sampler.first_bs'] = batch_size
  116. self.update(_cfg)
  117. def update_batch_size_pair(self,
  118. batch_size_train: int,
  119. batch_size_val: int,
  120. mode: str='train'):
  121. """update batch size setting
  122. Args:
  123. batch_size (int): the batch size number to set.
  124. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
  125. Defaults to 'train'.
  126. Raises:
  127. ValueError: mode error.
  128. """
  129. _cfg = {
  130. 'Train.dataset.batch_size_per_pair': batch_size_train,
  131. 'Eval.dataset.batch_size_per_pair': batch_size_val,
  132. }
  133. # if "sampler" in self.dict['Train']:
  134. # _cfg['Train.sampler.first_bs'] = 1
  135. self.update(_cfg)
  136. def update_learning_rate(self, learning_rate: float):
  137. """update learning rate
  138. Args:
  139. learning_rate (float): the learning rate value to set.
  140. """
  141. _cfg = {'Optimizer.lr.learning_rate': learning_rate, }
  142. self.update(_cfg)
  143. def update_label_dict_path(self, dict_path: str):
  144. """update label dict file path
  145. Args:
  146. dict_path (str): the path to label dict file.
  147. """
  148. _cfg = {'Global.character_dict_path': abspath(dict_path), }
  149. self.update(_cfg)
  150. def update_warmup_epochs(self, warmup_epochs: int):
  151. """update warmup epochs
  152. Args:
  153. warmup_epochs (int): the warmup epochs value to set.
  154. """
  155. _cfg = {'Optimizer.lr.warmup_epoch': warmup_epochs}
  156. self.update(_cfg)
  157. def update_pretrained_weights(self, pretrained_model: str):
  158. """update pretrained weight path
  159. Args:
  160. pretrained_model (str): the local path or url of pretrained weight file to set.
  161. """
  162. if pretrained_model:
  163. if not pretrained_model.startswith(
  164. 'http://') and not pretrained_model.startswith('https://'):
  165. pretrained_model = abspath(pretrained_model)
  166. self.update({
  167. 'Global.pretrained_model': pretrained_model,
  168. 'Global.checkpoints': ''
  169. })
  170. # TODO
  171. def update_class_path(self, class_path: str):
  172. """_summary_
  173. Args:
  174. class_path (str): _description_
  175. """
  176. self.update({'PostProcess.class_path': class_path, })
  177. def _update_amp(self, amp: Union[None, str]):
  178. """update AMP settings
  179. Args:
  180. amp (None | str): the AMP level if it is not None or `OFF`.
  181. """
  182. _cfg = {
  183. 'Global.use_amp': amp is not None and amp != 'OFF',
  184. 'Global.amp_level': amp,
  185. }
  186. self.update(_cfg)
  187. def update_device(self, device: str):
  188. """update device setting
  189. Args:
  190. device (str): the running device to set
  191. """
  192. device = device.split(':')[0]
  193. default_cfg = {
  194. 'Global.use_gpu': False,
  195. 'Global.use_xpu': False,
  196. 'Global.use_npu': False,
  197. 'Global.use_mlu': False,
  198. }
  199. device_cfg = {
  200. 'cpu': {},
  201. 'gpu': {
  202. 'Global.use_gpu': True
  203. },
  204. 'xpu': {
  205. 'Global.use_xpu': True
  206. },
  207. 'mlu': {
  208. 'Global.use_mlu': True
  209. },
  210. 'npu': {
  211. 'Global.use_npu': True
  212. }
  213. }
  214. default_cfg.update(device_cfg[device])
  215. self.update(default_cfg)
  216. def _update_epochs(self, epochs: int):
  217. """update epochs setting
  218. Args:
  219. epochs (int): the epochs number value to set
  220. """
  221. self.update({'Global.epoch_num': epochs})
  222. def _update_checkpoints(self, resume_path: Union[None, str]):
  223. """update checkpoint setting
  224. Args:
  225. resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
  226. train from checkpoint file that path is `.pdparams` file.
  227. """
  228. self.update({
  229. 'Global.checkpoints': abspath(resume_path),
  230. 'Global.pretrained_model': ''
  231. })
  232. def _update_to_static(self, dy2st: bool):
  233. """update config to set dynamic to static mode
  234. Args:
  235. dy2st (bool): whether or not to use the dynamic to static mode.
  236. """
  237. self.update({'Global.to_static': dy2st})
  238. def _update_use_vdl(self, use_vdl: bool):
  239. """update config to set VisualDL
  240. Args:
  241. use_vdl (bool): whether or not to use VisualDL.
  242. """
  243. self.update({'Global.use_visualdl': use_vdl})
  244. def _update_output_dir(self, save_dir: str):
  245. """update output directory
  246. Args:
  247. save_dir (str): the path to save output.
  248. """
  249. self.update({'Global.save_model_dir': abspath(save_dir)})
  250. # TODO
  251. # def _update_log_interval(self, log_interval):
  252. # self.update({'Global.print_batch_step': log_interval})
  253. def update_log_interval(self, log_interval: int):
  254. """update log interval(by steps)
  255. Args:
  256. log_interval (int): the log interval value to set.
  257. """
  258. self.update({'Global.print_batch_step': log_interval})
  259. # def _update_eval_interval(self, eval_start_step, eval_interval):
  260. # self.update({
  261. # 'Global.eval_batch_step': [eval_start_step, eval_interval]
  262. # })
  263. def update_log_ranks(self, device):
  264. """update log ranks
  265. Args:
  266. device (str): the running device to set
  267. """
  268. log_ranks = device.split(':')[1]
  269. self.update({'Global.log_ranks': log_ranks})
  270. def update_print_mem_info(self, print_mem_info: bool):
  271. """setting print memory info"""
  272. assert isinstance(print_mem_info,
  273. bool), "print_mem_info should be a bool"
  274. self.update({'Global.print_mem_info': f'{print_mem_info}'})
  275. def update_shared_memory(self, shared_memeory: bool):
  276. """update shared memory setting of train and eval dataloader
  277. Args:
  278. shared_memeory (bool): whether or not to use shared memory
  279. """
  280. assert isinstance(shared_memeory,
  281. bool), "shared_memeory should be a bool"
  282. _cfg = {
  283. 'Train.loader.use_shared_memory': f'{shared_memeory}',
  284. 'Train.loader.use_shared_memory': f'{shared_memeory}',
  285. }
  286. self.update(_cfg)
  287. def update_shuffle(self, shuffle: bool):
  288. """update shuffle setting of train and eval dataloader
  289. Args:
  290. shuffle (bool): whether or not to shuffle the data
  291. """
  292. assert isinstance(shuffle, bool), "shuffle should be a bool"
  293. _cfg = {
  294. f'Train.loader.shuffle': shuffle,
  295. f'Train.loader.shuffle': shuffle,
  296. }
  297. self.update(_cfg)
  298. def update_cal_metrics(self, cal_metrics: bool):
  299. """update calculate metrics setting
  300. Args:
  301. cal_metrics (bool): whether or not to calculate metrics during train
  302. """
  303. assert isinstance(cal_metrics, bool), "cal_metrics should be a bool"
  304. self.update({'Global.cal_metric_during_train': cal_metrics})
  305. def update_seed(self, seed: int):
  306. """update seed
  307. Args:
  308. seed (int): the random seed value to set
  309. """
  310. assert isinstance(seed, int), "seed should be an int"
  311. self.update({'Global.seed': seed})
  312. def _update_eval_interval_by_epoch(self, eval_interval):
  313. """update eval interval(by epoch)
  314. Args:
  315. eval_interval (int): the eval interval value to set.
  316. """
  317. self.update({'Global.eval_batch_epoch': eval_interval})
  318. def update_eval_interval(self, eval_interval: int, eval_start_step: int=0):
  319. """update eval interval(by steps)
  320. Args:
  321. eval_interval (int): the eval interval value to set.
  322. eval_start_step (int, optional): step number from which the evaluation is enabled. Defaults to 0.
  323. """
  324. self._update_eval_interval(eval_start_step, eval_interval)
  325. def _update_save_interval(self, save_interval: int):
  326. """update save interval(by steps)
  327. Args:
  328. save_interval (int): the save interval value to set.
  329. """
  330. self.update({'Global.save_epoch_step': save_interval})
  331. def update_save_interval(self, save_interval: int):
  332. """update save interval(by steps)
  333. Args:
  334. save_interval (int): the save interval value to set.
  335. """
  336. self._update_save_interval(save_interval)
  337. def _update_infer_img(self, infer_img: str, infer_list: str=None):
  338. """update image list to be infered
  339. Args:
  340. infer_img (str): path to the image file to be infered. It would be ignored when `infer_list` is be set.
  341. infer_list (str, optional): path to the .txt file containing the paths to image to be infered.
  342. Defaults to None.
  343. """
  344. if infer_list:
  345. self.update({'Global.infer_list': infer_list})
  346. self.update({'Global.infer_img': infer_img})
  347. def _update_save_inference_dir(self, save_inference_dir: str):
  348. """update the directory saving infer outputs
  349. Args:
  350. save_inference_dir (str): the directory saving infer outputs.
  351. """
  352. self.update({'Global.save_inference_dir': abspath(save_inference_dir)})
  353. def _update_save_res_path(self, save_res_path: str):
  354. """update the .txt file path saving OCR model inference result
  355. Args:
  356. save_res_path (str): the .txt file path saving OCR model inference result.
  357. """
  358. self.update({'Global.save_res_path': abspath(save_res_path)})
  359. def update_num_workers(self,
  360. num_workers: int,
  361. modes: Union[str, list]=['train', 'eval']):
  362. """update workers number of train or eval dataloader
  363. Args:
  364. num_workers (int): the value of train and eval dataloader workers number to set.
  365. modes (str | [list], optional): mode. Defaults to ['train', 'eval'].
  366. Raises:
  367. ValueError: mode error. The `mode` should be `train`, `eval` or `['train', 'eval']`.
  368. """
  369. if not isinstance(modes, list):
  370. modes = [modes]
  371. for mode in modes:
  372. if not mode in ('train', 'eval'):
  373. raise ValueError
  374. if mode == 'train':
  375. self['Train']['loader']['num_workers'] = num_workers
  376. else:
  377. self['Eval']['loader']['num_workers'] = num_workers
  378. def _get_model_type(self) -> str:
  379. """get model type
  380. Returns:
  381. str: model type, i.e. `Architecture.algorithm` or `Architecture.Models.Student.algorithm`.
  382. """
  383. if 'Models' in self.dict['Architecture']:
  384. return self.dict['Architecture']['Models']['Student']['algorithm']
  385. return self.dict['Architecture']['algorithm']
  386. def get_epochs_iters(self) -> int:
  387. """get epochs
  388. Returns:
  389. int: the epochs value, i.e., `Global.epochs` in config.
  390. """
  391. return self.dict['Global']['epoch_num']
  392. def get_learning_rate(self) -> float:
  393. """get learning rate
  394. Returns:
  395. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  396. """
  397. return self.dict['Optimizer']['lr']['learning_rate']
  398. def get_batch_size(self, mode='train') -> int:
  399. """get batch size
  400. Args:
  401. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  402. Defaults to 'train'.
  403. Returns:
  404. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  405. """
  406. return self.dict['Train']['loader']['batch_size_per_card']
  407. def get_qat_epochs_iters(self) -> int:
  408. """get qat epochs
  409. Returns:
  410. int: the epochs value.
  411. """
  412. return self.get_epochs_iters()
  413. def get_qat_learning_rate(self) -> float:
  414. """get qat learning rate
  415. Returns:
  416. float: the learning rate value.
  417. """
  418. return self.get_learning_rate()
  419. def get_label_dict_path(self) -> str:
  420. """get label dict file path
  421. Returns:
  422. str: the label dict file path, i.e., `Global.character_dict_path` in config.
  423. """
  424. return self.dict['Global']['character_dict_path']
  425. def _get_dataset_root(self) -> str:
  426. """get root directory of dataset, i.e. `DataLoader.Train.dataset.data_dir`
  427. Returns:
  428. str: the root directory of dataset
  429. """
  430. return self.dict['Train']['dataset']['data_dir']
  431. def _get_infer_shape(self) -> str:
  432. """get resize scale of ResizeImg operation in the evaluation
  433. Returns:
  434. str: resize scale, i.e. `Eval.dataset.transforms.ResizeImg.image_shape`
  435. """
  436. size = None
  437. transforms = self.dict['Eval']['dataset']['transforms']
  438. for op in transforms:
  439. op_name = list(op)[0]
  440. if 'ResizeImg' in op_name:
  441. size = op[op_name]['image_shape']
  442. return ','.join([str(x) for x in size])
  443. def get_train_save_dir(self) -> str:
  444. """get the directory to save output
  445. Returns:
  446. str: the directory to save output
  447. """
  448. return self['Global']['save_model_dir']
  449. def get_predict_save_dir(self) -> str:
  450. """get the directory to save output in predicting
  451. Returns:
  452. str: the directory to save output
  453. """
  454. return os.path.dirname(self['Global']['save_res_path'])