trainer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import time
  17. import tarfile
  18. from pathlib import Path
  19. import paddle
  20. from ..base import BaseTrainer, BaseTrainDeamon
  21. from ...utils.config import AttrDict
  22. from .support_models import SUPPORT_MODELS
  23. class TSFCTrainer(BaseTrainer):
  24. """ TS Forecast Model Trainer """
  25. support_models = SUPPORT_MODELS
  26. def build_deamon(self, config: AttrDict) -> "TSFCTrainDeamon":
  27. """build deamon thread for saving training outputs timely
  28. Args:
  29. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  30. Returns:
  31. TSFCTrainDeamon: the training deamon thread object for saving training outputs timely.
  32. """
  33. return TSFCTrainDeamon(config)
  34. def train(self):
  35. """firstly, update and dump train config, then train model
  36. """
  37. self.update_config()
  38. self.dump_config()
  39. train_result = self.pdx_model.train(**self.get_train_kwargs())
  40. assert train_result.returncode == 0, f"Encountered an unexpected error({train_result.returncode}) in \
  41. training!"
  42. self.make_tar_file()
  43. def make_tar_file(self):
  44. """make tar file to package the training outputs
  45. """
  46. tar_path = Path(
  47. self.global_config.output) / "best_accuracy.pdparams.tar"
  48. with tarfile.open(tar_path, 'w') as tar:
  49. tar.add(self.global_config.output, arcname='best_accuracy.pdparams')
  50. def update_config(self):
  51. """update training config
  52. """
  53. self.pdx_config.update_dataset(self.global_config.dataset_dir,
  54. "TSDataset")
  55. if self.train_config.input_len is not None:
  56. self.pdx_config.update_input_len(self.train_config.input_len)
  57. if self.train_config.time_col is not None:
  58. self.pdx_config.update_basic_info({
  59. 'time_col': self.train_config.time_col
  60. })
  61. if self.train_config.target_cols is not None:
  62. self.pdx_config.update_basic_info({
  63. 'target_cols': self.train_config.target_cols.split(',')
  64. })
  65. if self.train_config.freq is not None:
  66. try:
  67. self.train_config.freq = int(self.train_config.freq)
  68. except ValueError:
  69. pass
  70. self.pdx_config.update_basic_info({'freq': self.train_config.freq})
  71. if self.train_config.predict_len is not None:
  72. self.pdx_config.update_predict_len(self.train_config.predict_len)
  73. if self.train_config.patience is not None:
  74. self.pdx_config.update_patience(self.train_config.patience)
  75. if self.train_config.batch_size is not None:
  76. self.pdx_config.update_batch_size(self.train_config.batch_size)
  77. if self.train_config.learning_rate is not None:
  78. self.pdx_config.update_learning_rate(
  79. self.train_config.learning_rate)
  80. if self.train_config.epochs_iters is not None:
  81. self.pdx_config.update_epochs(self.train_config.epochs_iters)
  82. if self.global_config.output is not None:
  83. self.pdx_config.update_save_dir(self.global_config.output)
  84. def get_train_kwargs(self) -> dict:
  85. """get key-value arguments of model training function
  86. Returns:
  87. dict: the arguments of training function.
  88. """
  89. train_args = {"device": self.get_device()}
  90. if self.global_config.output is not None:
  91. train_args["save_dir"] = self.global_config.output
  92. return train_args
  93. class TSFCTrainDeamon(BaseTrainDeamon):
  94. """ TSFCTrainResultDemon """
  95. def get_watched_model(self):
  96. """ get the models needed to be watched """
  97. watched_models = []
  98. watched_models.append("best")
  99. return watched_models
  100. def update(self):
  101. """ update train result json """
  102. self.processing = True
  103. for i, result in enumerate(self.results):
  104. self.results[i] = self.update_result(result, self.train_outputs[i])
  105. self.save_json()
  106. self.processing = False
  107. def update_train_log(self, train_output):
  108. """ update train log """
  109. train_log_path = train_output / "train_ct.log"
  110. with open(train_log_path, 'w') as f:
  111. seconds = time.time()
  112. f.write('current training time: ' + time.strftime(
  113. "%Y-%m-%d %H:%M:%S", time.localtime(seconds)))
  114. f.close()
  115. return train_log_path
  116. def update_result(self, result, train_output):
  117. """ update every result """
  118. config = Path(train_output).joinpath("config.yaml")
  119. if not config.exists():
  120. return result
  121. result["config"] = config
  122. result["train_log"] = self.update_train_log(train_output)
  123. result["visualdl_log"] = self.update_vdl_log(train_output)
  124. result["label_dict"] = self.update_label_dict(train_output)
  125. self.update_models(result, train_output, "best")
  126. return result
  127. def update_models(self, result, train_output, model_key):
  128. """ update info of the models to be saved """
  129. pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
  130. if pdparams.exists():
  131. score = self.get_score(Path(train_output).joinpath("score.json"))
  132. result["models"][model_key] = {
  133. "score": "%.3f" % score,
  134. "pdparams": pdparams,
  135. "pdema": "",
  136. "pdopt": "",
  137. "pdstates": "",
  138. "inference_config": "",
  139. "pdmodel": "",
  140. "pdiparams": pdparams,
  141. "pdiparams.info": ""
  142. }
  143. def get_score(self, score_path):
  144. """ get the score by pdstates file """
  145. if not Path(score_path).exists():
  146. return 0
  147. return json.load(open(score_path))["metric"]
  148. def get_best_ckp_prefix(self):
  149. """ get the prefix of the best checkpoint file """
  150. pass
  151. def get_epoch_id_by_pdparams_prefix(self):
  152. """ get the epoch_id by pdparams file """
  153. pass
  154. def get_ith_ckp_prefix(self):
  155. """ get the prefix of the epoch_id checkpoint file """
  156. pass
  157. def get_the_pdema_suffix(self):
  158. """ get the suffix of pdema file """
  159. pass
  160. def get_the_pdopt_suffix(self):
  161. """ get the suffix of pdopt file """
  162. pass
  163. def get_the_pdparams_suffix(self):
  164. """ get the suffix of pdparams file """
  165. pass
  166. def get_the_pdstates_suffix(self):
  167. """ get the suffix of pdstates file """
  168. pass