train_deamon.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import sys
  13. import time
  14. import json
  15. import traceback
  16. import threading
  17. from abc import ABC, abstractmethod
  18. from pathlib import Path
  19. from ..build_model import build_model
  20. from ....utils.file_interface import write_json_file
  21. from ....utils import logging
  22. def try_except_decorator(func):
  23. """ try-except """
  24. def wrap(self, *args, **kwargs):
  25. try:
  26. func(self, *args, **kwargs)
  27. except Exception as e:
  28. exc_type, exc_value, exc_tb = sys.exc_info()
  29. self.save_json()
  30. traceback.print_exception(exc_type, exc_value, exc_tb)
  31. finally:
  32. self.processing = False
  33. return wrap
  34. class BaseTrainDeamon(ABC):
  35. """ BaseTrainResultDemon """
  36. update_interval = 600
  37. last_k = 5
  38. def __init__(self, global_config):
  39. """ init """
  40. self.global_config = global_config
  41. self.init_pre_hook()
  42. self.output = global_config.output
  43. self.train_outputs = self.get_train_outputs()
  44. self.save_paths = self.get_save_paths()
  45. self.results = self.init_train_result()
  46. self.save_json()
  47. self.models = {}
  48. self.init_post_hook()
  49. self.config_recorder = {}
  50. self.model_recorder = {}
  51. self.processing = False
  52. self.start()
  53. def init_train_result(self):
  54. """ init train result structure """
  55. model_names = self.init_model_names()
  56. configs = self.init_configs()
  57. train_log = self.init_train_log()
  58. vdl = self.init_vdl_log()
  59. results = []
  60. for i, model_name in enumerate(model_names):
  61. results.append({
  62. "model_name": model_name,
  63. "done_flag": False,
  64. "config": configs[i],
  65. "label_dict": "",
  66. "train_log": train_log,
  67. "visualdl_log": vdl,
  68. "models": self.init_model_pkg()
  69. })
  70. return results
  71. def get_save_names(self):
  72. """ get names to save """
  73. return ["train_result.json"]
  74. def get_train_outputs(self):
  75. """ get training outputs dir """
  76. return [Path(self.output)]
  77. def init_model_names(self):
  78. """ get models name """
  79. return [self.global_config.model]
  80. def get_save_paths(self):
  81. """ get the path to save train_result.json """
  82. return [
  83. Path(self.output, save_name) for save_name in self.get_save_names()
  84. ]
  85. def init_configs(self):
  86. """ get the init value of config field in result """
  87. return [""] * len(self.init_model_names())
  88. def init_train_log(self):
  89. """ get train log """
  90. return ""
  91. def init_vdl_log(self):
  92. """ get visualdl log """
  93. return ""
  94. def init_model_pkg(self):
  95. """ get model package """
  96. init_content = self.init_model_content()
  97. model_pkg = {}
  98. for pkg in self.get_watched_model():
  99. model_pkg[pkg] = init_content
  100. return model_pkg
  101. def normlize_path(self, dict_obj, relative_to):
  102. """ normlize path to string type path relative to the output """
  103. for key in dict_obj:
  104. if isinstance(dict_obj[key], dict):
  105. self.normlize_path(dict_obj[key], relative_to)
  106. if isinstance(dict_obj[key], Path):
  107. dict_obj[key] = dict_obj[key].resolve().relative_to(
  108. relative_to.resolve()).as_posix()
  109. def save_json(self):
  110. """ save result to json """
  111. for i, result in enumerate(self.results):
  112. self.save_paths[i].parent.mkdir(parents=True, exist_ok=True)
  113. self.normlize_path(result, relative_to=self.save_paths[i].parent)
  114. write_json_file(result, self.save_paths[i], indent=2)
  115. def start(self):
  116. """ start deamon thread """
  117. self.exit = False
  118. self.thread = threading.Thread(target=self.run)
  119. self.thread.daemon = True
  120. self.thread.start()
  121. def stop_hook(self):
  122. """ hook befor stop """
  123. for result in self.results:
  124. result["done_flag"] = True
  125. self.update()
  126. def stop(self):
  127. """ stop self """
  128. self.exit = True
  129. while True:
  130. if not self.processing:
  131. self.stop_hook()
  132. break
  133. time.sleep(60)
  134. def run(self):
  135. """ main function """
  136. while not self.exit:
  137. self.update()
  138. if self.exit:
  139. break
  140. time.sleep(self.update_interval)
  141. def update_train_log(self, train_output):
  142. """ update train log """
  143. train_log_path = train_output / "train.log"
  144. if train_log_path.exists():
  145. return train_log_path
  146. def update_vdl_log(self, train_output):
  147. """ update visualdl log """
  148. vdl_path = list(train_output.glob("vdlrecords*log"))
  149. if len(vdl_path) >= 1:
  150. return vdl_path[0]
  151. def update_label_dict(self, train_output):
  152. """ update label dict """
  153. dict_path = train_output.joinpath("label_dict.txt")
  154. if not dict_path.exists():
  155. return ""
  156. return dict_path
  157. @try_except_decorator
  158. def update(self):
  159. """ update train result json """
  160. self.processing = True
  161. for i in range(len(self.results)):
  162. self.results[i] = self.update_result(self.results[i],
  163. self.train_outputs[i])
  164. self.save_json()
  165. self.processing = False
  166. def get_model(self, model_name, config_path):
  167. """ initialize the model """
  168. if model_name not in self.models:
  169. config, model = build_model(
  170. model_name,
  171. device=self.global_config.device,
  172. config_path=config_path)
  173. self.models[model_name] = model
  174. return self.models[model_name]
  175. def get_watched_model(self):
  176. """ get the models needed to be watched """
  177. watched_models = [f"last_{i}" for i in range(1, self.last_k + 1)]
  178. watched_models.append("best")
  179. return watched_models
  180. def init_model_content(self):
  181. """ get model content structure """
  182. return {
  183. "score": "",
  184. "pdparams": "",
  185. "pdema": "",
  186. "pdopt": "",
  187. "pdstates": "",
  188. "inference_config": "",
  189. "pdmodel": "",
  190. "pdiparams": "",
  191. "pdiparams.info": ""
  192. }
  193. def update_result(self, result, train_output):
  194. """ update every result """
  195. train_output = Path(train_output).resolve()
  196. config_path = train_output.joinpath("config.yaml").resolve()
  197. if not config_path.exists():
  198. return result
  199. model_name = result["model_name"]
  200. if model_name in self.config_recorder and self.config_recorder[
  201. model_name] != config_path:
  202. result["models"] = self.init_model_pkg()
  203. result["config"] = config_path
  204. self.config_recorder[model_name] = config_path
  205. result["train_log"] = self.update_train_log(train_output)
  206. result["visualdl_log"] = self.update_vdl_log(train_output)
  207. result["label_dict"] = self.update_label_dict(train_output)
  208. model = self.get_model(result["model_name"], config_path)
  209. params_path_list = list(
  210. train_output.glob(".".join([
  211. self.get_ith_ckp_prefix("[0-9]*"), self.get_the_pdparams_suffix(
  212. )
  213. ])))
  214. epoch_ids = []
  215. for params_path in params_path_list:
  216. epoch_id = self.get_epoch_id_by_pdparams_prefix(params_path.stem)
  217. epoch_ids.append(epoch_id)
  218. epoch_ids.sort()
  219. # TODO(gaotingquan): how to avoid that the latest ckp files is being saved
  220. # epoch_ids = epoch_ids[:-1]
  221. for i in range(1, self.last_k + 1):
  222. if len(epoch_ids) < i:
  223. break
  224. self.update_models(result, model, train_output, f"last_{i}",
  225. self.get_ith_ckp_prefix(epoch_ids[-i]))
  226. self.update_models(result, model, train_output, "best",
  227. self.get_best_ckp_prefix())
  228. return result
  229. def update_models(self, result, model, train_output, model_key, ckp_prefix):
  230. """ update info of the models to be saved """
  231. pdparams = train_output.joinpath(".".join(
  232. [ckp_prefix, self.get_the_pdparams_suffix()]))
  233. if pdparams.exists():
  234. recorder_key = f"{train_output.name}_{model_key}"
  235. if model_key != "best" and recorder_key in self.model_recorder and self.model_recorder[
  236. recorder_key] == pdparams:
  237. return
  238. self.model_recorder[recorder_key] = pdparams
  239. pdema = ""
  240. pdema_suffix = self.get_the_pdema_suffix()
  241. if pdema_suffix:
  242. pdema = pdparams.parent.joinpath(".".join(
  243. [ckp_prefix, pdema_suffix]))
  244. if not pdema.exists():
  245. pdema = ""
  246. pdopt = ""
  247. pdopt_suffix = self.get_the_pdopt_suffix()
  248. if pdopt_suffix:
  249. pdopt = pdparams.parent.joinpath(".".join(
  250. [ckp_prefix, pdopt_suffix]))
  251. if not pdopt.exists():
  252. pdopt = ""
  253. pdstates = ""
  254. pdstates_suffix = self.get_the_pdstates_suffix()
  255. if pdstates_suffix:
  256. pdstates = pdparams.parent.joinpath(".".join(
  257. [ckp_prefix, pdstates_suffix]))
  258. if not pdstates.exists():
  259. pdstates = ""
  260. score = self.get_score(Path(pdstates).resolve().as_posix())
  261. result["models"][model_key] = {
  262. "score": score,
  263. "pdparams": pdparams,
  264. "pdema": pdema,
  265. "pdopt": pdopt,
  266. "pdstates": pdstates
  267. }
  268. self.update_inference_model(model, pdparams,
  269. train_output.joinpath(f"{ckp_prefix}"),
  270. result["models"][model_key])
  271. def update_inference_model(self, model, weight_path, export_save_dir,
  272. result_the_model):
  273. """ update inference model """
  274. export_save_dir.mkdir(parents=True, exist_ok=True)
  275. export_result = model.export(
  276. weight_path=weight_path, save_dir=export_save_dir)
  277. if export_result.returncode == 0:
  278. inference_config = export_save_dir.joinpath("inference.yml")
  279. if not inference_config.exists():
  280. inference_config = ""
  281. pdmodel = export_save_dir.joinpath("inference.pdmodel")
  282. pdiparams = export_save_dir.joinpath("inference.pdiparams")
  283. pdiparams_info = export_save_dir.joinpath(
  284. "inference.pdiparams.info")
  285. else:
  286. inference_config = ""
  287. pdmodel = ""
  288. pdiparams = ""
  289. pdiparams_info = ""
  290. result_the_model["inference_config"] = inference_config
  291. result_the_model["pdmodel"] = pdmodel
  292. result_the_model["pdiparams"] = pdiparams
  293. result_the_model["pdiparams.info"] = pdiparams_info
  294. def init_pre_hook(self):
  295. """ hook func that would be called befor init """
  296. pass
  297. def init_post_hook(self):
  298. """ hook func that would be called after init """
  299. pass
  300. @abstractmethod
  301. def get_the_pdparams_suffix(self):
  302. """ get the suffix of pdparams file """
  303. raise NotImplementedError
  304. @abstractmethod
  305. def get_the_pdema_suffix(self):
  306. """ get the suffix of pdema file """
  307. raise NotImplementedError
  308. @abstractmethod
  309. def get_the_pdopt_suffix(self):
  310. """ get the suffix of pdopt file """
  311. raise NotImplementedError
  312. @abstractmethod
  313. def get_the_pdstates_suffix(self):
  314. """ get the suffix of pdstates file """
  315. raise NotImplementedError
  316. @abstractmethod
  317. def get_ith_ckp_prefix(self, epoch_id):
  318. """ get the prefix of the epoch_id checkpoint file """
  319. raise NotImplementedError
  320. @abstractmethod
  321. def get_best_ckp_prefix(self):
  322. """ get the prefix of the best checkpoint file """
  323. raise NotImplementedError
  324. @abstractmethod
  325. def get_score(self, pdstates_path):
  326. """ get the score by pdstates file """
  327. raise NotImplementedError
  328. @abstractmethod
  329. def get_epoch_id_by_pdparams_prefix(self, pdparams_prefix):
  330. """ get the epoch_id by pdparams file """
  331. raise NotImplementedError