trainer.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import time
  17. import tarfile
  18. from pathlib import Path
  19. import paddle
  20. from ..base import BaseTrainer, BaseTrainDeamon
  21. from ...utils.config import AttrDict
  22. from .model_list import MODELS
  23. class TSCLSTrainer(BaseTrainer):
  24. """TS Classification Model Trainer"""
  25. entities = MODELS
  26. def build_deamon(self, config: AttrDict) -> "TSCLSTrainDeamon":
  27. """build deamon thread for saving training outputs timely
  28. Args:
  29. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  30. Returns:
  31. TSCLSTrainDeamon: the training deamon thread object for saving training outputs timely.
  32. """
  33. return TSCLSTrainDeamon(config)
  34. def train(self):
  35. """firstly, update and dump train config, then train model"""
  36. # XXX: using super().train() instead when the train_hook() is supported.
  37. os.makedirs(self.global_config.output, exist_ok=True)
  38. self.update_config()
  39. self.dump_config()
  40. train_result = self.pdx_model.train(**self.get_train_kwargs())
  41. assert (
  42. train_result.returncode == 0
  43. ), f"Encountered an unexpected error({train_result.returncode}) in \
  44. training!"
  45. self.make_tar_file()
  46. self.deamon.stop()
  47. def make_tar_file(self):
  48. """make tar file to package the training outputs"""
  49. tar_path = Path(self.global_config.output) / "best_accuracy.pdparams.tar"
  50. with tarfile.open(tar_path, "w") as tar:
  51. tar.add(self.global_config.output, arcname="best_accuracy.pdparams")
  52. def update_config(self):
  53. """update training config"""
  54. self.pdx_config.update_dataset(self.global_config.dataset_dir, "TSCLSDataset")
  55. if self.train_config.time_col is not None:
  56. self.pdx_config.update_basic_info({"time_col": self.train_config.time_col})
  57. if self.train_config.target_cols is not None:
  58. self.pdx_config.update_basic_info(
  59. {"target_cols": self.train_config.target_cols.split(",")}
  60. )
  61. if self.train_config.group_id is not None:
  62. self.pdx_config.update_basic_info({"group_id": self.train_config.group_id})
  63. if self.train_config.static_cov_cols is not None:
  64. self.pdx_config.update_basic_info(
  65. {"static_cov_cols": self.train_config.static_cov_cols}
  66. )
  67. if self.train_config.freq is not None:
  68. try:
  69. self.train_config.freq = int(self.train_config.freq)
  70. except ValueError:
  71. pass
  72. self.pdx_config.update_basic_info({"freq": self.train_config.freq})
  73. if self.train_config.batch_size is not None:
  74. self.pdx_config.update_batch_size(self.train_config.batch_size)
  75. if self.train_config.learning_rate is not None:
  76. self.pdx_config.update_learning_rate(self.train_config.learning_rate)
  77. if self.train_config.epochs_iters is not None:
  78. self.pdx_config.update_epochs(self.train_config.epochs_iters)
  79. if self.global_config.output is not None:
  80. self.pdx_config.update_save_dir(self.global_config.output)
  81. def get_train_kwargs(self) -> dict:
  82. """get key-value arguments of model training function
  83. Returns:
  84. dict: the arguments of training function.
  85. """
  86. train_args = {"device": self.get_device()}
  87. if self.global_config.output is not None:
  88. train_args["save_dir"] = self.global_config.output
  89. return train_args
  90. class TSCLSTrainDeamon(BaseTrainDeamon):
  91. """TSCLSTrainResultDemon"""
  92. def get_watched_model(self):
  93. """get the models needed to be watched"""
  94. watched_models = []
  95. watched_models.append("best")
  96. return watched_models
  97. def update(self):
  98. """update train result json"""
  99. self.processing = True
  100. for i, result in enumerate(self.results):
  101. self.results[i] = self.update_result(result, self.train_outputs[i])
  102. self.save_json()
  103. self.processing = False
  104. def update_train_log(self, train_output):
  105. """update train log"""
  106. train_log_path = train_output / "train_ct.log"
  107. with open(train_log_path, "w") as f:
  108. seconds = time.time()
  109. f.write(
  110. "current training time: "
  111. + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(seconds))
  112. )
  113. f.close()
  114. return train_log_path
  115. def update_result(self, result, train_output):
  116. """update every result"""
  117. train_output = Path(train_output).resolve()
  118. config_path = Path(train_output).joinpath("config.yaml").resolve()
  119. if not config_path.exists():
  120. return result
  121. model_name = result["model_name"]
  122. if (
  123. model_name in self.config_recorder
  124. and self.config_recorder[model_name] != config_path
  125. ):
  126. result["models"] = self.init_model_pkg()
  127. result["config"] = config_path
  128. self.config_recorder[model_name] = config_path
  129. result["config"] = config_path
  130. result["train_log"] = self.update_train_log(train_output)
  131. result["visualdl_log"] = self.update_vdl_log(train_output)
  132. result["label_dict"] = self.update_label_dict(train_output)
  133. model = self.get_model(result["model_name"], config_path)
  134. self.update_models(result, model, train_output, "best")
  135. return result
  136. def update_models(self, result, model, train_output, model_key):
  137. """update info of the models to be saved"""
  138. pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
  139. if pdparams.exists():
  140. score = self.get_score(Path(train_output).joinpath("score.json"))
  141. result["models"][model_key] = {
  142. "score": "%.3f" % score,
  143. "pdparams": pdparams,
  144. "pdema": "",
  145. "pdopt": "",
  146. "pdstates": "",
  147. "inference_config": "",
  148. "pdmodel": "",
  149. "pdiparams": pdparams,
  150. "pdiparams.info": "",
  151. }
  152. self.update_inference_model(
  153. model,
  154. train_output,
  155. train_output.joinpath(f"inference"),
  156. result["models"][model_key],
  157. )
  158. def update_inference_model(
  159. self, model, weight_path, export_save_dir, result_the_model
  160. ):
  161. """update inference model"""
  162. export_save_dir.mkdir(parents=True, exist_ok=True)
  163. export_result = model.export(weight_path=weight_path, save_dir=export_save_dir)
  164. if export_result.returncode == 0:
  165. inference_config = export_save_dir.joinpath("inference.yml")
  166. if not inference_config.exists():
  167. inference_config = ""
  168. use_pir = (
  169. hasattr(paddle.framework, "use_pir_api")
  170. and paddle.framework.use_pir_api()
  171. )
  172. pdmodel = (
  173. export_save_dir.joinpath("inference.json")
  174. if use_pir
  175. else export_save_dir.joinpath("inference.pdmodel")
  176. )
  177. pdiparams = export_save_dir.joinpath("inference.pdiparams")
  178. pdiparams_info = (
  179. "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
  180. )
  181. else:
  182. inference_config = ""
  183. pdmodel = ""
  184. pdiparams = ""
  185. pdiparams_info = ""
  186. result_the_model["inference_config"] = inference_config
  187. result_the_model["pdmodel"] = pdmodel
  188. result_the_model["pdiparams"] = pdiparams
  189. result_the_model["pdiparams.info"] = pdiparams_info
  190. def get_score(self, score_path):
  191. """get the score by pdstates file"""
  192. if not Path(score_path).exists():
  193. return 0
  194. return json.load(open(score_path))["metric"]
  195. def get_best_ckp_prefix(self):
  196. """get the prefix of the best checkpoint file"""
  197. pass
  198. def get_epoch_id_by_pdparams_prefix(self):
  199. """get the epoch_id by pdparams file"""
  200. pass
  201. def get_ith_ckp_prefix(self):
  202. """get the prefix of the epoch_id checkpoint file"""
  203. pass
  204. def get_the_pdema_suffix(self):
  205. """get the suffix of pdema file"""
  206. pass
  207. def get_the_pdopt_suffix(self):
  208. """get the suffix of pdopt file"""
  209. pass
  210. def get_the_pdparams_suffix(self):
  211. """get the suffix of pdparams file"""
  212. pass
  213. def get_the_pdstates_suffix(self):
  214. """get the suffix of pdstates file"""
  215. pass