trainer.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import time
  17. import tarfile
  18. from pathlib import Path
  19. import paddle
  20. from ..base import BaseTrainer, BaseTrainDeamon
  21. from ...utils.config import AttrDict
  22. from .model_list import MODELS
  23. class TSCLSTrainer(BaseTrainer):
  24. """ TS Classification Model Trainer """
  25. entities = MODELS
  26. def build_deamon(self, config: AttrDict) -> "TSCLSTrainDeamon":
  27. """build deamon thread for saving training outputs timely
  28. Args:
  29. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  30. Returns:
  31. TSCLSTrainDeamon: the training deamon thread object for saving training outputs timely.
  32. """
  33. return TSCLSTrainDeamon(config)
  34. def train(self):
  35. """firstly, update and dump train config, then train model
  36. """
  37. # XXX: using super().train() instead when the train_hook() is supported.
  38. os.makedirs(self.global_config.output, exist_ok=True)
  39. self.update_config()
  40. self.dump_config()
  41. train_result = self.pdx_model.train(**self.get_train_kwargs())
  42. assert train_result.returncode == 0, f"Encountered an unexpected error({train_result.returncode}) in \
  43. training!"
  44. self.make_tar_file()
  45. self.deamon.stop()
  46. def make_tar_file(self):
  47. """make tar file to package the training outputs
  48. """
  49. tar_path = Path(
  50. self.global_config.output) / "best_accuracy.pdparams.tar"
  51. with tarfile.open(tar_path, 'w') as tar:
  52. tar.add(self.global_config.output, arcname='best_accuracy.pdparams')
  53. def update_config(self):
  54. """update training config
  55. """
  56. self.pdx_config.update_dataset(self.global_config.dataset_dir,
  57. "TSCLSDataset")
  58. if self.train_config.time_col is not None:
  59. self.pdx_config.update_basic_info({
  60. 'time_col': self.train_config.time_col
  61. })
  62. if self.train_config.target_cols is not None:
  63. self.pdx_config.update_basic_info({
  64. 'target_cols': self.train_config.target_cols.split(',')
  65. })
  66. if self.train_config.group_id is not None:
  67. self.pdx_config.update_basic_info({
  68. 'group_id': self.train_config.group_id
  69. })
  70. if self.train_config.static_cov_cols is not None:
  71. self.pdx_config.update_basic_info({
  72. 'static_cov_cols': self.train_config.static_cov_cols
  73. })
  74. if self.train_config.freq is not None:
  75. try:
  76. self.train_config.freq = int(self.train_config.freq)
  77. except ValueError:
  78. pass
  79. self.pdx_config.update_basic_info({'freq': self.train_config.freq})
  80. if self.train_config.batch_size is not None:
  81. self.pdx_config.update_batch_size(self.train_config.batch_size)
  82. if self.train_config.learning_rate is not None:
  83. self.pdx_config.update_learning_rate(
  84. self.train_config.learning_rate)
  85. if self.train_config.epochs_iters is not None:
  86. self.pdx_config.update_epochs(self.train_config.epochs_iters)
  87. if self.global_config.output is not None:
  88. self.pdx_config.update_save_dir(self.global_config.output)
  89. def get_train_kwargs(self) -> dict:
  90. """get key-value arguments of model training function
  91. Returns:
  92. dict: the arguments of training function.
  93. """
  94. train_args = {"device": self.get_device()}
  95. if self.global_config.output is not None:
  96. train_args["save_dir"] = self.global_config.output
  97. return train_args
  98. class TSCLSTrainDeamon(BaseTrainDeamon):
  99. """ TSCLSTrainResultDemon """
  100. def get_watched_model(self):
  101. """ get the models needed to be watched """
  102. watched_models = []
  103. watched_models.append("best")
  104. return watched_models
  105. def update(self):
  106. """ update train result json """
  107. self.processing = True
  108. for i, result in enumerate(self.results):
  109. self.results[i] = self.update_result(result, self.train_outputs[i])
  110. self.save_json()
  111. self.processing = False
  112. def update_train_log(self, train_output):
  113. """ update train log """
  114. train_log_path = train_output / "train_ct.log"
  115. with open(train_log_path, 'w') as f:
  116. seconds = time.time()
  117. f.write('current training time: ' + time.strftime(
  118. "%Y-%m-%d %H:%M:%S", time.localtime(seconds)))
  119. f.close()
  120. return train_log_path
  121. def update_result(self, result, train_output):
  122. """ update every result """
  123. config = Path(train_output).joinpath("config.yaml")
  124. if not config.exists():
  125. return result
  126. result["config"] = config
  127. result["train_log"] = self.update_train_log(train_output)
  128. result["visualdl_log"] = self.update_vdl_log(train_output)
  129. result["label_dict"] = self.update_label_dict(train_output)
  130. self.update_models(result, train_output, "best")
  131. return result
  132. def update_models(self, result, train_output, model_key):
  133. """ update info of the models to be saved """
  134. pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
  135. if pdparams.exists():
  136. score = self.get_score(Path(train_output).joinpath("score.json"))
  137. result["models"][model_key] = {
  138. "score": "%.3f" % score,
  139. "pdparams": pdparams,
  140. "pdema": "",
  141. "pdopt": "",
  142. "pdstates": "",
  143. "inference_config": "",
  144. "pdmodel": "",
  145. "pdiparams": pdparams,
  146. "pdiparams.info": ""
  147. }
  148. def get_score(self, score_path):
  149. """ get the score by pdstates file """
  150. if not Path(score_path).exists():
  151. return 0
  152. return json.load(open(score_path))["metric"]
  153. def get_best_ckp_prefix(self):
  154. """ get the prefix of the best checkpoint file """
  155. pass
  156. def get_epoch_id_by_pdparams_prefix(self):
  157. """ get the epoch_id by pdparams file """
  158. pass
  159. def get_ith_ckp_prefix(self):
  160. """ get the prefix of the epoch_id checkpoint file """
  161. pass
  162. def get_the_pdema_suffix(self):
  163. """ get the suffix of pdema file """
  164. pass
  165. def get_the_pdopt_suffix(self):
  166. """ get the suffix of pdopt file """
  167. pass
  168. def get_the_pdparams_suffix(self):
  169. """ get the suffix of pdparams file """
  170. pass
  171. def get_the_pdstates_suffix(self):
  172. """ get the suffix of pdstates file """
  173. pass