trainer.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import time
  17. from pathlib import Path
  18. import tarfile
  19. import paddle
  20. from ..base import BaseTrainer, BaseTrainDeamon
  21. from ...utils.config import AttrDict
  22. from .model_list import MODELS
  23. class TSADTrainer(BaseTrainer):
  24. """TS Anomaly Detection Model Trainer"""
  25. entities = MODELS
  26. def build_deamon(self, config: AttrDict) -> "TSADTrainDeamon":
  27. """build deamon thread for saving training outputs timely
  28. Args:
  29. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  30. Returns:
  31. TSADTrainDeamon: the training deamon thread object for saving training outputs timely.
  32. """
  33. return TSADTrainDeamon(config)
  34. def train(self):
  35. """firstly, update and dump train config, then train model"""
  36. # XXX: using super().train() instead when the train_hook() is supported.
  37. os.makedirs(self.global_config.output, exist_ok=True)
  38. self.update_config()
  39. self.dump_config()
  40. train_result = self.pdx_model.train(**self.get_train_kwargs())
  41. assert (
  42. train_result.returncode == 0
  43. ), f"Encountered an unexpected error({train_result.returncode}) in \
  44. training!"
  45. self.make_tar_file()
  46. self.deamon.stop()
  47. def make_tar_file(self):
  48. """make tar file to package the training outputs"""
  49. tar_path = Path(self.global_config.output) / "best_accuracy.pdparams.tar"
  50. with tarfile.open(tar_path, "w") as tar:
  51. tar.add(self.global_config.output, arcname="best_accuracy.pdparams")
  52. def update_config(self):
  53. """update training config"""
  54. self.pdx_config.update_dataset(self.global_config.dataset_dir, "TSADDataset")
  55. if self.train_config.input_len is not None:
  56. self.pdx_config.update_input_len(self.train_config.input_len)
  57. if self.train_config.time_col is not None:
  58. self.pdx_config.update_basic_info({"time_col": self.train_config.time_col})
  59. if self.train_config.feature_cols is not None:
  60. if isinstance(self.train_config.feature_cols, tuple):
  61. feature_cols = [str(item) for item in self.train_config.feature_cols]
  62. self.pdx_config.update_basic_info({"feature_cols": feature_cols})
  63. else:
  64. self.pdx_config.update_basic_info(
  65. {"feature_cols": self.train_config.feature_cols.split(",")}
  66. )
  67. if self.train_config.label_col is not None:
  68. self.pdx_config.update_basic_info(
  69. {"label_col": self.train_config.label_col}
  70. )
  71. if self.train_config.freq is not None:
  72. try:
  73. self.train_config.freq = int(self.train_config.freq)
  74. except ValueError:
  75. pass
  76. self.pdx_config.update_basic_info({"freq": self.train_config.freq})
  77. if self.train_config.batch_size is not None:
  78. self.pdx_config.update_batch_size(self.train_config.batch_size)
  79. if self.train_config.learning_rate is not None:
  80. self.pdx_config.update_learning_rate(self.train_config.learning_rate)
  81. if self.train_config.epochs_iters is not None:
  82. self.pdx_config.update_epochs(self.train_config.epochs_iters)
  83. if self.global_config.output is not None:
  84. self.pdx_config.update_save_dir(self.global_config.output)
  85. def get_train_kwargs(self) -> dict:
  86. """get key-value arguments of model training function
  87. Returns:
  88. dict: the arguments of training function.
  89. """
  90. train_args = {"device": self.get_device()}
  91. if self.global_config.output is not None:
  92. train_args["save_dir"] = self.global_config.output
  93. return train_args
  94. class TSADTrainDeamon(BaseTrainDeamon):
  95. """DetTrainResultDemon"""
  96. def get_watched_model(self):
  97. """get the models needed to be watched"""
  98. watched_models = []
  99. watched_models.append("best")
  100. return watched_models
  101. def update(self):
  102. """update train result json"""
  103. self.processing = True
  104. for i, result in enumerate(self.results):
  105. self.results[i] = self.update_result(result, self.train_outputs[i])
  106. self.save_json()
  107. self.processing = False
  108. def update_train_log(self, train_output):
  109. """update train log"""
  110. train_log_path = train_output / "train_ct.log"
  111. with open(train_log_path, "w") as f:
  112. seconds = time.time()
  113. f.write(
  114. "current training time: "
  115. + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(seconds))
  116. )
  117. f.close()
  118. return train_log_path
  119. def update_result(self, result, train_output):
  120. """update every result"""
  121. train_output = Path(train_output).resolve()
  122. config_path = Path(train_output).joinpath("config.yaml").resolve()
  123. if not config_path.exists():
  124. return result
  125. model_name = result["model_name"]
  126. if (
  127. model_name in self.config_recorder
  128. and self.config_recorder[model_name] != config_path
  129. ):
  130. result["models"] = self.init_model_pkg()
  131. result["config"] = config_path
  132. self.config_recorder[model_name] = config_path
  133. result["config"] = config_path
  134. result["train_log"] = self.update_train_log(train_output)
  135. result["visualdl_log"] = self.update_vdl_log(train_output)
  136. result["label_dict"] = self.update_label_dict(train_output)
  137. model = self.get_model(result["model_name"], config_path)
  138. self.update_models(result, model, train_output, "best")
  139. return result
  140. def update_models(self, result, model, train_output, model_key):
  141. """update info of the models to be saved"""
  142. pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
  143. if pdparams.exists():
  144. score = self.get_score(Path(train_output).joinpath("score.json"))
  145. result["models"][model_key] = {
  146. "score": "%.3f" % score,
  147. "pdparams": pdparams,
  148. "pdema": "",
  149. "pdopt": "",
  150. "pdstates": "",
  151. "inference_config": "",
  152. "pdmodel": "",
  153. "pdiparams": pdparams,
  154. "pdiparams.info": "",
  155. }
  156. self.update_inference_model(
  157. model,
  158. train_output,
  159. train_output.joinpath(f"inference"),
  160. result["models"][model_key],
  161. )
  162. def update_inference_model(
  163. self, model, weight_path, export_save_dir, result_the_model
  164. ):
  165. """update inference model"""
  166. export_save_dir.mkdir(parents=True, exist_ok=True)
  167. export_result = model.export(weight_path=weight_path, save_dir=export_save_dir)
  168. if export_result.returncode == 0:
  169. inference_config = export_save_dir.joinpath("inference.yml")
  170. if not inference_config.exists():
  171. inference_config = ""
  172. use_pir = (
  173. hasattr(paddle.framework, "use_pir_api")
  174. and paddle.framework.use_pir_api()
  175. )
  176. pdmodel = (
  177. export_save_dir.joinpath("inference.json")
  178. if use_pir
  179. else export_save_dir.joinpath("inference.pdmodel")
  180. )
  181. pdiparams = export_save_dir.joinpath("inference.pdiparams")
  182. pdiparams_info = (
  183. "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
  184. )
  185. else:
  186. inference_config = ""
  187. pdmodel = ""
  188. pdiparams = ""
  189. pdiparams_info = ""
  190. result_the_model["inference_config"] = inference_config
  191. result_the_model["pdmodel"] = pdmodel
  192. result_the_model["pdiparams"] = pdiparams
  193. result_the_model["pdiparams.info"] = pdiparams_info
  194. def get_score(self, score_path):
  195. """get the score by pdstates file"""
  196. if not Path(score_path).exists():
  197. return 0
  198. return json.load(open(score_path, "r"))["metric"]
  199. def get_best_ckp_prefix(self):
  200. """get the prefix of the best checkpoint file"""
  201. pass
  202. def get_epoch_id_by_pdparams_prefix(self):
  203. """get the epoch_id by pdparams file"""
  204. pass
  205. def get_ith_ckp_prefix(self):
  206. """get the prefix of the epoch_id checkpoint file"""
  207. pass
  208. def get_the_pdema_suffix(self):
  209. """get the suffix of pdema file"""
  210. pass
  211. def get_the_pdopt_suffix(self):
  212. """get the suffix of pdopt file"""
  213. pass
  214. def get_the_pdparams_suffix(self):
  215. """get the suffix of pdparams file"""
  216. pass
  217. def get_the_pdstates_suffix(self):
  218. """get the suffix of pdstates file"""
  219. pass