train_deamon.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import time
  17. import json
  18. import traceback
  19. import threading
  20. from abc import ABC, abstractmethod
  21. from pathlib import Path
  22. import lazy_paddle as paddle
  23. from ..build_model import build_model
  24. from ....utils.file_interface import write_json_file
  25. from ....utils import logging
  26. def try_except_decorator(func):
  27. """try-except"""
  28. def wrap(self, *args, **kwargs):
  29. try:
  30. func(self, *args, **kwargs)
  31. except Exception as e:
  32. exc_type, exc_value, exc_tb = sys.exc_info()
  33. self.save_json()
  34. traceback.print_exception(exc_type, exc_value, exc_tb)
  35. finally:
  36. self.processing = False
  37. return wrap
  38. class BaseTrainDeamon(ABC):
  39. """BaseTrainResultDemon"""
  40. update_interval = 600
  41. last_k = 5
  42. def __init__(self, config):
  43. """init"""
  44. self.global_config = config.Global
  45. self.disable_deamon = config.get("Benchmark", {}).get("disable_deamon", False)
  46. self.init_pre_hook()
  47. self.output = self.global_config.output
  48. self.train_outputs = self.get_train_outputs()
  49. self.save_paths = self.get_save_paths()
  50. self.results = self.init_train_result()
  51. self.save_json()
  52. self.models = {}
  53. self.init_post_hook()
  54. self.config_recorder = {}
  55. self.model_recorder = {}
  56. self.processing = False
  57. self.start()
  58. def init_train_result(self):
  59. """init train result structure"""
  60. model_names = self.init_model_names()
  61. configs = self.init_configs()
  62. train_log = self.init_train_log()
  63. vdl = self.init_vdl_log()
  64. results = []
  65. for i, model_name in enumerate(model_names):
  66. results.append(
  67. {
  68. "model_name": model_name,
  69. "done_flag": False,
  70. "config": configs[i],
  71. "label_dict": "",
  72. "train_log": train_log,
  73. "visualdl_log": vdl,
  74. "models": self.init_model_pkg(),
  75. }
  76. )
  77. return results
  78. def get_save_names(self):
  79. """get names to save"""
  80. return ["train_result.json"]
  81. def get_train_outputs(self):
  82. """get training outputs dir"""
  83. return [Path(self.output)]
  84. def init_model_names(self):
  85. """get models name"""
  86. return [self.global_config.model]
  87. def get_save_paths(self):
  88. """get the path to save train_result.json"""
  89. return [Path(self.output, save_name) for save_name in self.get_save_names()]
  90. def init_configs(self):
  91. """get the init value of config field in result"""
  92. return [""] * len(self.init_model_names())
  93. def init_train_log(self):
  94. """get train log"""
  95. return ""
  96. def init_vdl_log(self):
  97. """get visualdl log"""
  98. return ""
  99. def init_model_pkg(self):
  100. """get model package"""
  101. init_content = self.init_model_content()
  102. model_pkg = {}
  103. for pkg in self.get_watched_model():
  104. model_pkg[pkg] = init_content
  105. return model_pkg
  106. def normlize_path(self, dict_obj, relative_to):
  107. """normlize path to string type path relative to the output"""
  108. for key in dict_obj:
  109. if isinstance(dict_obj[key], dict):
  110. self.normlize_path(dict_obj[key], relative_to)
  111. if isinstance(dict_obj[key], Path):
  112. dict_obj[key] = (
  113. dict_obj[key]
  114. .resolve()
  115. .relative_to(relative_to.resolve())
  116. .as_posix()
  117. )
  118. def save_json(self):
  119. """save result to json"""
  120. for i, result in enumerate(self.results):
  121. self.save_paths[i].parent.mkdir(parents=True, exist_ok=True)
  122. self.normlize_path(result, relative_to=self.save_paths[i].parent)
  123. write_json_file(result, self.save_paths[i], indent=2)
  124. def start(self):
  125. """start deamon thread"""
  126. self.exit = False
  127. self.thread = threading.Thread(target=self.run)
  128. self.thread.daemon = True
  129. if not self.disable_deamon:
  130. self.thread.start()
  131. def stop_hook(self):
  132. """hook befor stop"""
  133. for result in self.results:
  134. result["done_flag"] = True
  135. self.update()
  136. def stop(self):
  137. """stop self"""
  138. self.exit = True
  139. while True:
  140. if not self.processing:
  141. self.stop_hook()
  142. break
  143. time.sleep(60)
  144. def run(self):
  145. """main function"""
  146. while not self.exit:
  147. self.update()
  148. if self.exit:
  149. break
  150. time.sleep(self.update_interval)
  151. def update_train_log(self, train_output):
  152. """update train log"""
  153. train_log_path = train_output / "train.log"
  154. if train_log_path.exists():
  155. return train_log_path
  156. def update_vdl_log(self, train_output):
  157. """update visualdl log"""
  158. vdl_path = list(train_output.glob("vdlrecords*log"))
  159. if len(vdl_path) >= 1:
  160. return vdl_path[0]
  161. def update_label_dict(self, train_output):
  162. """update label dict"""
  163. dict_path = train_output.joinpath("label_dict.txt")
  164. if not dict_path.exists():
  165. return ""
  166. return dict_path
  167. @try_except_decorator
  168. def update(self):
  169. """update train result json"""
  170. self.processing = True
  171. for i in range(len(self.results)):
  172. self.results[i] = self.update_result(self.results[i], self.train_outputs[i])
  173. self.save_json()
  174. self.processing = False
  175. def get_model(self, model_name, config_path):
  176. """initialize the model"""
  177. if model_name not in self.models:
  178. config, model = build_model(
  179. model_name,
  180. # using CPU to export model
  181. device="cpu",
  182. config_path=config_path,
  183. )
  184. self.models[model_name] = model
  185. return self.models[model_name]
  186. def get_watched_model(self):
  187. """get the models needed to be watched"""
  188. watched_models = [f"last_{i}" for i in range(1, self.last_k + 1)]
  189. watched_models.append("best")
  190. return watched_models
  191. def init_model_content(self):
  192. """get model content structure"""
  193. return {
  194. "score": "",
  195. "pdparams": "",
  196. "pdema": "",
  197. "pdopt": "",
  198. "pdstates": "",
  199. "inference_config": "",
  200. "pdmodel": "",
  201. "pdiparams": "",
  202. "pdiparams.info": "",
  203. }
  204. def update_result(self, result, train_output):
  205. """update every result"""
  206. train_output = Path(train_output).resolve()
  207. config_path = train_output.joinpath("config.yaml").resolve()
  208. if not config_path.exists():
  209. return result
  210. model_name = result["model_name"]
  211. if (
  212. model_name in self.config_recorder
  213. and self.config_recorder[model_name] != config_path
  214. ):
  215. result["models"] = self.init_model_pkg()
  216. result["config"] = config_path
  217. self.config_recorder[model_name] = config_path
  218. result["train_log"] = self.update_train_log(train_output)
  219. result["visualdl_log"] = self.update_vdl_log(train_output)
  220. result["label_dict"] = self.update_label_dict(train_output)
  221. model = self.get_model(result["model_name"], config_path)
  222. params_path_list = list(
  223. train_output.glob(
  224. ".".join(
  225. [self.get_ith_ckp_prefix("[0-9]*"), self.get_the_pdparams_suffix()]
  226. )
  227. )
  228. )
  229. epoch_ids = []
  230. for params_path in params_path_list:
  231. epoch_id = self.get_epoch_id_by_pdparams_prefix(params_path.stem)
  232. epoch_ids.append(epoch_id)
  233. epoch_ids.sort()
  234. # TODO(gaotingquan): how to avoid that the latest ckp files is being saved
  235. # epoch_ids = epoch_ids[:-1]
  236. for i in range(1, self.last_k + 1):
  237. if len(epoch_ids) < i:
  238. break
  239. self.update_models(
  240. result,
  241. model,
  242. train_output,
  243. f"last_{i}",
  244. self.get_ith_ckp_prefix(epoch_ids[-i]),
  245. )
  246. self.update_models(
  247. result, model, train_output, "best", self.get_best_ckp_prefix()
  248. )
  249. return result
  250. def update_models(self, result, model, train_output, model_key, ckp_prefix):
  251. """update info of the models to be saved"""
  252. pdparams = train_output.joinpath(
  253. ".".join([ckp_prefix, self.get_the_pdparams_suffix()])
  254. )
  255. if pdparams.exists():
  256. recorder_key = f"{train_output.name}_{model_key}"
  257. if (
  258. model_key != "best"
  259. and recorder_key in self.model_recorder
  260. and self.model_recorder[recorder_key] == pdparams
  261. ):
  262. return
  263. self.model_recorder[recorder_key] = pdparams
  264. pdema = ""
  265. pdema_suffix = self.get_the_pdema_suffix()
  266. if pdema_suffix:
  267. pdema = pdparams.parent.joinpath(".".join([ckp_prefix, pdema_suffix]))
  268. if not pdema.exists():
  269. pdema = ""
  270. pdopt = ""
  271. pdopt_suffix = self.get_the_pdopt_suffix()
  272. if pdopt_suffix:
  273. pdopt = pdparams.parent.joinpath(".".join([ckp_prefix, pdopt_suffix]))
  274. if not pdopt.exists():
  275. pdopt = ""
  276. pdstates = ""
  277. pdstates_suffix = self.get_the_pdstates_suffix()
  278. if pdstates_suffix:
  279. pdstates = pdparams.parent.joinpath(
  280. ".".join([ckp_prefix, pdstates_suffix])
  281. )
  282. if not pdstates.exists():
  283. pdstates = ""
  284. score = self.get_score(Path(pdstates).resolve().as_posix())
  285. result["models"][model_key] = {
  286. "score": score,
  287. "pdparams": pdparams,
  288. "pdema": pdema,
  289. "pdopt": pdopt,
  290. "pdstates": pdstates,
  291. }
  292. self.update_inference_model(
  293. model,
  294. pdparams,
  295. train_output.joinpath(f"{ckp_prefix}"),
  296. result["models"][model_key],
  297. )
  298. def update_inference_model(
  299. self, model, weight_path, export_save_dir, result_the_model
  300. ):
  301. """update inference model"""
  302. export_save_dir.mkdir(parents=True, exist_ok=True)
  303. export_result = model.export(
  304. weight_path=str(weight_path), save_dir=export_save_dir
  305. )
  306. if export_result.returncode == 0:
  307. inference_config = export_save_dir.joinpath("inference.yml")
  308. if not inference_config.exists():
  309. inference_config = ""
  310. use_pir = (
  311. hasattr(paddle.framework, "use_pir_api")
  312. and paddle.framework.use_pir_api()
  313. )
  314. pdmodel = (
  315. export_save_dir.joinpath("inference.json")
  316. if use_pir
  317. else export_save_dir.joinpath("inference.pdmodel")
  318. )
  319. pdiparams = export_save_dir.joinpath("inference.pdiparams")
  320. pdiparams_info = (
  321. "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
  322. )
  323. else:
  324. inference_config = ""
  325. pdmodel = ""
  326. pdiparams = ""
  327. pdiparams_info = ""
  328. result_the_model["inference_config"] = inference_config
  329. result_the_model["pdmodel"] = pdmodel
  330. result_the_model["pdiparams"] = pdiparams
  331. result_the_model["pdiparams.info"] = pdiparams_info
  332. def init_pre_hook(self):
  333. """hook func that would be called befor init"""
  334. pass
  335. def init_post_hook(self):
  336. """hook func that would be called after init"""
  337. pass
  338. @abstractmethod
  339. def get_the_pdparams_suffix(self):
  340. """get the suffix of pdparams file"""
  341. raise NotImplementedError
  342. @abstractmethod
  343. def get_the_pdema_suffix(self):
  344. """get the suffix of pdema file"""
  345. raise NotImplementedError
  346. @abstractmethod
  347. def get_the_pdopt_suffix(self):
  348. """get the suffix of pdopt file"""
  349. raise NotImplementedError
  350. @abstractmethod
  351. def get_the_pdstates_suffix(self):
  352. """get the suffix of pdstates file"""
  353. raise NotImplementedError
  354. @abstractmethod
  355. def get_ith_ckp_prefix(self, epoch_id):
  356. """get the prefix of the epoch_id checkpoint file"""
  357. raise NotImplementedError
  358. @abstractmethod
  359. def get_best_ckp_prefix(self):
  360. """get the prefix of the best checkpoint file"""
  361. raise NotImplementedError
  362. @abstractmethod
  363. def get_score(self, pdstates_path):
  364. """get the score by pdstates file"""
  365. raise NotImplementedError
  366. @abstractmethod
  367. def get_epoch_id_by_pdparams_prefix(self, pdparams_prefix):
  368. """get the epoch_id by pdparams file"""
  369. raise NotImplementedError