trainer.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. from pathlib import Path
  13. import paddle
  14. from ..base.trainer import BaseTrainer
  15. from ..base.train_deamon import BaseTrainDeamon
  16. from ...utils.config import AttrDict
  17. from .support_models import SUPPORT_MODELS
  18. class TextDetTrainer(BaseTrainer):
  19. """ Text Detection Model Trainer """
  20. support_models = SUPPORT_MODELS
  21. def build_deamon(self, config: AttrDict) -> "TextDetTrainDeamon":
  22. """build deamon thread for saving training outputs timely
  23. Args:
  24. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  25. Returns:
  26. TextDetTrainDeamon: the training deamon thread object for saving training outputs timely.
  27. """
  28. return TextDetTrainDeamon(config)
  29. def update_config(self):
  30. """update training config
  31. """
  32. if self.train_config.log_interval:
  33. self.pdx_config.update_log_interval(self.train_config.log_interval)
  34. if self.train_config.eval_interval:
  35. self.pdx_config._update_eval_interval_by_epoch(
  36. self.train_config.eval_interval)
  37. if self.train_config.save_interval:
  38. self.pdx_config.update_save_interval(
  39. self.train_config.save_interval)
  40. self.pdx_config.update_dataset(self.global_config.dataset_dir,
  41. "TextDetDataset")
  42. if self.train_config.pretrain_weight_path:
  43. self.pdx_config.update_pretrained_weights(
  44. self.train_config.pretrain_weight_path)
  45. if self.train_config.batch_size is not None:
  46. self.pdx_config.update_batch_size(self.train_config.batch_size)
  47. if self.train_config.learning_rate is not None:
  48. self.pdx_config.update_learning_rate(
  49. self.train_config.learning_rate)
  50. if self.train_config.epochs_iters is not None:
  51. self.pdx_config._update_epochs(self.train_config.epochs_iters)
  52. if self.train_config.resume_path is not None and self.train_config.resume_path != "":
  53. self.pdx_config._update_checkpoints(self.train_config.resume_path)
  54. if self.global_config.output is not None:
  55. self.pdx_config._update_output_dir(self.global_config.output)
  56. def get_train_kwargs(self) -> dict:
  57. """get key-value arguments of model training function
  58. Returns:
  59. dict: the arguments of training function.
  60. """
  61. return {"device": self.get_device()}
  62. class TextDetTrainDeamon(BaseTrainDeamon):
  63. """ TableRecTrainDeamon """
  64. def __init__(self, *args, **kwargs):
  65. super().__init__(*args, **kwargs)
  66. def get_the_pdparams_suffix(self):
  67. """ get the suffix of pdparams file """
  68. return "pdparams"
  69. def get_the_pdema_suffix(self):
  70. """ get the suffix of pdema file """
  71. return "pdema"
  72. def get_the_pdopt_suffix(self):
  73. """ get the suffix of pdopt file """
  74. return "pdopt"
  75. def get_the_pdstates_suffix(self):
  76. """ get the suffix of pdstates file """
  77. return "states"
  78. def get_ith_ckp_prefix(self, epoch_id):
  79. """ get the prefix of the epoch_id checkpoint file """
  80. return f"iter_epoch_{epoch_id}"
  81. def get_best_ckp_prefix(self):
  82. """ get the prefix of the best checkpoint file """
  83. return "best_accuracy"
  84. def get_score(self, pdstates_path):
  85. """ get the score by pdstates file """
  86. if not Path(pdstates_path).exists():
  87. return 0
  88. return paddle.load(pdstates_path)['best_model_dict']['hmean']
  89. def get_epoch_id_by_pdparams_prefix(self, pdparams_prefix):
  90. """ get the epoch_id by pdparams file """
  91. return int(pdparams_prefix.split(".")[0].split("_")[-1])