base.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. from __future__ import absolute_import
  15. import paddle.fluid as fluid
  16. import os
  17. import sys
  18. import numpy as np
  19. import time
  20. import math
  21. import yaml
  22. import copy
  23. import json
  24. import functools
  25. import paddlex.utils.logging as logging
  26. from paddlex.utils import seconds_to_hms
  27. from paddlex.utils.utils import EarlyStop
  28. import paddlex
  29. from collections import OrderedDict
  30. from os import path as osp
  31. from paddle.fluid.framework import Program
  32. from .utils.pretrain_weights import get_pretrain_weights
  33. def dict2str(dict_input):
  34. out = ''
  35. for k, v in dict_input.items():
  36. try:
  37. v = round(float(v), 6)
  38. except:
  39. pass
  40. out = out + '{}={}, '.format(k, v)
  41. return out.strip(', ')
  42. class BaseAPI:
  43. def __init__(self, model_type):
  44. self.model_type = model_type
  45. # 现有的CV模型都有这个属性,而这个属且也需要在eval时用到
  46. self.num_classes = None
  47. self.labels = None
  48. self.version = paddlex.__version__
  49. if paddlex.env_info['place'] == 'cpu':
  50. self.places = fluid.cpu_places()
  51. else:
  52. self.places = fluid.cuda_places()
  53. self.exe = fluid.Executor(self.places[0])
  54. self.train_prog = None
  55. self.test_prog = None
  56. self.parallel_train_prog = None
  57. self.train_inputs = None
  58. self.test_inputs = None
  59. self.train_outputs = None
  60. self.test_outputs = None
  61. self.train_data_loader = None
  62. self.eval_metrics = None
  63. # 若模型是从inference model加载进来的,无法调用训练接口进行训练
  64. self.trainable = True
  65. # 是否使用多卡间同步BatchNorm均值和方差
  66. self.sync_bn = False
  67. # 当前模型状态
  68. self.status = 'Normal'
  69. def _get_single_card_bs(self, batch_size):
  70. if batch_size % len(self.places) == 0:
  71. return int(batch_size // len(self.places))
  72. else:
  73. raise Exception("Please support correct batch_size, \
  74. which can be divided by available cards({}) in {}".
  75. format(paddlex.env_info['num'],
  76. paddlex.env_info['place']))
  77. def build_program(self):
  78. # 构建训练网络
  79. self.train_inputs, self.train_outputs = self.build_net(mode='train')
  80. self.train_prog = fluid.default_main_program()
  81. startup_prog = fluid.default_startup_program()
  82. # 构建预测网络
  83. self.test_prog = fluid.Program()
  84. with fluid.program_guard(self.test_prog, startup_prog):
  85. with fluid.unique_name.guard():
  86. self.test_inputs, self.test_outputs = self.build_net(
  87. mode='test')
  88. self.test_prog = self.test_prog.clone(for_test=True)
  89. def arrange_transforms(self, transforms, mode='train'):
  90. # 给transforms添加arrange操作
  91. if self.model_type == 'classifier':
  92. arrange_transform = paddlex.cls.transforms.ArrangeClassifier
  93. elif self.model_type == 'segmenter':
  94. arrange_transform = paddlex.seg.transforms.ArrangeSegmenter
  95. elif self.model_type == 'detector':
  96. arrange_name = 'Arrange{}'.format(self.__class__.__name__)
  97. arrange_transform = getattr(paddlex.det.transforms, arrange_name)
  98. else:
  99. raise Exception("Unrecognized model type: {}".format(
  100. self.model_type))
  101. if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
  102. transforms.transforms[-1] = arrange_transform(mode=mode)
  103. else:
  104. transforms.transforms.append(arrange_transform(mode=mode))
  105. def build_train_data_loader(self, dataset, batch_size):
  106. # 初始化data_loader
  107. if self.train_data_loader is None:
  108. self.train_data_loader = fluid.io.DataLoader.from_generator(
  109. feed_list=list(self.train_inputs.values()),
  110. capacity=64,
  111. use_double_buffer=True,
  112. iterable=True)
  113. batch_size_each_gpu = self._get_single_card_bs(batch_size)
  114. generator = dataset.generator(
  115. batch_size=batch_size_each_gpu, drop_last=True)
  116. self.train_data_loader.set_sample_list_generator(
  117. dataset.generator(batch_size=batch_size_each_gpu),
  118. places=self.places)
  119. def export_quant_model(self,
  120. dataset,
  121. save_dir,
  122. batch_size=1,
  123. batch_num=10,
  124. cache_dir="./temp"):
  125. self.arrange_transforms(transforms=dataset.transforms, mode='quant')
  126. dataset.num_samples = batch_size * batch_num
  127. try:
  128. from .slim.post_quantization import PaddleXPostTrainingQuantization
  129. except:
  130. raise Exception(
  131. "Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
  132. )
  133. is_use_cache_file = True
  134. if cache_dir is None:
  135. is_use_cache_file = False
  136. post_training_quantization = PaddleXPostTrainingQuantization(
  137. executor=self.exe,
  138. dataset=dataset,
  139. program=self.test_prog,
  140. inputs=self.test_inputs,
  141. outputs=self.test_outputs,
  142. batch_size=batch_size,
  143. batch_nums=batch_num,
  144. scope=None,
  145. algo='KL',
  146. quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
  147. is_full_quantize=False,
  148. is_use_cache_file=is_use_cache_file,
  149. cache_dir=cache_dir)
  150. post_training_quantization.quantize()
  151. post_training_quantization.save_quantized_model(save_dir)
  152. model_info = self.get_model_info()
  153. model_info['status'] = 'Quant'
  154. # 保存模型输出的变量描述
  155. model_info['_ModelInputsOutputs'] = dict()
  156. model_info['_ModelInputsOutputs']['test_inputs'] = [
  157. [k, v.name] for k, v in self.test_inputs.items()
  158. ]
  159. model_info['_ModelInputsOutputs']['test_outputs'] = [
  160. [k, v.name] for k, v in self.test_outputs.items()
  161. ]
  162. with open(
  163. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  164. mode='w') as f:
  165. yaml.dump(model_info, f)
  166. def net_initialize(self,
  167. startup_prog=None,
  168. pretrain_weights=None,
  169. fuse_bn=False,
  170. save_dir='.',
  171. sensitivities_file=None,
  172. eval_metric_loss=0.05):
  173. pretrain_dir = osp.join(save_dir, 'pretrain')
  174. if not os.path.isdir(pretrain_dir):
  175. if os.path.exists(pretrain_dir):
  176. os.remove(pretrain_dir)
  177. os.makedirs(pretrain_dir)
  178. if hasattr(self, 'backbone'):
  179. backbone = self.backbone
  180. else:
  181. backbone = self.__class__.__name__
  182. pretrain_weights = get_pretrain_weights(
  183. pretrain_weights, self.model_type, backbone, pretrain_dir)
  184. if startup_prog is None:
  185. startup_prog = fluid.default_startup_program()
  186. self.exe.run(startup_prog)
  187. if pretrain_weights is not None:
  188. logging.info(
  189. "Load pretrain weights from {}.".format(pretrain_weights))
  190. paddlex.utils.utils.load_pretrain_weights(
  191. self.exe, self.train_prog, pretrain_weights, fuse_bn)
  192. # 进行裁剪
  193. if sensitivities_file is not None:
  194. from .slim.prune_config import get_sensitivities
  195. sensitivities_file = get_sensitivities(sensitivities_file, self,
  196. save_dir)
  197. from .slim.prune import get_params_ratios, prune_program
  198. prune_params_ratios = get_params_ratios(
  199. sensitivities_file, eval_metric_loss=eval_metric_loss)
  200. prune_program(self, prune_params_ratios)
  201. self.status = 'Prune'
  202. def get_model_info(self):
  203. info = dict()
  204. info['version'] = paddlex.__version__
  205. info['Model'] = self.__class__.__name__
  206. info['_Attributes'] = {'model_type': self.model_type}
  207. if 'self' in self.init_params:
  208. del self.init_params['self']
  209. if '__class__' in self.init_params:
  210. del self.init_params['__class__']
  211. info['_init_params'] = self.init_params
  212. info['_Attributes']['num_classes'] = self.num_classes
  213. info['_Attributes']['labels'] = self.labels
  214. try:
  215. primary_metric_key = list(self.eval_metrics.keys())[0]
  216. primary_metric_value = float(self.eval_metrics[primary_metric_key])
  217. info['_Attributes']['eval_metrics'] = {
  218. primary_metric_key: primary_metric_value
  219. }
  220. except:
  221. pass
  222. if hasattr(self.test_transforms, 'to_rgb'):
  223. if self.test_transforms.to_rgb:
  224. info['TransformsMode'] = 'RGB'
  225. else:
  226. info['TransformsMode'] = 'BGR'
  227. if hasattr(self, 'test_transforms'):
  228. if self.test_transforms is not None:
  229. info['Transforms'] = list()
  230. for op in self.test_transforms.transforms:
  231. name = op.__class__.__name__
  232. attr = op.__dict__
  233. info['Transforms'].append({name: attr})
  234. return info
  235. def save_model(self, save_dir):
  236. if not osp.isdir(save_dir):
  237. if osp.exists(save_dir):
  238. os.remove(save_dir)
  239. os.makedirs(save_dir)
  240. fluid.save(self.train_prog, osp.join(save_dir, 'model'))
  241. model_info = self.get_model_info()
  242. model_info['status'] = self.status
  243. with open(
  244. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  245. mode='w') as f:
  246. yaml.dump(model_info, f)
  247. # 评估结果保存
  248. if hasattr(self, 'eval_details'):
  249. with open(osp.join(save_dir, 'eval_details.json'), 'w') as f:
  250. json.dump(self.eval_details, f)
  251. if self.status == 'Prune':
  252. # 保存裁剪的shape
  253. shapes = {}
  254. for block in self.train_prog.blocks:
  255. for param in block.all_parameters():
  256. pd_var = fluid.global_scope().find_var(param.name)
  257. pd_param = pd_var.get_tensor()
  258. shapes[param.name] = np.array(pd_param).shape
  259. with open(
  260. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  261. mode='w') as f:
  262. yaml.dump(shapes, f)
  263. # 模型保存成功的标志
  264. open(osp.join(save_dir, '.success'), 'w').close()
  265. logging.info("Model saved in {}.".format(save_dir))
  266. def export_inference_model(self, save_dir):
  267. test_input_names = [
  268. var.name for var in list(self.test_inputs.values())
  269. ]
  270. test_outputs = list(self.test_outputs.values())
  271. if self.__class__.__name__ == 'MaskRCNN':
  272. from paddlex.utils.save import save_mask_inference_model
  273. save_mask_inference_model(
  274. dirname=save_dir,
  275. executor=self.exe,
  276. params_filename='__params__',
  277. feeded_var_names=test_input_names,
  278. target_vars=test_outputs,
  279. main_program=self.test_prog)
  280. else:
  281. fluid.io.save_inference_model(
  282. dirname=save_dir,
  283. executor=self.exe,
  284. params_filename='__params__',
  285. feeded_var_names=test_input_names,
  286. target_vars=test_outputs,
  287. main_program=self.test_prog)
  288. model_info = self.get_model_info()
  289. model_info['status'] = 'Infer'
  290. # 保存模型输出的变量描述
  291. model_info['_ModelInputsOutputs'] = dict()
  292. model_info['_ModelInputsOutputs']['test_inputs'] = [
  293. [k, v.name] for k, v in self.test_inputs.items()
  294. ]
  295. model_info['_ModelInputsOutputs']['test_outputs'] = [
  296. [k, v.name] for k, v in self.test_outputs.items()
  297. ]
  298. with open(
  299. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  300. mode='w') as f:
  301. yaml.dump(model_info, f)
  302. # 模型保存成功的标志
  303. open(osp.join(save_dir, '.success'), 'w').close()
  304. logging.info(
  305. "Model for inference deploy saved in {}.".format(save_dir))
  306. def export_onnx_model(self, save_dir, onnx_name=None):
  307. from fluid.utils import op_io_info, init_name_prefix
  308. from onnx import helper, checker
  309. import fluid_onnx.ops as ops
  310. from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
  311. from debug.model_check import debug_model, Tracker
  312. place = fluid.CPUPlace()
  313. exe = fluid.Executor(place)
  314. inference_scope = fluid.global_scope()
  315. with fluid.scope_guard(inference_scope):
  316. test_input_names = [
  317. var.name for var in list(self.test_inputs.values())
  318. ]
  319. inputs_outputs_list = ["fetch", "feed"]
  320. weights, weights_value_info = [], []
  321. global_block = self.test_prog.global_block()
  322. for var_name in global_block.vars:
  323. var = global_block.var(var_name)
  324. if var_name not in test_input_names\
  325. and var.persistable:
  326. weight, val_info = paddle_onnx_weight(
  327. var=var, scope=inference_scope)
  328. weights.append(weight)
  329. weights_value_info.append(val_info)
  330. # Create inputs
  331. inputs = [
  332. paddle_variable_to_onnx_tensor(v, global_block)
  333. for v in test_input_names
  334. ]
  335. print("load the model parameter done.")
  336. onnx_nodes = []
  337. op_check_list = []
  338. op_trackers = []
  339. nms_first_index = -1
  340. nms_outputs = []
  341. for block in self.test_prog.blocks:
  342. for op in block.ops:
  343. if op.type in ops.node_maker:
  344. # TODO(kuke): deal with the corner case that vars in
  345. # different blocks have the same name
  346. node_proto = ops.node_maker[str(op.type)](
  347. operator=op, block=block)
  348. op_outputs = []
  349. last_node = None
  350. if isinstance(node_proto, tuple):
  351. onnx_nodes.extend(list(node_proto))
  352. last_node = list(node_proto)
  353. else:
  354. onnx_nodes.append(node_proto)
  355. last_node = [node_proto]
  356. tracker = Tracker(str(op.type), last_node)
  357. op_trackers.append(tracker)
  358. op_check_list.append(str(op.type))
  359. if op.type == "multiclass_nms" and nms_first_index < 0:
  360. nms_first_index = 0
  361. if nms_first_index >= 0:
  362. _, _, output_op = op_io_info(op)
  363. for output in output_op:
  364. nms_outputs.extend(output_op[output])
  365. else:
  366. if op.type not in ['feed', 'fetch']:
  367. op_check_list.append(op.type)
  368. print('The operator sets to run test case.')
  369. print(set(op_check_list))
  370. # Create outputs
  371. # Get the new names for outputs if they've been renamed in nodes' making
  372. renamed_outputs = op_io_info.get_all_renamed_outputs()
  373. test_outputs = list(self.test_outputs.values())
  374. test_outputs_names = [
  375. var.name for var in self.test_outputs.values()
  376. ]
  377. test_outputs_names = [
  378. name if name not in renamed_outputs else renamed_outputs[name]
  379. for name in test_outputs_names
  380. ]
  381. outputs = [
  382. paddle_variable_to_onnx_tensor(v, global_block)
  383. for v in test_outputs_names
  384. ]
  385. # Make graph
  386. onnx_name = 'test'
  387. onnx_graph = helper.make_graph(
  388. nodes=onnx_nodes,
  389. name=onnx_name,
  390. initializer=weights,
  391. inputs=inputs + weights_value_info,
  392. outputs=outputs)
  393. # Make model
  394. onnx_model = helper.make_model(
  395. onnx_graph, producer_name='PaddlePaddle')
  396. # Model check
  397. checker.check_model(onnx_model)
  398. # Print model
  399. #if to_print_model:
  400. # print("The converted model is:\n{}".format(onnx_model))
  401. # Save converted model
  402. if onnx_model is not None:
  403. try:
  404. onnx_model_file = osp.join(save_dir, onnx_name)
  405. with open(onnx_model_file, 'wb') as f:
  406. f.write(onnx_model.SerializeToString())
  407. print(
  408. "Saved converted model to path: %s" % onnx_model_file)
  409. except Exception as e:
  410. print(e)
  411. print(
  412. "Convert Failed! Please use the debug message to find error."
  413. )
  414. sys.exit(-1)
  415. def train_loop(self,
  416. num_epochs,
  417. train_dataset,
  418. train_batch_size,
  419. eval_dataset=None,
  420. save_interval_epochs=1,
  421. log_interval_steps=10,
  422. save_dir='output',
  423. use_vdl=False,
  424. early_stop=False,
  425. early_stop_patience=5):
  426. if not osp.isdir(save_dir):
  427. if osp.exists(save_dir):
  428. os.remove(save_dir)
  429. os.makedirs(save_dir)
  430. if use_vdl:
  431. from visualdl import LogWriter
  432. vdl_logdir = osp.join(save_dir, 'vdl_log')
  433. # 给transform添加arrange操作
  434. self.arrange_transforms(
  435. transforms=train_dataset.transforms, mode='train')
  436. # 构建train_data_loader
  437. self.build_train_data_loader(
  438. dataset=train_dataset, batch_size=train_batch_size)
  439. if eval_dataset is not None:
  440. self.eval_transforms = eval_dataset.transforms
  441. self.test_transforms = copy.deepcopy(eval_dataset.transforms)
  442. # 获取实时变化的learning rate
  443. lr = self.optimizer._learning_rate
  444. if isinstance(lr, fluid.framework.Variable):
  445. self.train_outputs['lr'] = lr
  446. # 在多卡上跑训练
  447. if self.parallel_train_prog is None:
  448. build_strategy = fluid.compiler.BuildStrategy()
  449. build_strategy.fuse_all_optimizer_ops = False
  450. if paddlex.env_info['place'] != 'cpu' and len(self.places) > 1:
  451. build_strategy.sync_batch_norm = self.sync_bn
  452. exec_strategy = fluid.ExecutionStrategy()
  453. exec_strategy.num_iteration_per_drop_scope = 1
  454. self.parallel_train_prog = fluid.CompiledProgram(
  455. self.train_prog).with_data_parallel(
  456. loss_name=self.train_outputs['loss'].name,
  457. build_strategy=build_strategy,
  458. exec_strategy=exec_strategy)
  459. total_num_steps = math.floor(
  460. train_dataset.num_samples / train_batch_size)
  461. num_steps = 0
  462. time_stat = list()
  463. time_train_one_epoch = None
  464. time_eval_one_epoch = None
  465. total_num_steps_eval = 0
  466. # 模型总共的评估次数
  467. total_eval_times = math.ceil(num_epochs / save_interval_epochs)
  468. # 检测目前仅支持单卡评估,训练数据batch大小与显卡数量之商为验证数据batch大小。
  469. eval_batch_size = train_batch_size
  470. if self.model_type == 'detector':
  471. eval_batch_size = self._get_single_card_bs(train_batch_size)
  472. if eval_dataset is not None:
  473. total_num_steps_eval = math.ceil(
  474. eval_dataset.num_samples / eval_batch_size)
  475. if use_vdl:
  476. # VisualDL component
  477. log_writer = LogWriter(vdl_logdir, sync_cycle=20)
  478. train_step_component = OrderedDict()
  479. eval_component = OrderedDict()
  480. thresh = 0.0001
  481. if early_stop:
  482. earlystop = EarlyStop(early_stop_patience, thresh)
  483. best_accuracy_key = ""
  484. best_accuracy = -1.0
  485. best_model_epoch = 1
  486. for i in range(num_epochs):
  487. records = list()
  488. step_start_time = time.time()
  489. epoch_start_time = time.time()
  490. for step, data in enumerate(self.train_data_loader()):
  491. outputs = self.exe.run(
  492. self.parallel_train_prog,
  493. feed=data,
  494. fetch_list=list(self.train_outputs.values()))
  495. outputs_avg = np.mean(np.array(outputs), axis=1)
  496. records.append(outputs_avg)
  497. # 训练完成剩余时间预估
  498. current_time = time.time()
  499. step_cost_time = current_time - step_start_time
  500. step_start_time = current_time
  501. if len(time_stat) < 20:
  502. time_stat.append(step_cost_time)
  503. else:
  504. time_stat[num_steps % 20] = step_cost_time
  505. # 每间隔log_interval_steps,输出loss信息
  506. num_steps += 1
  507. if num_steps % log_interval_steps == 0:
  508. step_metrics = OrderedDict(
  509. zip(list(self.train_outputs.keys()), outputs_avg))
  510. if use_vdl:
  511. for k, v in step_metrics.items():
  512. if k not in train_step_component.keys():
  513. with log_writer.mode('Each_Step_while_Training'
  514. ) as step_logger:
  515. train_step_component[
  516. k] = step_logger.scalar(
  517. 'Training: {}'.format(k))
  518. train_step_component[k].add_record(num_steps, v)
  519. # 估算剩余时间
  520. avg_step_time = np.mean(time_stat)
  521. if time_train_one_epoch is not None:
  522. eta = (num_epochs - i - 1) * time_train_one_epoch + (
  523. total_num_steps - step - 1) * avg_step_time
  524. else:
  525. eta = ((num_epochs - i) * total_num_steps - step -
  526. 1) * avg_step_time
  527. if time_eval_one_epoch is not None:
  528. eval_eta = (total_eval_times - i //
  529. save_interval_epochs) * time_eval_one_epoch
  530. else:
  531. eval_eta = (
  532. total_eval_times - i // save_interval_epochs
  533. ) * total_num_steps_eval * avg_step_time
  534. eta_str = seconds_to_hms(eta + eval_eta)
  535. logging.info(
  536. "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
  537. .format(i + 1, num_epochs, step + 1, total_num_steps,
  538. dict2str(step_metrics), round(
  539. avg_step_time, 2), eta_str))
  540. train_metrics = OrderedDict(
  541. zip(list(self.train_outputs.keys()), np.mean(records, axis=0)))
  542. logging.info('[TRAIN] Epoch {} finished, {} .'.format(
  543. i + 1, dict2str(train_metrics)))
  544. time_train_one_epoch = time.time() - epoch_start_time
  545. epoch_start_time = time.time()
  546. # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
  547. eval_epoch_start_time = time.time()
  548. if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
  549. current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
  550. if not osp.isdir(current_save_dir):
  551. os.makedirs(current_save_dir)
  552. if eval_dataset is not None:
  553. self.eval_metrics, self.eval_details = self.evaluate(
  554. eval_dataset=eval_dataset,
  555. batch_size=eval_batch_size,
  556. epoch_id=i + 1,
  557. return_details=True)
  558. logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
  559. i + 1, dict2str(self.eval_metrics)))
  560. # 保存最优模型
  561. best_accuracy_key = list(self.eval_metrics.keys())[0]
  562. current_accuracy = self.eval_metrics[best_accuracy_key]
  563. if current_accuracy > best_accuracy:
  564. best_accuracy = current_accuracy
  565. best_model_epoch = i + 1
  566. best_model_dir = osp.join(save_dir, "best_model")
  567. self.save_model(save_dir=best_model_dir)
  568. if use_vdl:
  569. for k, v in self.eval_metrics.items():
  570. if isinstance(v, list):
  571. continue
  572. if isinstance(v, np.ndarray):
  573. if v.size > 1:
  574. continue
  575. if k not in eval_component:
  576. with log_writer.mode('Each_Epoch_on_Eval_Data'
  577. ) as eval_logger:
  578. eval_component[k] = eval_logger.scalar(
  579. 'Evaluation: {}'.format(k))
  580. eval_component[k].add_record(i + 1, v)
  581. self.save_model(save_dir=current_save_dir)
  582. time_eval_one_epoch = time.time() - eval_epoch_start_time
  583. eval_epoch_start_time = time.time()
  584. logging.info(
  585. 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
  586. .format(best_model_epoch, best_accuracy_key,
  587. best_accuracy))
  588. if eval_dataset is not None and early_stop:
  589. if earlystop(current_accuracy):
  590. break