trainer.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import glob
  16. from pathlib import Path
  17. import lazy_paddle as paddle
  18. from ..base import BaseTrainer, BaseTrainDeamon
  19. from ...utils.config import AttrDict
  20. from .model_list import MODELS
  21. class UadTrainer(BaseTrainer):
  22. """Uad Model Trainer"""
  23. entities = MODELS
  24. def build_deamon(self, config: AttrDict) -> "SegTrainDeamon":
  25. """build deamon thread for saving training outputs timely
  26. Args:
  27. config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
  28. Returns:
  29. SegTrainDeamon: the training deamon thread object for saving training outputs timely.
  30. """
  31. return SegTrainDeamon(config)
  32. def update_config(self):
  33. """update training config"""
  34. self.pdx_config.update_dataset(self.global_config.dataset_dir, "SegDataset")
  35. if self.train_config.num_classes is not None:
  36. self.pdx_config.update_num_classes(self.train_config.num_classes)
  37. if (
  38. self.train_config.pretrain_weight_path
  39. and self.train_config.pretrain_weight_path != ""
  40. ):
  41. self.pdx_config.update_pretrained_weights(
  42. self.train_config.pretrain_weight_path, is_backbone=True
  43. )
  44. def get_train_kwargs(self) -> dict:
  45. """get key-value arguments of model training function
  46. Returns:
  47. dict: the arguments of training function.
  48. """
  49. train_args = {"device": self.get_device()}
  50. # XXX:
  51. os.environ.pop("FLAGS_npu_jit_compile", None)
  52. if self.train_config.batch_size is not None:
  53. train_args["batch_size"] = self.train_config.batch_size
  54. if self.train_config.learning_rate is not None:
  55. train_args["learning_rate"] = self.train_config.learning_rate
  56. if self.train_config.epochs_iters is not None:
  57. train_args["epochs_iters"] = self.train_config.epochs_iters
  58. if (
  59. self.train_config.resume_path is not None
  60. and self.train_config.resume_path != ""
  61. ):
  62. train_args["resume_path"] = self.train_config.resume_path
  63. if self.global_config.output is not None:
  64. train_args["save_dir"] = self.global_config.output
  65. if self.train_config.log_interval:
  66. train_args["log_iters"] = self.train_config.log_interval
  67. if self.train_config.eval_interval:
  68. train_args["do_eval"] = True
  69. train_args["save_interval"] = self.train_config.eval_interval
  70. train_args["dy2st"] = self.train_config.get("dy2st", False)
  71. return train_args
  72. class SegTrainDeamon(BaseTrainDeamon):
  73. """SegTrainResultDemon"""
  74. last_k = 1
  75. def __init__(self, *args, **kwargs):
  76. super().__init__(*args, **kwargs)
  77. def get_the_pdparams_suffix(self):
  78. """get the suffix of pdparams file"""
  79. return "pdparams"
  80. def get_the_pdema_suffix(self):
  81. """get the suffix of pdema file"""
  82. return "pdema"
  83. def get_the_pdopt_suffix(self):
  84. """get the suffix of pdopt file"""
  85. return "pdopt"
  86. def get_the_pdstates_suffix(self):
  87. """get the suffix of pdstates file"""
  88. return "pdstates"
  89. def get_ith_ckp_prefix(self, epoch_id):
  90. """get the prefix of the epoch_id checkpoint file"""
  91. return f"iter_{epoch_id}/model"
  92. def get_best_ckp_prefix(self):
  93. """get the prefix of the best checkpoint file"""
  94. return "best_model/model"
  95. def get_score(self, pdstates_path):
  96. """get the score by pdstates file"""
  97. if not Path(pdstates_path).exists():
  98. return 0
  99. return paddle.load(pdstates_path)["mIoU"]
  100. def get_epoch_id_by_pdparams_prefix(self, pdparams_dir):
  101. """get the epoch_id by pdparams file"""
  102. return int(pdparams_dir.parent.name.split("_")[-1])
  103. def update_result(self, result, train_output):
  104. """update every result"""
  105. train_output = Path(train_output).resolve()
  106. config_path = train_output.joinpath("config.yaml").resolve()
  107. if not config_path.exists():
  108. return result
  109. model_name = result["model_name"]
  110. if (
  111. model_name in self.config_recorder
  112. and self.config_recorder[model_name] != config_path
  113. ):
  114. result["models"] = self.init_model_pkg()
  115. result["config"] = config_path
  116. self.config_recorder[model_name] = config_path
  117. result["visualdl_log"] = self.update_vdl_log(train_output)
  118. result["label_dict"] = self.update_label_dict(train_output)
  119. model = self.get_model(result["model_name"], config_path)
  120. params_path_list = list(
  121. train_output.glob(
  122. ".".join(
  123. [self.get_ith_ckp_prefix("[0-9]*"), self.get_the_pdparams_suffix()]
  124. )
  125. )
  126. )
  127. iter_ids = []
  128. for params_path in params_path_list:
  129. iter_id = self.get_epoch_id_by_pdparams_prefix(params_path)
  130. iter_ids.append(iter_id)
  131. iter_ids.sort()
  132. # TODO(gaotingquan): how to avoid that the latest ckp files is being saved
  133. # epoch_ids = epoch_ids[:-1]
  134. for i in range(1, self.last_k + 1):
  135. if len(iter_ids) < i:
  136. break
  137. self.update_models(
  138. result,
  139. model,
  140. train_output,
  141. f"last_{i}",
  142. self.get_ith_ckp_prefix(iter_ids[-i]),
  143. )
  144. self.update_models(
  145. result, model, train_output, "best", self.get_best_ckp_prefix()
  146. )
  147. return result
  148. def update_models(self, result, model, train_output, model_key, ckp_prefix):
  149. """update info of the models to be saved"""
  150. pdparams = train_output.joinpath(
  151. ".".join([ckp_prefix, self.get_the_pdparams_suffix()])
  152. )
  153. if pdparams.exists():
  154. recorder_key = f"{train_output.name}_{model_key}"
  155. if (
  156. model_key != "best"
  157. and recorder_key in self.model_recorder
  158. and self.model_recorder[recorder_key] == pdparams
  159. ):
  160. return
  161. self.model_recorder[recorder_key] = pdparams
  162. pdema = ""
  163. pdema_suffix = self.get_the_pdema_suffix()
  164. if pdema_suffix:
  165. pdema = pdparams.parents[1].joinpath(
  166. ".".join([ckp_prefix, pdema_suffix])
  167. )
  168. if not pdema.exists():
  169. pdema = ""
  170. pdopt = ""
  171. pdopt_suffix = self.get_the_pdopt_suffix()
  172. if pdopt_suffix:
  173. pdopt = pdparams.parents[1].joinpath(
  174. ".".join([ckp_prefix, pdopt_suffix])
  175. )
  176. if not pdopt.exists():
  177. pdopt = ""
  178. pdstates = ""
  179. pdstates_suffix = self.get_the_pdstates_suffix()
  180. if pdstates_suffix:
  181. pdstates = pdparams.parents[1].joinpath(
  182. ".".join([ckp_prefix, pdstates_suffix])
  183. )
  184. if not pdstates.exists():
  185. pdstates = ""
  186. score = self.get_score(Path(pdstates).resolve().as_posix())
  187. result["models"][model_key] = {
  188. "score": score,
  189. "pdparams": pdparams,
  190. "pdema": pdema,
  191. "pdopt": pdopt,
  192. "pdstates": pdstates,
  193. }
  194. self.update_inference_model(
  195. model,
  196. pdparams,
  197. train_output.joinpath(f"{ckp_prefix}"),
  198. result["models"][model_key],
  199. )