base.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import paddle.fluid as fluid
  16. import os
  17. import sys
  18. import numpy as np
  19. import time
  20. import math
  21. import yaml
  22. import copy
  23. import json
  24. import functools
  25. import multiprocessing as mp
  26. import paddlex.utils.logging as logging
  27. from paddlex.utils import seconds_to_hms
  28. from paddlex.utils.utils import EarlyStop
  29. from paddlex.cv.transforms import arrange_transforms
  30. import paddlex
  31. from collections import OrderedDict
  32. from os import path as osp
  33. from paddle.fluid.framework import Program
  34. from .utils.pretrain_weights import get_pretrain_weights
  35. def dict2str(dict_input):
  36. out = ''
  37. for k, v in dict_input.items():
  38. try:
  39. v = round(float(v), 6)
  40. except:
  41. pass
  42. out = out + '{}={}, '.format(k, v)
  43. return out.strip(', ')
  44. class BaseAPI:
  45. def __init__(self, model_type):
  46. self.model_type = model_type
  47. # 现有的CV模型都有这个属性,而这个属且也需要在eval时用到
  48. self.num_classes = None
  49. self.labels = None
  50. self.version = paddlex.__version__
  51. if paddlex.env_info['place'] == 'cpu':
  52. self.places = fluid.cpu_places()
  53. else:
  54. self.places = fluid.cuda_places()
  55. self.exe = fluid.Executor(self.places[0])
  56. self.train_prog = None
  57. self.test_prog = None
  58. self.parallel_train_prog = None
  59. self.train_inputs = None
  60. self.test_inputs = None
  61. self.train_outputs = None
  62. self.test_outputs = None
  63. self.train_data_loader = None
  64. self.eval_metrics = None
  65. # 若模型是从inference model加载进来的,无法调用训练接口进行训练
  66. self.trainable = True
  67. # 是否使用多卡间同步BatchNorm均值和方差
  68. self.sync_bn = False
  69. # 当前模型状态
  70. self.status = 'Normal'
  71. # 已完成迭代轮数,为恢复训练时的起始轮数
  72. self.completed_epochs = 0
  73. self.scope = fluid.global_scope()
  74. # 线程池,在模型在预测时用于对输入数据以图片为单位进行并行处理
  75. # 主要用于batch_predict接口
  76. thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
  77. self.thread_pool = mp.pool.ThreadPool(thread_num)
  78. def reset_thread_pool(self, thread_num):
  79. self.thread_pool.close()
  80. self.thread_pool.join()
  81. self.thread_pool = mp.pool.ThreadPool(thread_num)
  82. def _get_single_card_bs(self, batch_size):
  83. if batch_size % len(self.places) == 0:
  84. return int(batch_size // len(self.places))
  85. else:
  86. raise Exception("Please support correct batch_size, \
  87. which can be divided by available cards({}) in {}"
  88. .format(paddlex.env_info['num'], paddlex.env_info[
  89. 'place']))
  90. def build_program(self):
  91. if hasattr(paddlex, 'model_built') and paddlex.model_built:
  92. logging.error(
  93. "Function model.train() only can be called once in your code.")
  94. paddlex.model_built = True
  95. # 构建训练网络
  96. self.train_inputs, self.train_outputs = self.build_net(mode='train')
  97. self.train_prog = fluid.default_main_program()
  98. startup_prog = fluid.default_startup_program()
  99. # 构建预测网络
  100. self.test_prog = fluid.Program()
  101. with fluid.program_guard(self.test_prog, startup_prog):
  102. with fluid.unique_name.guard():
  103. self.test_inputs, self.test_outputs = self.build_net(
  104. mode='test')
  105. self.test_prog = self.test_prog.clone(for_test=True)
  106. def build_train_data_loader(self, dataset, batch_size):
  107. # 初始化data_loader
  108. if self.train_data_loader is None:
  109. self.train_data_loader = fluid.io.DataLoader.from_generator(
  110. feed_list=list(self.train_inputs.values()),
  111. capacity=64,
  112. use_double_buffer=True,
  113. iterable=True)
  114. batch_size_each_gpu = self._get_single_card_bs(batch_size)
  115. generator = dataset.generator(
  116. batch_size=batch_size_each_gpu, drop_last=True)
  117. self.train_data_loader.set_sample_list_generator(
  118. dataset.generator(batch_size=batch_size_each_gpu),
  119. places=self.places)
  120. def export_quant_model(self,
  121. dataset,
  122. save_dir,
  123. batch_size=1,
  124. batch_num=10,
  125. cache_dir="./temp"):
  126. input_channel = getattr(self, 'input_channel', 3)
  127. arrange_transforms(
  128. model_type=self.model_type,
  129. class_name=self.__class__.__name__,
  130. transforms=dataset.transforms,
  131. mode='quant',
  132. input_channel=input_channel)
  133. dataset.num_samples = batch_size * batch_num
  134. import paddle
  135. version = paddle.__version__.strip().split('.')
  136. if version[0] == '2' or (version[0] == '0' and
  137. hasattr(paddle, 'enable_static')):
  138. from .slim.post_quantization import PaddleXPostTrainingQuantizationV2 as PaddleXPostTrainingQuantization
  139. else:
  140. from .slim.post_quantization import PaddleXPostTrainingQuantization
  141. PaddleXPostTrainingQuantization._collect_target_varnames
  142. is_use_cache_file = True
  143. if cache_dir is None:
  144. is_use_cache_file = False
  145. quant_prog = self.test_prog.clone(for_test=True)
  146. post_training_quantization = PaddleXPostTrainingQuantization(
  147. executor=self.exe,
  148. dataset=dataset,
  149. program=quant_prog,
  150. inputs=self.test_inputs,
  151. outputs=self.test_outputs,
  152. batch_size=batch_size,
  153. batch_nums=batch_num,
  154. scope=self.scope,
  155. algo='KL',
  156. quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
  157. is_full_quantize=False,
  158. is_use_cache_file=is_use_cache_file,
  159. cache_dir=cache_dir)
  160. post_training_quantization.quantize()
  161. post_training_quantization.save_quantized_model(save_dir)
  162. model_info = self.get_model_info()
  163. model_info['status'] = 'Quant'
  164. # 保存模型输出的变量描述
  165. model_info['_ModelInputsOutputs'] = dict()
  166. model_info['_ModelInputsOutputs']['test_inputs'] = [
  167. [k, v.name] for k, v in self.test_inputs.items()
  168. ]
  169. model_info['_ModelInputsOutputs']['test_outputs'] = [
  170. [k, v.name] for k, v in self.test_outputs.items()
  171. ]
  172. with open(
  173. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  174. mode='w') as f:
  175. yaml.dump(model_info, f)
  176. def net_initialize(self,
  177. startup_prog=None,
  178. pretrain_weights=None,
  179. fuse_bn=False,
  180. save_dir='.',
  181. sensitivities_file=None,
  182. eval_metric_loss=0.05,
  183. resume_checkpoint=None):
  184. if not resume_checkpoint:
  185. pretrain_dir = osp.join(save_dir, 'pretrain')
  186. if not os.path.isdir(pretrain_dir):
  187. if os.path.exists(pretrain_dir):
  188. os.remove(pretrain_dir)
  189. os.makedirs(pretrain_dir)
  190. if pretrain_weights is not None and not os.path.exists(
  191. pretrain_weights):
  192. if self.model_type == 'classifier':
  193. if pretrain_weights not in ['IMAGENET', 'BAIDU10W']:
  194. logging.warning(
  195. "Path of pretrain_weights('{}') is not exists!".
  196. format(pretrain_weights))
  197. logging.warning(
  198. "Pretrain_weights will be forced to set as 'IMAGENET', if you don't want to use pretrain weights, set pretrain_weights=None."
  199. )
  200. pretrain_weights = 'IMAGENET'
  201. elif self.model_type == 'detector':
  202. if pretrain_weights not in ['IMAGENET', 'COCO']:
  203. logging.warning(
  204. "Path of pretrain_weights('{}') is not exists!".
  205. format(pretrain_weights))
  206. logging.warning(
  207. "Pretrain_weights will be forced to set as 'IMAGENET', if you don't want to use pretrain weights, set pretrain_weights=None."
  208. )
  209. pretrain_weights = 'IMAGENET'
  210. elif self.model_type == 'segmenter':
  211. if pretrain_weights not in [
  212. 'IMAGENET', 'COCO', 'CITYSCAPES'
  213. ]:
  214. logging.warning(
  215. "Path of pretrain_weights('{}') is not exists!".
  216. format(pretrain_weights))
  217. logging.warning(
  218. "Pretrain_weights will be forced to set as 'IMAGENET', if you don't want to use pretrain weights, set pretrain_weights=None."
  219. )
  220. pretrain_weights = 'IMAGENET'
  221. if hasattr(self, 'backbone'):
  222. backbone = self.backbone
  223. else:
  224. backbone = self.__class__.__name__
  225. if backbone == "HRNet":
  226. backbone = backbone + "_W{}".format(self.width)
  227. class_name = self.__class__.__name__
  228. pretrain_weights = get_pretrain_weights(
  229. pretrain_weights, class_name, backbone, pretrain_dir)
  230. if startup_prog is None:
  231. startup_prog = fluid.default_startup_program()
  232. self.exe.run(startup_prog)
  233. if resume_checkpoint:
  234. logging.info(
  235. "Resume checkpoint from {}.".format(resume_checkpoint),
  236. use_color=True)
  237. paddlex.utils.utils.load_pretrain_weights(
  238. self.exe, self.train_prog, resume_checkpoint, resume=True)
  239. if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
  240. raise Exception("There's not model.yml in {}".format(
  241. resume_checkpoint))
  242. with open(osp.join(resume_checkpoint, "model.yml")) as f:
  243. info = yaml.load(f.read(), Loader=yaml.Loader)
  244. self.completed_epochs = info['completed_epochs']
  245. elif pretrain_weights is not None:
  246. logging.info(
  247. "Load pretrain weights from {}.".format(pretrain_weights),
  248. use_color=True)
  249. paddlex.utils.utils.load_pretrain_weights(
  250. self.exe, self.train_prog, pretrain_weights, fuse_bn)
  251. # 进行裁剪
  252. if sensitivities_file is not None:
  253. import paddleslim
  254. from .slim.prune_config import get_sensitivities
  255. sensitivities_file = get_sensitivities(sensitivities_file, self,
  256. save_dir)
  257. from .slim.prune import get_params_ratios, prune_program
  258. logging.info(
  259. "Start to prune program with eval_metric_loss = {}".format(
  260. eval_metric_loss),
  261. use_color=True)
  262. origin_flops = paddleslim.analysis.flops(self.test_prog)
  263. prune_params_ratios = get_params_ratios(
  264. sensitivities_file, eval_metric_loss=eval_metric_loss)
  265. prune_program(self, prune_params_ratios)
  266. current_flops = paddleslim.analysis.flops(self.test_prog)
  267. remaining_ratio = current_flops / origin_flops
  268. logging.info(
  269. "Finish prune program, before FLOPs:{}, after prune FLOPs:{}, remaining ratio:{}"
  270. .format(origin_flops, current_flops, remaining_ratio),
  271. use_color=True)
  272. self.status = 'Prune'
  273. def get_model_info(self):
  274. info = dict()
  275. info['version'] = paddlex.__version__
  276. info['Model'] = self.__class__.__name__
  277. info['_Attributes'] = {'model_type': self.model_type}
  278. if 'self' in self.init_params:
  279. del self.init_params['self']
  280. if '__class__' in self.init_params:
  281. del self.init_params['__class__']
  282. if 'model_name' in self.init_params:
  283. del self.init_params['model_name']
  284. info['_init_params'] = self.init_params
  285. info['_Attributes']['num_classes'] = self.num_classes
  286. info['_Attributes']['labels'] = self.labels
  287. info['_Attributes']['fixed_input_shape'] = self.fixed_input_shape
  288. try:
  289. primary_metric_key = list(self.eval_metrics.keys())[0]
  290. primary_metric_value = float(self.eval_metrics[primary_metric_key])
  291. info['_Attributes']['eval_metrics'] = {
  292. primary_metric_key: primary_metric_value
  293. }
  294. except:
  295. pass
  296. if hasattr(self, 'test_transforms'):
  297. if hasattr(self.test_transforms, 'to_rgb'):
  298. if self.test_transforms.to_rgb:
  299. info['TransformsMode'] = 'RGB'
  300. else:
  301. info['TransformsMode'] = 'BGR'
  302. if self.test_transforms is not None:
  303. info['Transforms'] = list()
  304. for op in self.test_transforms.transforms:
  305. name = op.__class__.__name__
  306. if name.startswith('Arrange'):
  307. continue
  308. attr = op.__dict__
  309. info['Transforms'].append({name: attr})
  310. info['completed_epochs'] = self.completed_epochs
  311. return info
  312. def save_model(self, save_dir):
  313. if not osp.isdir(save_dir):
  314. if osp.exists(save_dir):
  315. os.remove(save_dir)
  316. os.makedirs(save_dir)
  317. if self.train_prog is not None:
  318. fluid.save(self.train_prog, osp.join(save_dir, 'model'))
  319. else:
  320. fluid.save(self.test_prog, osp.join(save_dir, 'model'))
  321. model_info = self.get_model_info()
  322. model_info['status'] = self.status
  323. with open(
  324. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  325. mode='w') as f:
  326. yaml.dump(model_info, f)
  327. # 评估结果保存
  328. if hasattr(self, 'eval_details'):
  329. with open(osp.join(save_dir, 'eval_details.json'), 'w') as f:
  330. json.dump(self.eval_details, f)
  331. if self.status == 'Prune':
  332. # 保存裁剪的shape
  333. shapes = {}
  334. for block in self.train_prog.blocks:
  335. for param in block.all_parameters():
  336. pd_var = fluid.global_scope().find_var(param.name)
  337. pd_param = pd_var.get_tensor()
  338. shapes[param.name] = np.array(pd_param).shape
  339. with open(
  340. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  341. mode='w') as f:
  342. yaml.dump(shapes, f)
  343. # 模型保存成功的标志
  344. open(osp.join(save_dir, '.success'), 'w').close()
  345. logging.info("Model saved in {}.".format(save_dir))
  346. def export_inference_model(self, save_dir):
  347. test_input_names = [
  348. var.name for var in list(self.test_inputs.values())
  349. ]
  350. test_outputs = list(self.test_outputs.values())
  351. save_prog = self.test_prog.clone(for_test=True)
  352. with fluid.scope_guard(self.scope):
  353. fluid.io.save_inference_model(
  354. dirname=save_dir,
  355. executor=self.exe,
  356. params_filename='__params__',
  357. feeded_var_names=test_input_names,
  358. target_vars=test_outputs,
  359. main_program=save_prog)
  360. model_info = self.get_model_info()
  361. model_info['status'] = 'Infer'
  362. # 保存模型输出的变量描述
  363. model_info['_ModelInputsOutputs'] = dict()
  364. model_info['_ModelInputsOutputs']['test_inputs'] = [
  365. [k, v.name] for k, v in self.test_inputs.items()
  366. ]
  367. model_info['_ModelInputsOutputs']['test_outputs'] = [
  368. [k, v.name] for k, v in self.test_outputs.items()
  369. ]
  370. with open(
  371. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  372. mode='w') as f:
  373. yaml.dump(model_info, f)
  374. # 模型保存成功的标志
  375. open(osp.join(save_dir, '.success'), 'w').close()
  376. logging.info("Model for inference deploy saved in {}.".format(
  377. save_dir))
  378. def train_loop(self,
  379. num_epochs,
  380. train_dataset,
  381. train_batch_size,
  382. eval_dataset=None,
  383. save_interval_epochs=1,
  384. log_interval_steps=10,
  385. save_dir='output',
  386. use_vdl=False,
  387. early_stop=False,
  388. early_stop_patience=5):
  389. if train_dataset.num_samples < train_batch_size:
  390. raise Exception(
  391. 'The amount of training datset must be larger than batch size.')
  392. if not osp.isdir(save_dir):
  393. if osp.exists(save_dir):
  394. os.remove(save_dir)
  395. os.makedirs(save_dir)
  396. if use_vdl:
  397. from visualdl import LogWriter
  398. vdl_logdir = osp.join(save_dir, 'vdl_log')
  399. # 给transform添加arrange操作
  400. input_channel = getattr(self, 'input_channel', 3)
  401. arrange_transforms(
  402. model_type=self.model_type,
  403. class_name=self.__class__.__name__,
  404. transforms=train_dataset.transforms,
  405. mode='train',
  406. input_channel=input_channel)
  407. # 构建train_data_loader
  408. self.build_train_data_loader(
  409. dataset=train_dataset, batch_size=train_batch_size)
  410. if eval_dataset is not None:
  411. self.eval_transforms = eval_dataset.transforms
  412. self.test_transforms = copy.deepcopy(eval_dataset.transforms)
  413. # 获取实时变化的learning rate
  414. lr = self.optimizer._learning_rate
  415. if isinstance(lr, fluid.framework.Variable):
  416. self.train_outputs['lr'] = lr
  417. # 在多卡上跑训练
  418. if self.parallel_train_prog is None:
  419. build_strategy = fluid.compiler.BuildStrategy()
  420. build_strategy.fuse_all_optimizer_ops = False
  421. if paddlex.env_info['place'] != 'cpu' and len(self.places) > 1:
  422. build_strategy.sync_batch_norm = self.sync_bn
  423. exec_strategy = fluid.ExecutionStrategy()
  424. exec_strategy.num_iteration_per_drop_scope = 1
  425. self.parallel_train_prog = fluid.CompiledProgram(
  426. self.train_prog).with_data_parallel(
  427. loss_name=self.train_outputs['loss'].name,
  428. build_strategy=build_strategy,
  429. exec_strategy=exec_strategy)
  430. total_num_steps = math.floor(train_dataset.num_samples /
  431. train_batch_size)
  432. num_steps = 0
  433. time_stat = list()
  434. time_train_one_epoch = None
  435. time_eval_one_epoch = None
  436. total_num_steps_eval = 0
  437. # 模型总共的评估次数
  438. total_eval_times = math.ceil(num_epochs / save_interval_epochs)
  439. # 检测目前仅支持单卡评估,训练数据batch大小与显卡数量之商为验证数据batch大小。
  440. eval_batch_size = train_batch_size
  441. if self.model_type == 'detector':
  442. eval_batch_size = self._get_single_card_bs(train_batch_size)
  443. if eval_dataset is not None:
  444. total_num_steps_eval = math.ceil(eval_dataset.num_samples /
  445. eval_batch_size)
  446. if use_vdl:
  447. # VisualDL component
  448. log_writer = LogWriter(vdl_logdir)
  449. thresh = 0.0001
  450. if early_stop:
  451. earlystop = EarlyStop(early_stop_patience, thresh)
  452. best_accuracy_key = ""
  453. best_accuracy = -1.0
  454. best_model_epoch = -1
  455. start_epoch = self.completed_epochs
  456. # task_id: 目前由PaddleX GUI赋值
  457. # 用于在VisualDL日志中注明所属任务id
  458. task_id = getattr(paddlex, "task_id", "")
  459. for i in range(start_epoch, num_epochs):
  460. records = list()
  461. step_start_time = time.time()
  462. epoch_start_time = time.time()
  463. for step, data in enumerate(self.train_data_loader()):
  464. outputs = self.exe.run(
  465. self.parallel_train_prog,
  466. feed=data,
  467. fetch_list=list(self.train_outputs.values()))
  468. outputs_avg = np.mean(np.array(outputs), axis=1)
  469. records.append(outputs_avg)
  470. # 训练完成剩余时间预估
  471. current_time = time.time()
  472. step_cost_time = current_time - step_start_time
  473. step_start_time = current_time
  474. if len(time_stat) < 20:
  475. time_stat.append(step_cost_time)
  476. else:
  477. time_stat[num_steps % 20] = step_cost_time
  478. # 每间隔log_interval_steps,输出loss信息
  479. num_steps += 1
  480. if num_steps % log_interval_steps == 0:
  481. step_metrics = OrderedDict(
  482. zip(list(self.train_outputs.keys()), outputs_avg))
  483. if use_vdl:
  484. for k, v in step_metrics.items():
  485. log_writer.add_scalar(
  486. '{}-Metrics/Training(Step): {}'.format(
  487. task_id, k), v, num_steps)
  488. # 估算剩余时间
  489. avg_step_time = np.mean(time_stat)
  490. if time_train_one_epoch is not None:
  491. eta = (num_epochs - i - 1) * time_train_one_epoch + (
  492. total_num_steps - step - 1) * avg_step_time
  493. else:
  494. eta = ((num_epochs - i) * total_num_steps - step - 1
  495. ) * avg_step_time
  496. if time_eval_one_epoch is not None:
  497. eval_eta = (
  498. total_eval_times - i // save_interval_epochs
  499. ) * time_eval_one_epoch
  500. else:
  501. eval_eta = (
  502. total_eval_times - i // save_interval_epochs
  503. ) * total_num_steps_eval * avg_step_time
  504. eta_str = seconds_to_hms(eta + eval_eta)
  505. logging.info(
  506. "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
  507. .format(i + 1, num_epochs, step + 1, total_num_steps,
  508. dict2str(step_metrics),
  509. round(avg_step_time, 2), eta_str))
  510. train_metrics = OrderedDict(
  511. zip(list(self.train_outputs.keys()), np.mean(
  512. records, axis=0)))
  513. logging.info('[TRAIN] Epoch {} finished, {} .'.format(
  514. i + 1, dict2str(train_metrics)))
  515. time_train_one_epoch = time.time() - epoch_start_time
  516. epoch_start_time = time.time()
  517. # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
  518. self.completed_epochs += 1
  519. eval_epoch_start_time = time.time()
  520. if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
  521. current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
  522. if not osp.isdir(current_save_dir):
  523. os.makedirs(current_save_dir)
  524. if getattr(self, 'use_ema', False):
  525. self.exe.run(self.ema.apply_program)
  526. if eval_dataset is not None and eval_dataset.num_samples > 0:
  527. self.eval_metrics, self.eval_details = self.evaluate(
  528. eval_dataset=eval_dataset,
  529. batch_size=eval_batch_size,
  530. epoch_id=i + 1,
  531. return_details=True)
  532. logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
  533. i + 1, dict2str(self.eval_metrics)))
  534. # 保存最优模型
  535. best_accuracy_key = list(self.eval_metrics.keys())[0]
  536. current_accuracy = self.eval_metrics[best_accuracy_key]
  537. if current_accuracy > best_accuracy:
  538. best_accuracy = current_accuracy
  539. best_model_epoch = i + 1
  540. best_model_dir = osp.join(save_dir, "best_model")
  541. self.save_model(save_dir=best_model_dir)
  542. if use_vdl:
  543. for k, v in self.eval_metrics.items():
  544. if isinstance(v, list):
  545. continue
  546. if isinstance(v, np.ndarray):
  547. if v.size > 1:
  548. continue
  549. log_writer.add_scalar(
  550. "{}-Metrics/Eval(Epoch): {}".format(
  551. task_id, k), v, i + 1)
  552. self.save_model(save_dir=current_save_dir)
  553. if getattr(self, 'use_ema', False):
  554. self.exe.run(self.ema.restore_program)
  555. time_eval_one_epoch = time.time() - eval_epoch_start_time
  556. eval_epoch_start_time = time.time()
  557. if best_model_epoch > 0:
  558. logging.info(
  559. 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
  560. .format(best_model_epoch, best_accuracy_key,
  561. best_accuracy))
  562. if eval_dataset is not None and early_stop:
  563. if earlystop(current_accuracy):
  564. break