trainer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import time
  17. import tarfile
  18. from pathlib import Path
  19. import paddle
  20. from ..base import BaseTrainer, BaseTrainDeamon
  21. from ...utils.config import AttrDict
  22. from .support_models import SUPPORT_MODELS
  23. class TSCLSTrainer(BaseTrainer):
  24. """ TS Classification Model Trainer """
  25. support_models = SUPPORT_MODELS
  26. def build_deamon(self, config: AttrDict) -> "TSCLSTrainDeamon":
  27. """build deamon thread for saving training outputs timely
  28. Args:
  29. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  30. Returns:
  31. TSCLSTrainDeamon: the training deamon thread object for saving training outputs timely.
  32. """
  33. return TSCLSTrainDeamon(config)
  34. def train(self):
  35. """firstly, update and dump train config, then train model
  36. """
  37. self.update_config()
  38. self.dump_config()
  39. train_result = self.pdx_model.train(**self.get_train_kwargs())
  40. assert train_result.returncode == 0, f"Encountered an unexpected error({train_result.returncode}) in \
  41. training!"
  42. self.make_tar_file()
  43. def make_tar_file(self):
  44. """make tar file to package the training outputs
  45. """
  46. tar_path = Path(
  47. self.global_config.output) / "best_accuracy.pdparams.tar"
  48. with tarfile.open(tar_path, 'w') as tar:
  49. tar.add(self.global_config.output, arcname='best_accuracy.pdparams')
  50. def update_config(self):
  51. """update training config
  52. """
  53. self.pdx_config.update_dataset(self.global_config.dataset_dir,
  54. "TSCLSDataset")
  55. if self.train_config.time_col is not None:
  56. self.pdx_config.update_basic_info({
  57. 'time_col': self.train_config.time_col
  58. })
  59. if self.train_config.target_cols is not None:
  60. self.pdx_config.update_basic_info({
  61. 'target_cols': self.train_config.target_cols.split(',')
  62. })
  63. if self.train_config.group_id is not None:
  64. self.pdx_config.update_basic_info({
  65. 'group_id': self.train_config.group_id
  66. })
  67. if self.train_config.static_cov_cols is not None:
  68. self.pdx_config.update_basic_info({
  69. 'static_cov_cols': self.train_config.static_cov_cols
  70. })
  71. if self.train_config.freq is not None:
  72. try:
  73. self.train_config.freq = int(self.train_config.freq)
  74. except ValueError:
  75. pass
  76. self.pdx_config.update_basic_info({'freq': self.train_config.freq})
  77. if self.train_config.batch_size is not None:
  78. self.pdx_config.update_batch_size(self.train_config.batch_size)
  79. if self.train_config.learning_rate is not None:
  80. self.pdx_config.update_learning_rate(
  81. self.train_config.learning_rate)
  82. if self.train_config.epochs_iters is not None:
  83. self.pdx_config.update_epochs(self.train_config.epochs_iters)
  84. if self.global_config.output is not None:
  85. self.pdx_config.update_save_dir(self.global_config.output)
  86. def get_train_kwargs(self) -> dict:
  87. """get key-value arguments of model training function
  88. Returns:
  89. dict: the arguments of training function.
  90. """
  91. train_args = {"device": self.get_device()}
  92. if self.global_config.output is not None:
  93. train_args["save_dir"] = self.global_config.output
  94. return train_args
  95. class TSCLSTrainDeamon(BaseTrainDeamon):
  96. """ TSCLSTrainResultDemon """
  97. def get_watched_model(self):
  98. """ get the models needed to be watched """
  99. watched_models = []
  100. watched_models.append("best")
  101. return watched_models
  102. def update(self):
  103. """ update train result json """
  104. self.processing = True
  105. for i, result in enumerate(self.results):
  106. self.results[i] = self.update_result(result, self.train_outputs[i])
  107. self.save_json()
  108. self.processing = False
  109. def update_train_log(self, train_output):
  110. """ update train log """
  111. train_log_path = train_output / "train_ct.log"
  112. with open(train_log_path, 'w') as f:
  113. seconds = time.time()
  114. f.write('current training time: ' + time.strftime(
  115. "%Y-%m-%d %H:%M:%S", time.localtime(seconds)))
  116. f.close()
  117. return train_log_path
  118. def update_result(self, result, train_output):
  119. """ update every result """
  120. config = Path(train_output).joinpath("config.yaml")
  121. if not config.exists():
  122. return result
  123. result["config"] = config
  124. result["train_log"] = self.update_train_log(train_output)
  125. result["visualdl_log"] = self.update_vdl_log(train_output)
  126. result["label_dict"] = self.update_label_dict(train_output)
  127. self.update_models(result, train_output, "best")
  128. return result
  129. def update_models(self, result, train_output, model_key):
  130. """ update info of the models to be saved """
  131. pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
  132. if pdparams.exists():
  133. score = self.get_score(Path(train_output).joinpath("score.json"))
  134. result["models"][model_key] = {
  135. "score": "%.3f" % score,
  136. "pdparams": pdparams,
  137. "pdema": "",
  138. "pdopt": "",
  139. "pdstates": "",
  140. "inference_config": "",
  141. "pdmodel": "",
  142. "pdiparams": pdparams,
  143. "pdiparams.info": ""
  144. }
  145. def get_score(self, score_path):
  146. """ get the score by pdstates file """
  147. if not Path(score_path).exists():
  148. return 0
  149. return json.load(open(score_path))["metric"]
  150. def get_best_ckp_prefix(self):
  151. """ get the prefix of the best checkpoint file """
  152. pass
  153. def get_epoch_id_by_pdparams_prefix(self):
  154. """ get the epoch_id by pdparams file """
  155. pass
  156. def get_ith_ckp_prefix(self):
  157. """ get the prefix of the epoch_id checkpoint file """
  158. pass
  159. def get_the_pdema_suffix(self):
  160. """ get the suffix of pdema file """
  161. pass
  162. def get_the_pdopt_suffix(self):
  163. """ get the suffix of pdopt file """
  164. pass
  165. def get_the_pdparams_suffix(self):
  166. """ get the suffix of pdparams file """
  167. pass
  168. def get_the_pdstates_suffix(self):
  169. """ get the suffix of pdstates file """
  170. pass