trainer.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import copy
  20. import time
  21. import numpy as np
  22. import typing
  23. from PIL import Image, ImageOps
  24. import paddle
  25. import paddle.distributed as dist
  26. from paddle.distributed import fleet
  27. from paddle import amp
  28. from paddle.static import InputSpec
  29. from paddlex.ppdet.optimizer import ModelEMA
  30. from paddlex.ppdet.core.workspace import create
  31. from paddlex.ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  32. from paddlex.ppdet.utils.visualizer import visualize_results, save_result
  33. from paddlex.ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
  34. from paddlex.ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
  35. from paddlex.ppdet.data.source.sniper_coco import SniperCOCODataSet
  36. from paddlex.ppdet.data.source.category import get_categories
  37. from paddlex.ppdet.utils import stats
  38. from paddlex.ppdet.utils import profiler
  39. from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
  40. from .export_utils import _dump_infer_config, _prune_input_spec
  41. from paddlex.ppdet.utils.logger import setup_logger
  42. logger = setup_logger('ppdet.engine')
  43. __all__ = ['Trainer']
  44. MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']
  45. class Trainer(object):
  46. def __init__(self, cfg, mode='train'):
  47. self.cfg = cfg
  48. assert mode.lower() in ['train', 'eval', 'test'], \
  49. "mode should be 'train', 'eval' or 'test'"
  50. self.mode = mode.lower()
  51. self.optimizer = None
  52. self.is_loaded_weights = False
  53. # build data loader
  54. if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
  55. self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
  56. else:
  57. self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
  58. if cfg.architecture == 'DeepSORT' and self.mode == 'train':
  59. logger.error('DeepSORT has no need of training on mot dataset.')
  60. sys.exit(1)
  61. if self.mode == 'train':
  62. self.loader = create('{}Reader'.format(self.mode.capitalize()))(
  63. self.dataset, cfg.worker_num)
  64. if cfg.architecture == 'JDE' and self.mode == 'train':
  65. cfg['JDEEmbeddingHead'][
  66. 'num_identities'] = self.dataset.num_identities_dict[0]
  67. # JDE only support single class MOT now.
  68. if cfg.architecture == 'FairMOT' and self.mode == 'train':
  69. cfg['FairMOTEmbeddingHead'][
  70. 'num_identities_dict'] = self.dataset.num_identities_dict
  71. # FairMOT support single class and multi-class MOT now.
  72. # build model
  73. if 'model' not in self.cfg:
  74. self.model = create(cfg.architecture)
  75. else:
  76. self.model = self.cfg.model
  77. self.is_loaded_weights = True
  78. #normalize params for deploy
  79. self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
  80. self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
  81. if self.use_ema:
  82. ema_decay = self.cfg.get('ema_decay', 0.9998)
  83. cycle_epoch = self.cfg.get('cycle_epoch', -1)
  84. self.ema = ModelEMA(
  85. self.model,
  86. decay=ema_decay,
  87. use_thres_step=True,
  88. cycle_epoch=cycle_epoch)
  89. # EvalDataset build with BatchSampler to evaluate in single device
  90. # TODO: multi-device evaluate
  91. if self.mode == 'eval':
  92. self._eval_batch_sampler = paddle.io.BatchSampler(
  93. self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
  94. self.loader = create('{}Reader'.format(self.mode.capitalize()))(
  95. self.dataset, cfg.worker_num, self._eval_batch_sampler)
  96. # TestDataset build after user set images, skip loader creation here
  97. # build optimizer in train mode
  98. if self.mode == 'train':
  99. steps_per_epoch = len(self.loader)
  100. self.lr = create('LearningRate')(steps_per_epoch)
  101. self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
  102. if self.cfg.get('unstructured_prune'):
  103. self.pruner = create('UnstructuredPruner')(self.model,
  104. steps_per_epoch)
  105. self._nranks = dist.get_world_size()
  106. self._local_rank = dist.get_rank()
  107. self.status = {}
  108. self.start_epoch = 0
  109. self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
  110. # initial default callbacks
  111. self._init_callbacks()
  112. # initial default metrics
  113. self._init_metrics()
  114. self._reset_metrics()
  115. def _init_callbacks(self):
  116. if self.mode == 'train':
  117. self._callbacks = [LogPrinter(self), Checkpointer(self)]
  118. if self.cfg.get('use_vdl', False):
  119. self._callbacks.append(VisualDLWriter(self))
  120. if self.cfg.get('save_proposals', False):
  121. self._callbacks.append(SniperProposalsGenerator(self))
  122. self._compose_callback = ComposeCallback(self._callbacks)
  123. elif self.mode == 'eval':
  124. self._callbacks = [LogPrinter(self)]
  125. if self.cfg.metric == 'WiderFace':
  126. self._callbacks.append(WiferFaceEval(self))
  127. self._compose_callback = ComposeCallback(self._callbacks)
  128. elif self.mode == 'test' and self.cfg.get('use_vdl', False):
  129. self._callbacks = [VisualDLWriter(self)]
  130. self._compose_callback = ComposeCallback(self._callbacks)
  131. else:
  132. self._callbacks = []
  133. self._compose_callback = None
  134. def _init_metrics(self, validate=False):
  135. if self.mode == 'test' or (self.mode == 'train' and not validate):
  136. self._metrics = []
  137. return
  138. classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
  139. if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
  140. # TODO: bias should be unified
  141. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  142. output_eval = self.cfg['output_eval'] \
  143. if 'output_eval' in self.cfg else None
  144. save_prediction_only = self.cfg.get('save_prediction_only', False)
  145. # pass clsid2catid info to metric instance to avoid multiple loading
  146. # annotation file
  147. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  148. if self.mode == 'eval' else None
  149. # when do validation in train, annotation file should be get from
  150. # EvalReader instead of self.dataset(which is TrainReader)
  151. anno_file = self.dataset.get_anno()
  152. dataset = self.dataset
  153. if self.mode == 'train' and validate:
  154. eval_dataset = self.cfg['EvalDataset']
  155. eval_dataset.check_or_download_dataset()
  156. anno_file = eval_dataset.get_anno()
  157. dataset = eval_dataset
  158. IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
  159. if self.cfg.metric == "COCO":
  160. self._metrics = [
  161. COCOMetric(
  162. anno_file=anno_file,
  163. clsid2catid=clsid2catid,
  164. classwise=classwise,
  165. output_eval=output_eval,
  166. bias=bias,
  167. IouType=IouType,
  168. save_prediction_only=save_prediction_only)
  169. ]
  170. elif self.cfg.metric == "SNIPERCOCO": # sniper
  171. self._metrics = [
  172. SNIPERCOCOMetric(
  173. anno_file=anno_file,
  174. dataset=dataset,
  175. clsid2catid=clsid2catid,
  176. classwise=classwise,
  177. output_eval=output_eval,
  178. bias=bias,
  179. IouType=IouType,
  180. save_prediction_only=save_prediction_only)
  181. ]
  182. elif self.cfg.metric == 'RBOX':
  183. # TODO: bias should be unified
  184. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  185. output_eval = self.cfg['output_eval'] \
  186. if 'output_eval' in self.cfg else None
  187. save_prediction_only = self.cfg.get('save_prediction_only', False)
  188. # pass clsid2catid info to metric instance to avoid multiple loading
  189. # annotation file
  190. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  191. if self.mode == 'eval' else None
  192. # when do validation in train, annotation file should be get from
  193. # EvalReader instead of self.dataset(which is TrainReader)
  194. anno_file = self.dataset.get_anno()
  195. if self.mode == 'train' and validate:
  196. eval_dataset = self.cfg['EvalDataset']
  197. eval_dataset.check_or_download_dataset()
  198. anno_file = eval_dataset.get_anno()
  199. self._metrics = [
  200. RBoxMetric(
  201. anno_file=anno_file,
  202. clsid2catid=clsid2catid,
  203. classwise=classwise,
  204. output_eval=output_eval,
  205. bias=bias,
  206. save_prediction_only=save_prediction_only)
  207. ]
  208. elif self.cfg.metric == 'VOC':
  209. self._metrics = [
  210. VOCMetric(
  211. label_list=self.dataset.get_label_list(),
  212. class_num=self.cfg.num_classes,
  213. map_type=self.cfg.map_type,
  214. classwise=classwise)
  215. ]
  216. elif self.cfg.metric == 'WiderFace':
  217. multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
  218. self._metrics = [
  219. WiderFaceMetric(
  220. image_dir=os.path.join(self.dataset.dataset_dir,
  221. self.dataset.image_dir),
  222. anno_file=self.dataset.get_anno(),
  223. multi_scale=multi_scale)
  224. ]
  225. elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
  226. eval_dataset = self.cfg['EvalDataset']
  227. eval_dataset.check_or_download_dataset()
  228. anno_file = eval_dataset.get_anno()
  229. save_prediction_only = self.cfg.get('save_prediction_only', False)
  230. self._metrics = [
  231. KeyPointTopDownCOCOEval(
  232. anno_file,
  233. len(eval_dataset),
  234. self.cfg.num_joints,
  235. self.cfg.save_dir,
  236. save_prediction_only=save_prediction_only)
  237. ]
  238. elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
  239. eval_dataset = self.cfg['EvalDataset']
  240. eval_dataset.check_or_download_dataset()
  241. anno_file = eval_dataset.get_anno()
  242. save_prediction_only = self.cfg.get('save_prediction_only', False)
  243. self._metrics = [
  244. KeyPointTopDownMPIIEval(
  245. anno_file,
  246. len(eval_dataset),
  247. self.cfg.num_joints,
  248. self.cfg.save_dir,
  249. save_prediction_only=save_prediction_only)
  250. ]
  251. elif self.cfg.metric == 'MOTDet':
  252. self._metrics = [JDEDetMetric(), ]
  253. else:
  254. logger.warning("Metric not support for metric type {}".format(
  255. self.cfg.metric))
  256. self._metrics = []
  257. def _reset_metrics(self):
  258. for metric in self._metrics:
  259. metric.reset()
  260. def register_callbacks(self, callbacks):
  261. callbacks = [c for c in list(callbacks) if c is not None]
  262. for c in callbacks:
  263. assert isinstance(c, Callback), \
  264. "metrics shoule be instances of subclass of Metric"
  265. self._callbacks.extend(callbacks)
  266. self._compose_callback = ComposeCallback(self._callbacks)
  267. def register_metrics(self, metrics):
  268. metrics = [m for m in list(metrics) if m is not None]
  269. for m in metrics:
  270. assert isinstance(m, Metric), \
  271. "metrics shoule be instances of subclass of Metric"
  272. self._metrics.extend(metrics)
  273. def load_weights(self, weights):
  274. if self.is_loaded_weights:
  275. return
  276. self.start_epoch = 0
  277. load_pretrain_weight(self.model, weights)
  278. logger.debug("Load weights {} to start training".format(weights))
  279. def load_weights_sde(self, det_weights, reid_weights):
  280. if self.model.detector:
  281. load_weight(self.model.detector, det_weights)
  282. load_weight(self.model.reid, reid_weights)
  283. else:
  284. load_weight(self.model.reid, reid_weights)
  285. def resume_weights(self, weights):
  286. # support Distill resume weights
  287. if hasattr(self.model, 'student_model'):
  288. self.start_epoch = load_weight(self.model.student_model, weights,
  289. self.optimizer)
  290. else:
  291. self.start_epoch = load_weight(self.model, weights, self.optimizer)
  292. logger.debug("Resume weights of epoch {}".format(self.start_epoch))
  293. def train(self, validate=False):
  294. assert self.mode == 'train', "Model not in 'train' mode"
  295. Init_mark = False
  296. model = self.model
  297. if self.cfg.get('fleet', False):
  298. model = fleet.distributed_model(model)
  299. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  300. elif self._nranks > 1:
  301. find_unused_parameters = self.cfg[
  302. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  303. model = paddle.DataParallel(
  304. self.model, find_unused_parameters=find_unused_parameters)
  305. # initial fp16
  306. if self.cfg.get('fp16', False):
  307. scaler = amp.GradScaler(
  308. enable=self.cfg.use_gpu, init_loss_scaling=1024)
  309. self.status.update({
  310. 'epoch_id': self.start_epoch,
  311. 'step_id': 0,
  312. 'steps_per_epoch': len(self.loader)
  313. })
  314. self.status['batch_time'] = stats.SmoothedValue(
  315. self.cfg.log_iter, fmt='{avg:.4f}')
  316. self.status['data_time'] = stats.SmoothedValue(
  317. self.cfg.log_iter, fmt='{avg:.4f}')
  318. self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
  319. if self.cfg.get('print_flops', False):
  320. self._flops(self.loader)
  321. profiler_options = self.cfg.get('profiler_options', None)
  322. self._compose_callback.on_train_begin(self.status)
  323. for epoch_id in range(self.start_epoch, self.cfg.epoch):
  324. self.status['mode'] = 'train'
  325. self.status['epoch_id'] = epoch_id
  326. self._compose_callback.on_epoch_begin(self.status)
  327. self.loader.dataset.set_epoch(epoch_id)
  328. model.train()
  329. iter_tic = time.time()
  330. for step_id, data in enumerate(self.loader):
  331. self.status['data_time'].update(time.time() - iter_tic)
  332. self.status['step_id'] = step_id
  333. profiler.add_profiler_step(profiler_options)
  334. self._compose_callback.on_step_begin(self.status)
  335. data['epoch_id'] = epoch_id
  336. if self.cfg.get('fp16', False):
  337. with amp.auto_cast(enable=self.cfg.use_gpu):
  338. # model forward
  339. outputs = model(data)
  340. loss = outputs['loss']
  341. # model backward
  342. scaled_loss = scaler.scale(loss)
  343. scaled_loss.backward()
  344. # in dygraph mode, optimizer.minimize is equal to optimizer.step
  345. scaler.minimize(self.optimizer, scaled_loss)
  346. else:
  347. # model forward
  348. outputs = model(data)
  349. loss = outputs['loss']
  350. # model backward
  351. loss.backward()
  352. self.optimizer.step()
  353. curr_lr = self.optimizer.get_lr()
  354. self.lr.step()
  355. if self.cfg.get('unstructured_prune'):
  356. self.pruner.step()
  357. self.optimizer.clear_grad()
  358. self.status['learning_rate'] = curr_lr
  359. if self._nranks < 2 or self._local_rank == 0:
  360. self.status['training_staus'].update(outputs)
  361. self.status['batch_time'].update(time.time() - iter_tic)
  362. self._compose_callback.on_step_end(self.status)
  363. if self.use_ema:
  364. self.ema.update(self.model)
  365. iter_tic = time.time()
  366. # apply ema weight on model
  367. if self.use_ema:
  368. weight = copy.deepcopy(self.model.state_dict())
  369. self.model.set_dict(self.ema.apply())
  370. if self.cfg.get('unstructured_prune'):
  371. self.pruner.update_params()
  372. self._compose_callback.on_epoch_end(self.status)
  373. if validate and (self._nranks < 2 or self._local_rank == 0) \
  374. and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
  375. or epoch_id == self.end_epoch - 1):
  376. if not hasattr(self, '_eval_loader'):
  377. # build evaluation dataset and loader
  378. self._eval_dataset = self.cfg.EvalDataset
  379. self._eval_batch_sampler = \
  380. paddle.io.BatchSampler(
  381. self._eval_dataset,
  382. batch_size=self.cfg.EvalReader['batch_size'])
  383. self._eval_loader = create('EvalReader')(
  384. self._eval_dataset,
  385. self.cfg.worker_num,
  386. batch_sampler=self._eval_batch_sampler)
  387. # if validation in training is enabled, metrics should be re-init
  388. # Init_mark makes sure this code will only execute once
  389. if validate and Init_mark == False:
  390. Init_mark = True
  391. self._init_metrics(validate=validate)
  392. self._reset_metrics()
  393. with paddle.no_grad():
  394. self.status['save_best_model'] = True
  395. self._eval_with_loader(self._eval_loader)
  396. # restore origin weight on model
  397. if self.use_ema:
  398. self.model.set_dict(weight)
  399. self._compose_callback.on_train_end(self.status)
  400. def _eval_with_loader(self, loader):
  401. sample_num = 0
  402. tic = time.time()
  403. self._compose_callback.on_epoch_begin(self.status)
  404. self.status['mode'] = 'eval'
  405. self.model.eval()
  406. if self.cfg.get('print_flops', False):
  407. self._flops(loader)
  408. for step_id, data in enumerate(loader):
  409. self.status['step_id'] = step_id
  410. self._compose_callback.on_step_begin(self.status)
  411. # forward
  412. outs = self.model(data)
  413. # update metrics
  414. for metric in self._metrics:
  415. metric.update(data, outs)
  416. # multi-scale inputs: all inputs have same im_id
  417. if isinstance(data, typing.Sequence):
  418. sample_num += data[0]['im_id'].numpy().shape[0]
  419. else:
  420. sample_num += data['im_id'].numpy().shape[0]
  421. self._compose_callback.on_step_end(self.status)
  422. self.status['sample_num'] = sample_num
  423. self.status['cost_time'] = time.time() - tic
  424. # accumulate metric to log out
  425. for metric in self._metrics:
  426. metric.accumulate()
  427. metric.log()
  428. self._compose_callback.on_epoch_end(self.status)
  429. # reset metric states for metric may performed multiple times
  430. self._reset_metrics()
  431. def evaluate(self):
  432. with paddle.no_grad():
  433. self._eval_with_loader(self.loader)
  434. def predict(self,
  435. images,
  436. draw_threshold=0.5,
  437. output_dir='output',
  438. save_txt=False):
  439. self.dataset.set_images(images)
  440. loader = create('TestReader')(self.dataset, 0)
  441. imid2path = self.dataset.get_imid2path()
  442. anno_file = self.dataset.get_anno()
  443. clsid2catid, catid2name = get_categories(
  444. self.cfg.metric, anno_file=anno_file)
  445. # Run Infer
  446. self.status['mode'] = 'test'
  447. self.model.eval()
  448. if self.cfg.get('print_flops', False):
  449. self._flops(loader)
  450. results = []
  451. for step_id, data in enumerate(loader):
  452. self.status['step_id'] = step_id
  453. # forward
  454. outs = self.model(data)
  455. for key in ['im_shape', 'scale_factor', 'im_id']:
  456. if isinstance(data, typing.Sequence):
  457. outs[key] = data[0][key]
  458. else:
  459. outs[key] = data[key]
  460. for key, value in outs.items():
  461. if hasattr(value, 'numpy'):
  462. outs[key] = value.numpy()
  463. results.append(outs)
  464. # sniper
  465. if type(self.dataset) == SniperCOCODataSet:
  466. results = self.dataset.anno_cropper.aggregate_chips_detections(
  467. results)
  468. for outs in results:
  469. batch_res = get_infer_results(outs, clsid2catid)
  470. bbox_num = outs['bbox_num']
  471. start = 0
  472. for i, im_id in enumerate(outs['im_id']):
  473. image_path = imid2path[int(im_id)]
  474. image = Image.open(image_path).convert('RGB')
  475. image = ImageOps.exif_transpose(image)
  476. self.status['original_image'] = np.array(image.copy())
  477. end = start + bbox_num[i]
  478. bbox_res = batch_res['bbox'][start:end] \
  479. if 'bbox' in batch_res else None
  480. mask_res = batch_res['mask'][start:end] \
  481. if 'mask' in batch_res else None
  482. segm_res = batch_res['segm'][start:end] \
  483. if 'segm' in batch_res else None
  484. keypoint_res = batch_res['keypoint'][start:end] \
  485. if 'keypoint' in batch_res else None
  486. image = visualize_results(
  487. image, bbox_res, mask_res, segm_res, keypoint_res,
  488. int(im_id), catid2name, draw_threshold)
  489. self.status['result_image'] = np.array(image.copy())
  490. if self._compose_callback:
  491. self._compose_callback.on_step_end(self.status)
  492. # save image with detection
  493. save_name = self._get_save_image_name(output_dir, image_path)
  494. logger.info("Detection bbox results save in {}".format(
  495. save_name))
  496. image.save(save_name, quality=95)
  497. if save_txt:
  498. save_path = os.path.splitext(save_name)[0] + '.txt'
  499. results = {}
  500. results["im_id"] = im_id
  501. if bbox_res:
  502. results["bbox_res"] = bbox_res
  503. if keypoint_res:
  504. results["keypoint_res"] = keypoint_res
  505. save_result(save_path, results, catid2name, draw_threshold)
  506. start = end
  507. def _get_save_image_name(self, output_dir, image_path):
  508. """
  509. Get save image name from source image path.
  510. """
  511. if not os.path.exists(output_dir):
  512. os.makedirs(output_dir)
  513. image_name = os.path.split(image_path)[-1]
  514. name, ext = os.path.splitext(image_name)
  515. return os.path.join(output_dir, "{}".format(name)) + ext
  516. def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
  517. image_shape = None
  518. im_shape = [None, 2]
  519. scale_factor = [None, 2]
  520. if self.cfg.architecture in MOT_ARCH:
  521. test_reader_name = 'TestMOTReader'
  522. else:
  523. test_reader_name = 'TestReader'
  524. if 'inputs_def' in self.cfg[test_reader_name]:
  525. inputs_def = self.cfg[test_reader_name]['inputs_def']
  526. image_shape = inputs_def.get('image_shape', None)
  527. # set image_shape=[None, 3, -1, -1] as default
  528. if image_shape is None:
  529. image_shape = [None, 3, -1, -1]
  530. if len(image_shape) == 3:
  531. image_shape = [None] + image_shape
  532. else:
  533. im_shape = [image_shape[0], 2]
  534. scale_factor = [image_shape[0], 2]
  535. if hasattr(self.model, 'deploy'):
  536. self.model.deploy = True
  537. if hasattr(self.model, 'fuse_norm'):
  538. self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
  539. False)
  540. # Save infer cfg
  541. _dump_infer_config(self.cfg,
  542. os.path.join(save_dir, 'infer_cfg.yml'),
  543. image_shape, self.model)
  544. input_spec = [{
  545. "image": InputSpec(
  546. shape=image_shape, name='image'),
  547. "im_shape": InputSpec(
  548. shape=im_shape, name='im_shape'),
  549. "scale_factor": InputSpec(
  550. shape=scale_factor, name='scale_factor')
  551. }]
  552. if self.cfg.architecture == 'DeepSORT':
  553. input_spec[0].update({
  554. "crops": InputSpec(
  555. shape=[None, 3, 192, 64], name='crops')
  556. })
  557. if prune_input:
  558. static_model = paddle.jit.to_static(
  559. self.model, input_spec=input_spec)
  560. # NOTE: dy2st do not pruned program, but jit.save will prune program
  561. # input spec, prune input spec here and save with pruned input spec
  562. pruned_input_spec = _prune_input_spec(
  563. input_spec, static_model.forward.main_program,
  564. static_model.forward.outputs)
  565. else:
  566. static_model = None
  567. pruned_input_spec = input_spec
  568. # TODO: Hard code, delete it when support prune input_spec.
  569. if self.cfg.architecture == 'PicoDet':
  570. pruned_input_spec = [{
  571. "image": InputSpec(
  572. shape=image_shape, name='image')
  573. }]
  574. return static_model, pruned_input_spec
  575. def export(self, output_dir='output_inference'):
  576. self.model.eval()
  577. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  578. save_dir = os.path.join(output_dir, model_name)
  579. if not os.path.exists(save_dir):
  580. os.makedirs(save_dir)
  581. static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  582. save_dir)
  583. # dy2st and save model
  584. if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
  585. paddle.jit.save(
  586. static_model,
  587. os.path.join(save_dir, 'model'),
  588. input_spec=pruned_input_spec)
  589. else:
  590. self.cfg.slim.save_quantized_model(
  591. self.model,
  592. os.path.join(save_dir, 'model'),
  593. input_spec=pruned_input_spec)
  594. logger.info("Export model and saved in {}".format(save_dir))
  595. def post_quant(self, output_dir='output_inference'):
  596. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  597. save_dir = os.path.join(output_dir, model_name)
  598. if not os.path.exists(save_dir):
  599. os.makedirs(save_dir)
  600. for idx, data in enumerate(self.loader):
  601. self.model(data)
  602. if idx == int(self.cfg.get('quant_batch_num', 10)):
  603. break
  604. # TODO: support prune input_spec
  605. _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  606. save_dir, prune_input=False)
  607. self.cfg.slim.save_quantized_model(
  608. self.model,
  609. os.path.join(save_dir, 'model'),
  610. input_spec=pruned_input_spec)
  611. logger.info("Export Post-Quant model and saved in {}".format(save_dir))
  612. def _flops(self, loader):
  613. self.model.eval()
  614. try:
  615. import paddleslim
  616. except Exception as e:
  617. logger.warning(
  618. 'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
  619. )
  620. return
  621. from paddleslim.analysis import dygraph_flops as flops
  622. input_data = None
  623. for data in loader:
  624. input_data = data
  625. break
  626. input_spec = [{
  627. "image": input_data['image'][0].unsqueeze(0),
  628. "im_shape": input_data['im_shape'][0].unsqueeze(0),
  629. "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
  630. }]
  631. flops = flops(self.model, input_spec) / (1000**3)
  632. logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
  633. flops, input_data['image'][0].unsqueeze(0).shape))