trainer.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import copy
  20. import time
  21. import numpy as np
  22. from PIL import Image
  23. import paddle
  24. import paddle.distributed as dist
  25. from paddle.distributed import fleet
  26. from paddle import amp
  27. from paddle.static import InputSpec
  28. from paddlex.ppdet.optimizer import ModelEMA
  29. from paddlex.ppdet.core.workspace import create
  30. from paddlex.ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  31. from paddlex.ppdet.utils.visualizer import visualize_results, save_result
  32. from paddlex.ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
  33. from paddlex.ppdet.metrics import RBoxMetric, JDEDetMetric
  34. from paddlex.ppdet.data.source.category import get_categories
  35. from paddlex.ppdet.utils import stats
  36. from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
  37. from .export_utils import _dump_infer_config
  38. from paddlex.ppdet.utils.logger import setup_logger
  39. logger = setup_logger('ppdet.engine')
  40. __all__ = ['Trainer']
  41. MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']
  42. class Trainer(object):
  43. def __init__(self, cfg, mode='train'):
  44. self.cfg = cfg
  45. assert mode.lower() in ['train', 'eval', 'test'], \
  46. "mode should be 'train', 'eval' or 'test'"
  47. self.mode = mode.lower()
  48. self.optimizer = None
  49. self.is_loaded_weights = False
  50. # build data loader
  51. if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
  52. self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
  53. else:
  54. self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
  55. if cfg.architecture == 'DeepSORT' and self.mode == 'train':
  56. logger.error('DeepSORT has no need of training on mot dataset.')
  57. sys.exit(1)
  58. if self.mode == 'train':
  59. self.loader = create('{}Reader'.format(self.mode.capitalize()))(
  60. self.dataset, cfg.worker_num)
  61. if cfg.architecture == 'JDE' and self.mode == 'train':
  62. cfg['JDEEmbeddingHead'][
  63. 'num_identifiers'] = self.dataset.total_identities
  64. if cfg.architecture == 'FairMOT' and self.mode == 'train':
  65. cfg['FairMOTEmbeddingHead'][
  66. 'num_identifiers'] = self.dataset.total_identities
  67. # build model
  68. if 'model' not in self.cfg:
  69. self.model = create(cfg.architecture)
  70. else:
  71. self.model = self.cfg.model
  72. self.is_loaded_weights = True
  73. #normalize params for deploy
  74. self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
  75. self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
  76. if self.use_ema:
  77. ema_decay = self.cfg.get('ema_decay', 0.9998)
  78. cycle_epoch = self.cfg.get('cycle_epoch', -1)
  79. self.ema = ModelEMA(
  80. self.model,
  81. decay=ema_decay,
  82. use_thres_step=True,
  83. cycle_epoch=cycle_epoch)
  84. # EvalDataset build with BatchSampler to evaluate in single device
  85. # TODO: multi-device evaluate
  86. if self.mode == 'eval':
  87. self._eval_batch_sampler = paddle.io.BatchSampler(
  88. self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
  89. self.loader = create('{}Reader'.format(self.mode.capitalize()))(
  90. self.dataset, cfg.worker_num, self._eval_batch_sampler)
  91. # TestDataset build after user set images, skip loader creation here
  92. # build optimizer in train mode
  93. if self.mode == 'train':
  94. steps_per_epoch = len(self.loader)
  95. self.lr = create('LearningRate')(steps_per_epoch)
  96. self.optimizer = create('OptimizerBuilder')(
  97. self.lr, self.model.parameters())
  98. self._nranks = dist.get_world_size()
  99. self._local_rank = dist.get_rank()
  100. self.status = {}
  101. self.start_epoch = 0
  102. self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
  103. # initial default callbacks
  104. self._init_callbacks()
  105. # initial default metrics
  106. self._init_metrics()
  107. self._reset_metrics()
  108. def _init_callbacks(self):
  109. if self.mode == 'train':
  110. self._callbacks = [LogPrinter(self), Checkpointer(self)]
  111. if self.cfg.get('use_vdl', False):
  112. self._callbacks.append(VisualDLWriter(self))
  113. self._compose_callback = ComposeCallback(self._callbacks)
  114. elif self.mode == 'eval':
  115. self._callbacks = [LogPrinter(self)]
  116. if self.cfg.metric == 'WiderFace':
  117. self._callbacks.append(WiferFaceEval(self))
  118. self._compose_callback = ComposeCallback(self._callbacks)
  119. elif self.mode == 'test' and self.cfg.get('use_vdl', False):
  120. self._callbacks = [VisualDLWriter(self)]
  121. self._compose_callback = ComposeCallback(self._callbacks)
  122. else:
  123. self._callbacks = []
  124. self._compose_callback = None
  125. def _init_metrics(self, validate=False):
  126. if self.mode == 'test' or (self.mode == 'train' and not validate):
  127. self._metrics = []
  128. return
  129. classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
  130. if self.cfg.metric == 'COCO':
  131. # TODO: bias should be unified
  132. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  133. output_eval = self.cfg['output_eval'] \
  134. if 'output_eval' in self.cfg else None
  135. save_prediction_only = self.cfg.get('save_prediction_only', False)
  136. # pass clsid2catid info to metric instance to avoid multiple loading
  137. # annotation file
  138. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  139. if self.mode == 'eval' else None
  140. # when do validation in train, annotation file should be get from
  141. # EvalReader instead of self.dataset(which is TrainReader)
  142. anno_file = self.dataset.get_anno()
  143. if self.mode == 'train' and validate:
  144. eval_dataset = self.cfg['EvalDataset']
  145. eval_dataset.check_or_download_dataset()
  146. anno_file = eval_dataset.get_anno()
  147. IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
  148. self._metrics = [
  149. COCOMetric(
  150. anno_file=anno_file,
  151. clsid2catid=clsid2catid,
  152. classwise=classwise,
  153. output_eval=output_eval,
  154. bias=bias,
  155. IouType=IouType,
  156. save_prediction_only=save_prediction_only)
  157. ]
  158. elif self.cfg.metric == 'RBOX':
  159. # TODO: bias should be unified
  160. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  161. output_eval = self.cfg['output_eval'] \
  162. if 'output_eval' in self.cfg else None
  163. save_prediction_only = self.cfg.get('save_prediction_only', False)
  164. # pass clsid2catid info to metric instance to avoid multiple loading
  165. # annotation file
  166. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  167. if self.mode == 'eval' else None
  168. # when do validation in train, annotation file should be get from
  169. # EvalReader instead of self.dataset(which is TrainReader)
  170. anno_file = self.dataset.get_anno()
  171. if self.mode == 'train' and validate:
  172. eval_dataset = self.cfg['EvalDataset']
  173. eval_dataset.check_or_download_dataset()
  174. anno_file = eval_dataset.get_anno()
  175. self._metrics = [
  176. RBoxMetric(
  177. anno_file=anno_file,
  178. clsid2catid=clsid2catid,
  179. classwise=classwise,
  180. output_eval=output_eval,
  181. bias=bias,
  182. save_prediction_only=save_prediction_only)
  183. ]
  184. elif self.cfg.metric == 'VOC':
  185. self._metrics = [
  186. VOCMetric(
  187. label_list=self.dataset.get_label_list(),
  188. class_num=self.cfg.num_classes,
  189. map_type=self.cfg.map_type,
  190. classwise=classwise)
  191. ]
  192. elif self.cfg.metric == 'WiderFace':
  193. multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
  194. self._metrics = [
  195. WiderFaceMetric(
  196. image_dir=os.path.join(self.dataset.dataset_dir,
  197. self.dataset.image_dir),
  198. anno_file=self.dataset.get_anno(),
  199. multi_scale=multi_scale)
  200. ]
  201. elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
  202. eval_dataset = self.cfg['EvalDataset']
  203. eval_dataset.check_or_download_dataset()
  204. anno_file = eval_dataset.get_anno()
  205. save_prediction_only = self.cfg.get('save_prediction_only', False)
  206. self._metrics = [
  207. KeyPointTopDownCOCOEval(
  208. anno_file,
  209. len(eval_dataset),
  210. self.cfg.num_joints,
  211. self.cfg.save_dir,
  212. save_prediction_only=save_prediction_only)
  213. ]
  214. elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
  215. eval_dataset = self.cfg['EvalDataset']
  216. eval_dataset.check_or_download_dataset()
  217. anno_file = eval_dataset.get_anno()
  218. save_prediction_only = self.cfg.get('save_prediction_only', False)
  219. self._metrics = [
  220. KeyPointTopDownMPIIEval(
  221. anno_file,
  222. len(eval_dataset),
  223. self.cfg.num_joints,
  224. self.cfg.save_dir,
  225. save_prediction_only=save_prediction_only)
  226. ]
  227. elif self.cfg.metric == 'MOTDet':
  228. self._metrics = [JDEDetMetric(), ]
  229. else:
  230. logger.warning("Metric not support for metric type {}".format(
  231. self.cfg.metric))
  232. self._metrics = []
  233. def _reset_metrics(self):
  234. for metric in self._metrics:
  235. metric.reset()
  236. def register_callbacks(self, callbacks):
  237. callbacks = [c for c in list(callbacks) if c is not None]
  238. for c in callbacks:
  239. assert isinstance(c, Callback), \
  240. "metrics shoule be instances of subclass of Metric"
  241. self._callbacks.extend(callbacks)
  242. self._compose_callback = ComposeCallback(self._callbacks)
  243. def register_metrics(self, metrics):
  244. metrics = [m for m in list(metrics) if m is not None]
  245. for m in metrics:
  246. assert isinstance(m, Metric), \
  247. "metrics shoule be instances of subclass of Metric"
  248. self._metrics.extend(metrics)
  249. def load_weights(self, weights):
  250. if self.is_loaded_weights:
  251. return
  252. self.start_epoch = 0
  253. load_pretrain_weight(self.model, weights)
  254. logger.debug("Load weights {} to start training".format(weights))
  255. def load_weights_sde(self, det_weights, reid_weights):
  256. if self.model.detector:
  257. load_weight(self.model.detector, det_weights)
  258. load_weight(self.model.reid, reid_weights)
  259. else:
  260. load_weight(self.model.reid, reid_weights)
  261. def resume_weights(self, weights):
  262. # support Distill resume weights
  263. if hasattr(self.model, 'student_model'):
  264. self.start_epoch = load_weight(self.model.student_model, weights,
  265. self.optimizer)
  266. else:
  267. self.start_epoch = load_weight(self.model, weights, self.optimizer)
  268. logger.debug("Resume weights of epoch {}".format(self.start_epoch))
  269. def train(self, validate=False):
  270. assert self.mode == 'train', "Model not in 'train' mode"
  271. Init_mark = False
  272. model = self.model
  273. if self.cfg.get('fleet', False):
  274. model = fleet.distributed_model(model)
  275. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  276. elif self._nranks > 1:
  277. find_unused_parameters = self.cfg[
  278. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  279. model = paddle.DataParallel(
  280. self.model, find_unused_parameters=find_unused_parameters)
  281. # initial fp16
  282. if self.cfg.get('fp16', False):
  283. scaler = amp.GradScaler(
  284. enable=self.cfg.use_gpu, init_loss_scaling=1024)
  285. self.status.update({
  286. 'epoch_id': self.start_epoch,
  287. 'step_id': 0,
  288. 'steps_per_epoch': len(self.loader)
  289. })
  290. self.status['batch_time'] = stats.SmoothedValue(
  291. self.cfg.log_iter, fmt='{avg:.4f}')
  292. self.status['data_time'] = stats.SmoothedValue(
  293. self.cfg.log_iter, fmt='{avg:.4f}')
  294. self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
  295. if self.cfg.get('print_flops', False):
  296. self._flops(self.loader)
  297. for epoch_id in range(self.start_epoch, self.cfg.epoch):
  298. self.status['mode'] = 'train'
  299. self.status['epoch_id'] = epoch_id
  300. self._compose_callback.on_epoch_begin(self.status)
  301. self.loader.dataset.set_epoch(epoch_id)
  302. model.train()
  303. iter_tic = time.time()
  304. for step_id, data in enumerate(self.loader):
  305. self.status['data_time'].update(time.time() - iter_tic)
  306. self.status['step_id'] = step_id
  307. self._compose_callback.on_step_begin(self.status)
  308. if self.cfg.get('fp16', False):
  309. with amp.auto_cast(enable=self.cfg.use_gpu):
  310. # model forward
  311. outputs = model(data)
  312. loss = outputs['loss']
  313. # model backward
  314. scaled_loss = scaler.scale(loss)
  315. scaled_loss.backward()
  316. # in dygraph mode, optimizer.minimize is equal to optimizer.step
  317. scaler.minimize(self.optimizer, scaled_loss)
  318. else:
  319. # model forward
  320. outputs = model(data)
  321. loss = outputs['loss']
  322. # model backward
  323. loss.backward()
  324. self.optimizer.step()
  325. curr_lr = self.optimizer.get_lr()
  326. self.lr.step()
  327. self.optimizer.clear_grad()
  328. self.status['learning_rate'] = curr_lr
  329. if self._nranks < 2 or self._local_rank == 0:
  330. self.status['training_staus'].update(outputs)
  331. self.status['batch_time'].update(time.time() - iter_tic)
  332. self._compose_callback.on_step_end(self.status)
  333. if self.use_ema:
  334. self.ema.update(self.model)
  335. iter_tic = time.time()
  336. # apply ema weight on model
  337. if self.use_ema:
  338. weight = copy.deepcopy(self.model.state_dict())
  339. self.model.set_dict(self.ema.apply())
  340. self._compose_callback.on_epoch_end(self.status)
  341. if validate and (self._nranks < 2 or self._local_rank == 0) \
  342. and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
  343. or epoch_id == self.end_epoch - 1):
  344. if not hasattr(self, '_eval_loader'):
  345. # build evaluation dataset and loader
  346. self._eval_dataset = self.cfg.EvalDataset
  347. self._eval_batch_sampler = \
  348. paddle.io.BatchSampler(
  349. self._eval_dataset,
  350. batch_size=self.cfg.EvalReader['batch_size'])
  351. self._eval_loader = create('EvalReader')(
  352. self._eval_dataset,
  353. self.cfg.worker_num,
  354. batch_sampler=self._eval_batch_sampler)
  355. # if validation in training is enabled, metrics should be re-init
  356. # Init_mark makes sure this code will only execute once
  357. if validate and Init_mark == False:
  358. Init_mark = True
  359. self._init_metrics(validate=validate)
  360. self._reset_metrics()
  361. with paddle.no_grad():
  362. self.status['save_best_model'] = True
  363. self._eval_with_loader(self._eval_loader)
  364. # restore origin weight on model
  365. if self.use_ema:
  366. self.model.set_dict(weight)
  367. def _eval_with_loader(self, loader):
  368. sample_num = 0
  369. tic = time.time()
  370. self._compose_callback.on_epoch_begin(self.status)
  371. self.status['mode'] = 'eval'
  372. self.model.eval()
  373. if self.cfg.get('print_flops', False):
  374. self._flops(loader)
  375. for step_id, data in enumerate(loader):
  376. self.status['step_id'] = step_id
  377. self._compose_callback.on_step_begin(self.status)
  378. # forward
  379. outs = self.model(data)
  380. # update metrics
  381. for metric in self._metrics:
  382. metric.update(data, outs)
  383. sample_num += data['im_id'].numpy().shape[0]
  384. self._compose_callback.on_step_end(self.status)
  385. self.status['sample_num'] = sample_num
  386. self.status['cost_time'] = time.time() - tic
  387. # accumulate metric to log out
  388. for metric in self._metrics:
  389. metric.accumulate()
  390. metric.log()
  391. self._compose_callback.on_epoch_end(self.status)
  392. # reset metric states for metric may performed multiple times
  393. self._reset_metrics()
  394. def evaluate(self):
  395. with paddle.no_grad():
  396. self._eval_with_loader(self.loader)
  397. def predict(self,
  398. images,
  399. draw_threshold=0.5,
  400. output_dir='output',
  401. save_txt=False):
  402. self.dataset.set_images(images)
  403. loader = create('TestReader')(self.dataset, 0)
  404. imid2path = self.dataset.get_imid2path()
  405. anno_file = self.dataset.get_anno()
  406. clsid2catid, catid2name = get_categories(
  407. self.cfg.metric, anno_file=anno_file)
  408. # Run Infer
  409. self.status['mode'] = 'test'
  410. self.model.eval()
  411. if self.cfg.get('print_flops', False):
  412. self._flops(loader)
  413. for step_id, data in enumerate(loader):
  414. self.status['step_id'] = step_id
  415. # forward
  416. outs = self.model(data)
  417. for key in ['im_shape', 'scale_factor', 'im_id']:
  418. outs[key] = data[key]
  419. for key, value in outs.items():
  420. if hasattr(value, 'numpy'):
  421. outs[key] = value.numpy()
  422. batch_res = get_infer_results(outs, clsid2catid)
  423. bbox_num = outs['bbox_num']
  424. start = 0
  425. for i, im_id in enumerate(outs['im_id']):
  426. image_path = imid2path[int(im_id)]
  427. image = Image.open(image_path).convert('RGB')
  428. self.status['original_image'] = np.array(image.copy())
  429. end = start + bbox_num[i]
  430. bbox_res = batch_res['bbox'][start:end] \
  431. if 'bbox' in batch_res else None
  432. mask_res = batch_res['mask'][start:end] \
  433. if 'mask' in batch_res else None
  434. segm_res = batch_res['segm'][start:end] \
  435. if 'segm' in batch_res else None
  436. keypoint_res = batch_res['keypoint'][start:end] \
  437. if 'keypoint' in batch_res else None
  438. image = visualize_results(
  439. image, bbox_res, mask_res, segm_res, keypoint_res,
  440. int(im_id), catid2name, draw_threshold)
  441. self.status['result_image'] = np.array(image.copy())
  442. if self._compose_callback:
  443. self._compose_callback.on_step_end(self.status)
  444. # save image with detection
  445. save_name = self._get_save_image_name(output_dir, image_path)
  446. logger.info("Detection bbox results save in {}".format(
  447. save_name))
  448. image.save(save_name, quality=95)
  449. if save_txt:
  450. save_path = os.path.splitext(save_name)[0] + '.txt'
  451. results = {}
  452. results["im_id"] = im_id
  453. if bbox_res:
  454. results["bbox_res"] = bbox_res
  455. if keypoint_res:
  456. results["keypoint_res"] = keypoint_res
  457. save_result(save_path, results, catid2name, draw_threshold)
  458. start = end
  459. def _get_save_image_name(self, output_dir, image_path):
  460. """
  461. Get save image name from source image path.
  462. """
  463. if not os.path.exists(output_dir):
  464. os.makedirs(output_dir)
  465. image_name = os.path.split(image_path)[-1]
  466. name, ext = os.path.splitext(image_name)
  467. return os.path.join(output_dir, "{}".format(name)) + ext
  468. def export(self, output_dir='output_inference'):
  469. self.model.eval()
  470. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  471. save_dir = os.path.join(output_dir, model_name)
  472. if not os.path.exists(save_dir):
  473. os.makedirs(save_dir)
  474. image_shape = None
  475. if self.cfg.architecture in MOT_ARCH:
  476. test_reader_name = 'TestMOTReader'
  477. else:
  478. test_reader_name = 'TestReader'
  479. if 'inputs_def' in self.cfg[test_reader_name]:
  480. inputs_def = self.cfg[test_reader_name]['inputs_def']
  481. image_shape = inputs_def.get('image_shape', None)
  482. # set image_shape=[3, -1, -1] as default
  483. if image_shape is None:
  484. image_shape = [3, -1, -1]
  485. if hasattr(self.model, 'deploy'):
  486. self.model.deploy = True
  487. if hasattr(self.model, 'fuse_norm'):
  488. self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
  489. False)
  490. if hasattr(self.cfg, 'lite_deploy'):
  491. self.model.lite_deploy = self.cfg.lite_deploy
  492. # Save infer cfg
  493. _dump_infer_config(self.cfg,
  494. os.path.join(save_dir, 'infer_cfg.yml'),
  495. image_shape, self.model)
  496. input_spec = [{
  497. "image": InputSpec(
  498. shape=[None] + image_shape, name='image'),
  499. "im_shape": InputSpec(
  500. shape=[None, 2], name='im_shape'),
  501. "scale_factor": InputSpec(
  502. shape=[None, 2], name='scale_factor')
  503. }]
  504. if self.cfg.architecture == 'DeepSORT':
  505. input_spec[0].update({
  506. "crops": InputSpec(
  507. shape=[None, 3, 192, 64], name='crops')
  508. })
  509. static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
  510. # NOTE: dy2st do not pruned program, but jit.save will prune program
  511. # input spec, prune input spec here and save with pruned input spec
  512. pruned_input_spec = self._prune_input_spec(
  513. input_spec, static_model.forward.main_program,
  514. static_model.forward.outputs)
  515. # dy2st and save model
  516. if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
  517. paddle.jit.save(
  518. static_model,
  519. os.path.join(save_dir, 'model'),
  520. input_spec=pruned_input_spec)
  521. else:
  522. self.cfg.slim.save_quantized_model(
  523. self.model,
  524. os.path.join(save_dir, 'model'),
  525. input_spec=pruned_input_spec)
  526. logger.info("Export model and saved in {}".format(save_dir))
  527. def _prune_input_spec(self, input_spec, program, targets):
  528. # try to prune static program to figure out pruned input spec
  529. # so we perform following operations in static mode
  530. paddle.enable_static()
  531. pruned_input_spec = [{}]
  532. program = program.clone()
  533. program = program._prune(targets=targets)
  534. global_block = program.global_block()
  535. for name, spec in input_spec[0].items():
  536. try:
  537. v = global_block.var(name)
  538. pruned_input_spec[0][name] = spec
  539. except Exception:
  540. pass
  541. paddle.disable_static()
  542. return pruned_input_spec
  543. def _flops(self, loader):
  544. self.model.eval()
  545. try:
  546. import paddleslim
  547. except Exception as e:
  548. logger.warning(
  549. 'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
  550. )
  551. return
  552. from paddleslim.analysis import dygraph_flops as flops
  553. input_data = None
  554. for data in loader:
  555. input_data = data
  556. break
  557. input_spec = [{
  558. "image": input_data['image'][0].unsqueeze(0),
  559. "im_shape": input_data['im_shape'][0].unsqueeze(0),
  560. "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
  561. }]
  562. flops = flops(self.model, input_spec) / (1000**3)
  563. logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
  564. flops, input_data['image'][0].unsqueeze(0).shape))