trainer.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import copy
  20. import time
  21. import numpy as np
  22. import typing
  23. from PIL import Image, ImageOps
  24. import paddle
  25. import paddle.distributed as dist
  26. from paddle.distributed import fleet
  27. from paddle import amp
  28. from paddle.static import InputSpec
  29. from paddlex.ppdet.optimizer import ModelEMA
  30. from paddlex.ppdet.core.workspace import create
  31. from paddlex.ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  32. from paddlex.ppdet.utils.visualizer import visualize_results, save_result
  33. from paddlex.ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
  34. from paddlex.ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
  35. from paddlex.ppdet.data.source.sniper_coco import SniperCOCODataSet
  36. from paddlex.ppdet.data.source.category import get_categories
  37. from paddlex.ppdet.utils import stats
  38. from paddlex.ppdet.utils import profiler
  39. from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
  40. from .export_utils import _dump_infer_config, _prune_input_spec
  41. from paddlex.ppdet.utils.logger import setup_logger
  42. logger = setup_logger('ppdet.engine')
  43. __all__ = ['Trainer']
  44. MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']
  45. class Trainer(object):
  46. def __init__(self, cfg, mode='train'):
  47. self.cfg = cfg
  48. assert mode.lower() in ['train', 'eval', 'test'], \
  49. "mode should be 'train', 'eval' or 'test'"
  50. self.mode = mode.lower()
  51. self.optimizer = None
  52. self.is_loaded_weights = False
  53. # build data loader
  54. if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
  55. self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
  56. else:
  57. self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
  58. if cfg.architecture == 'DeepSORT' and self.mode == 'train':
  59. logger.error('DeepSORT has no need of training on mot dataset.')
  60. sys.exit(1)
  61. if self.mode == 'train':
  62. self.loader = create('{}Reader'.format(self.mode.capitalize()))(
  63. self.dataset, cfg.worker_num)
  64. if cfg.architecture == 'JDE' and self.mode == 'train':
  65. cfg['JDEEmbeddingHead'][
  66. 'num_identities'] = self.dataset.num_identities_dict[0]
  67. # JDE only support single class MOT now.
  68. if cfg.architecture == 'FairMOT' and self.mode == 'train':
  69. cfg['FairMOTEmbeddingHead'][
  70. 'num_identities_dict'] = self.dataset.num_identities_dict
  71. # FairMOT support single class and multi-class MOT now.
  72. # build model
  73. if 'model' not in self.cfg:
  74. self.model = create(cfg.architecture)
  75. else:
  76. self.model = self.cfg.model
  77. self.is_loaded_weights = True
  78. #normalize params for deploy
  79. self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
  80. self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
  81. if self.use_ema:
  82. ema_decay = self.cfg.get('ema_decay', 0.9998)
  83. cycle_epoch = self.cfg.get('cycle_epoch', -1)
  84. self.ema = ModelEMA(
  85. self.model,
  86. decay=ema_decay,
  87. use_thres_step=True,
  88. cycle_epoch=cycle_epoch)
  89. # EvalDataset build with BatchSampler to evaluate in single device
  90. # TODO: multi-device evaluate
  91. if self.mode == 'eval':
  92. self._eval_batch_sampler = paddle.io.BatchSampler(
  93. self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
  94. reader_name = '{}Reader'.format(self.mode.capitalize())
  95. # If metric is VOC, need to be set collate_batch=False.
  96. if cfg.metric == 'VOC':
  97. cfg[reader_name]['collate_batch'] = False
  98. self.loader = create(reader_name)(self.dataset, cfg.worker_num,
  99. self._eval_batch_sampler)
  100. # TestDataset build after user set images, skip loader creation here
  101. # build optimizer in train mode
  102. if self.mode == 'train':
  103. steps_per_epoch = len(self.loader)
  104. self.lr = create('LearningRate')(steps_per_epoch)
  105. self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
  106. if self.cfg.get('unstructured_prune'):
  107. self.pruner = create('UnstructuredPruner')(self.model,
  108. steps_per_epoch)
  109. self._nranks = dist.get_world_size()
  110. self._local_rank = dist.get_rank()
  111. self.status = {}
  112. self.start_epoch = 0
  113. self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
  114. # initial default callbacks
  115. self._init_callbacks()
  116. # initial default metrics
  117. self._init_metrics()
  118. self._reset_metrics()
  119. def _init_callbacks(self):
  120. if self.mode == 'train':
  121. self._callbacks = [LogPrinter(self), Checkpointer(self)]
  122. if self.cfg.get('use_vdl', False):
  123. self._callbacks.append(VisualDLWriter(self))
  124. if self.cfg.get('save_proposals', False):
  125. self._callbacks.append(SniperProposalsGenerator(self))
  126. self._compose_callback = ComposeCallback(self._callbacks)
  127. elif self.mode == 'eval':
  128. self._callbacks = [LogPrinter(self)]
  129. if self.cfg.metric == 'WiderFace':
  130. self._callbacks.append(WiferFaceEval(self))
  131. self._compose_callback = ComposeCallback(self._callbacks)
  132. elif self.mode == 'test' and self.cfg.get('use_vdl', False):
  133. self._callbacks = [VisualDLWriter(self)]
  134. self._compose_callback = ComposeCallback(self._callbacks)
  135. else:
  136. self._callbacks = []
  137. self._compose_callback = None
  138. def _init_metrics(self, validate=False):
  139. if self.mode == 'test' or (self.mode == 'train' and not validate):
  140. self._metrics = []
  141. return
  142. classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
  143. if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
  144. # TODO: bias should be unified
  145. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  146. output_eval = self.cfg['output_eval'] \
  147. if 'output_eval' in self.cfg else None
  148. save_prediction_only = self.cfg.get('save_prediction_only', False)
  149. # pass clsid2catid info to metric instance to avoid multiple loading
  150. # annotation file
  151. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  152. if self.mode == 'eval' else None
  153. # when do validation in train, annotation file should be get from
  154. # EvalReader instead of self.dataset(which is TrainReader)
  155. anno_file = self.dataset.get_anno()
  156. dataset = self.dataset
  157. if self.mode == 'train' and validate:
  158. eval_dataset = self.cfg['EvalDataset']
  159. eval_dataset.check_or_download_dataset()
  160. anno_file = eval_dataset.get_anno()
  161. dataset = eval_dataset
  162. IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
  163. if self.cfg.metric == "COCO":
  164. self._metrics = [
  165. COCOMetric(
  166. anno_file=anno_file,
  167. clsid2catid=clsid2catid,
  168. classwise=classwise,
  169. output_eval=output_eval,
  170. bias=bias,
  171. IouType=IouType,
  172. save_prediction_only=save_prediction_only)
  173. ]
  174. elif self.cfg.metric == "SNIPERCOCO": # sniper
  175. self._metrics = [
  176. SNIPERCOCOMetric(
  177. anno_file=anno_file,
  178. dataset=dataset,
  179. clsid2catid=clsid2catid,
  180. classwise=classwise,
  181. output_eval=output_eval,
  182. bias=bias,
  183. IouType=IouType,
  184. save_prediction_only=save_prediction_only)
  185. ]
  186. elif self.cfg.metric == 'RBOX':
  187. # TODO: bias should be unified
  188. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  189. output_eval = self.cfg['output_eval'] \
  190. if 'output_eval' in self.cfg else None
  191. save_prediction_only = self.cfg.get('save_prediction_only', False)
  192. # pass clsid2catid info to metric instance to avoid multiple loading
  193. # annotation file
  194. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  195. if self.mode == 'eval' else None
  196. # when do validation in train, annotation file should be get from
  197. # EvalReader instead of self.dataset(which is TrainReader)
  198. anno_file = self.dataset.get_anno()
  199. if self.mode == 'train' and validate:
  200. eval_dataset = self.cfg['EvalDataset']
  201. eval_dataset.check_or_download_dataset()
  202. anno_file = eval_dataset.get_anno()
  203. self._metrics = [
  204. RBoxMetric(
  205. anno_file=anno_file,
  206. clsid2catid=clsid2catid,
  207. classwise=classwise,
  208. output_eval=output_eval,
  209. bias=bias,
  210. save_prediction_only=save_prediction_only)
  211. ]
  212. elif self.cfg.metric == 'VOC':
  213. self._metrics = [
  214. VOCMetric(
  215. label_list=self.dataset.get_label_list(),
  216. class_num=self.cfg.num_classes,
  217. map_type=self.cfg.map_type,
  218. classwise=classwise)
  219. ]
  220. elif self.cfg.metric == 'WiderFace':
  221. multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
  222. self._metrics = [
  223. WiderFaceMetric(
  224. image_dir=os.path.join(self.dataset.dataset_dir,
  225. self.dataset.image_dir),
  226. anno_file=self.dataset.get_anno(),
  227. multi_scale=multi_scale)
  228. ]
  229. elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
  230. eval_dataset = self.cfg['EvalDataset']
  231. eval_dataset.check_or_download_dataset()
  232. anno_file = eval_dataset.get_anno()
  233. save_prediction_only = self.cfg.get('save_prediction_only', False)
  234. self._metrics = [
  235. KeyPointTopDownCOCOEval(
  236. anno_file,
  237. len(eval_dataset),
  238. self.cfg.num_joints,
  239. self.cfg.save_dir,
  240. save_prediction_only=save_prediction_only)
  241. ]
  242. elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
  243. eval_dataset = self.cfg['EvalDataset']
  244. eval_dataset.check_or_download_dataset()
  245. anno_file = eval_dataset.get_anno()
  246. save_prediction_only = self.cfg.get('save_prediction_only', False)
  247. self._metrics = [
  248. KeyPointTopDownMPIIEval(
  249. anno_file,
  250. len(eval_dataset),
  251. self.cfg.num_joints,
  252. self.cfg.save_dir,
  253. save_prediction_only=save_prediction_only)
  254. ]
  255. elif self.cfg.metric == 'MOTDet':
  256. self._metrics = [JDEDetMetric(), ]
  257. else:
  258. logger.warning("Metric not support for metric type {}".format(
  259. self.cfg.metric))
  260. self._metrics = []
  261. def _reset_metrics(self):
  262. for metric in self._metrics:
  263. metric.reset()
  264. def register_callbacks(self, callbacks):
  265. callbacks = [c for c in list(callbacks) if c is not None]
  266. for c in callbacks:
  267. assert isinstance(c, Callback), \
  268. "metrics shoule be instances of subclass of Metric"
  269. self._callbacks.extend(callbacks)
  270. self._compose_callback = ComposeCallback(self._callbacks)
  271. def register_metrics(self, metrics):
  272. metrics = [m for m in list(metrics) if m is not None]
  273. for m in metrics:
  274. assert isinstance(m, Metric), \
  275. "metrics shoule be instances of subclass of Metric"
  276. self._metrics.extend(metrics)
  277. def load_weights(self, weights):
  278. if self.is_loaded_weights:
  279. return
  280. self.start_epoch = 0
  281. load_pretrain_weight(self.model, weights)
  282. logger.debug("Load weights {} to start training".format(weights))
  283. def load_weights_sde(self, det_weights, reid_weights):
  284. if self.model.detector:
  285. load_weight(self.model.detector, det_weights)
  286. load_weight(self.model.reid, reid_weights)
  287. else:
  288. load_weight(self.model.reid, reid_weights)
  289. def resume_weights(self, weights):
  290. # support Distill resume weights
  291. if hasattr(self.model, 'student_model'):
  292. self.start_epoch = load_weight(self.model.student_model, weights,
  293. self.optimizer)
  294. else:
  295. self.start_epoch = load_weight(self.model, weights, self.optimizer)
  296. logger.debug("Resume weights of epoch {}".format(self.start_epoch))
  297. def train(self, validate=False):
  298. assert self.mode == 'train', "Model not in 'train' mode"
  299. Init_mark = False
  300. sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
  301. self.cfg.use_gpu and self._nranks > 1)
  302. if sync_bn:
  303. self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
  304. self.model)
  305. model = self.model
  306. if self.cfg.get('fleet', False):
  307. model = fleet.distributed_model(model)
  308. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  309. elif self._nranks > 1:
  310. find_unused_parameters = self.cfg[
  311. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  312. model = paddle.DataParallel(
  313. self.model, find_unused_parameters=find_unused_parameters)
  314. # initial fp16
  315. if self.cfg.get('fp16', False):
  316. scaler = amp.GradScaler(
  317. enable=self.cfg.use_gpu, init_loss_scaling=1024)
  318. self.status.update({
  319. 'epoch_id': self.start_epoch,
  320. 'step_id': 0,
  321. 'steps_per_epoch': len(self.loader)
  322. })
  323. self.status['batch_time'] = stats.SmoothedValue(
  324. self.cfg.log_iter, fmt='{avg:.4f}')
  325. self.status['data_time'] = stats.SmoothedValue(
  326. self.cfg.log_iter, fmt='{avg:.4f}')
  327. self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
  328. if self.cfg.get('print_flops', False):
  329. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  330. self.dataset, self.cfg.worker_num)
  331. self._flops(flops_loader)
  332. profiler_options = self.cfg.get('profiler_options', None)
  333. self._compose_callback.on_train_begin(self.status)
  334. for epoch_id in range(self.start_epoch, self.cfg.epoch):
  335. self.status['mode'] = 'train'
  336. self.status['epoch_id'] = epoch_id
  337. self._compose_callback.on_epoch_begin(self.status)
  338. self.loader.dataset.set_epoch(epoch_id)
  339. model.train()
  340. iter_tic = time.time()
  341. for step_id, data in enumerate(self.loader):
  342. self.status['data_time'].update(time.time() - iter_tic)
  343. self.status['step_id'] = step_id
  344. profiler.add_profiler_step(profiler_options)
  345. self._compose_callback.on_step_begin(self.status)
  346. data['epoch_id'] = epoch_id
  347. if self.cfg.get('fp16', False):
  348. with amp.auto_cast(enable=self.cfg.use_gpu):
  349. # model forward
  350. outputs = model(data)
  351. loss = outputs['loss']
  352. # model backward
  353. scaled_loss = scaler.scale(loss)
  354. scaled_loss.backward()
  355. # in dygraph mode, optimizer.minimize is equal to optimizer.step
  356. scaler.minimize(self.optimizer, scaled_loss)
  357. else:
  358. # model forward
  359. outputs = model(data)
  360. loss = outputs['loss']
  361. # model backward
  362. loss.backward()
  363. self.optimizer.step()
  364. curr_lr = self.optimizer.get_lr()
  365. self.lr.step()
  366. if self.cfg.get('unstructured_prune'):
  367. self.pruner.step()
  368. self.optimizer.clear_grad()
  369. self.status['learning_rate'] = curr_lr
  370. if self._nranks < 2 or self._local_rank == 0:
  371. self.status['training_staus'].update(outputs)
  372. self.status['batch_time'].update(time.time() - iter_tic)
  373. self._compose_callback.on_step_end(self.status)
  374. if self.use_ema:
  375. self.ema.update(self.model)
  376. iter_tic = time.time()
  377. # apply ema weight on model
  378. if self.use_ema:
  379. weight = copy.deepcopy(self.model.state_dict())
  380. self.model.set_dict(self.ema.apply())
  381. if self.cfg.get('unstructured_prune'):
  382. self.pruner.update_params()
  383. self._compose_callback.on_epoch_end(self.status)
  384. if validate and (self._nranks < 2 or self._local_rank == 0) \
  385. and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
  386. or epoch_id == self.end_epoch - 1):
  387. if not hasattr(self, '_eval_loader'):
  388. # build evaluation dataset and loader
  389. self._eval_dataset = self.cfg.EvalDataset
  390. self._eval_batch_sampler = \
  391. paddle.io.BatchSampler(
  392. self._eval_dataset,
  393. batch_size=self.cfg.EvalReader['batch_size'])
  394. # If metric is VOC, need to be set collate_batch=False.
  395. if self.cfg.metric == 'VOC':
  396. self.cfg['EvalReader']['collate_batch'] = False
  397. self._eval_loader = create('EvalReader')(
  398. self._eval_dataset,
  399. self.cfg.worker_num,
  400. batch_sampler=self._eval_batch_sampler)
  401. # if validation in training is enabled, metrics should be re-init
  402. # Init_mark makes sure this code will only execute once
  403. if validate and Init_mark == False:
  404. Init_mark = True
  405. self._init_metrics(validate=validate)
  406. self._reset_metrics()
  407. with paddle.no_grad():
  408. self.status['save_best_model'] = True
  409. self._eval_with_loader(self._eval_loader)
  410. # restore origin weight on model
  411. if self.use_ema:
  412. self.model.set_dict(weight)
  413. self._compose_callback.on_train_end(self.status)
  414. def _eval_with_loader(self, loader):
  415. sample_num = 0
  416. tic = time.time()
  417. self._compose_callback.on_epoch_begin(self.status)
  418. self.status['mode'] = 'eval'
  419. self.model.eval()
  420. if self.cfg.get('print_flops', False):
  421. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  422. self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
  423. self._flops(flops_loader)
  424. for step_id, data in enumerate(loader):
  425. self.status['step_id'] = step_id
  426. self._compose_callback.on_step_begin(self.status)
  427. # forward
  428. outs = self.model(data)
  429. # update metrics
  430. for metric in self._metrics:
  431. metric.update(data, outs)
  432. # multi-scale inputs: all inputs have same im_id
  433. if isinstance(data, typing.Sequence):
  434. sample_num += data[0]['im_id'].numpy().shape[0]
  435. else:
  436. sample_num += data['im_id'].numpy().shape[0]
  437. self._compose_callback.on_step_end(self.status)
  438. self.status['sample_num'] = sample_num
  439. self.status['cost_time'] = time.time() - tic
  440. # accumulate metric to log out
  441. for metric in self._metrics:
  442. metric.accumulate()
  443. metric.log()
  444. self._compose_callback.on_epoch_end(self.status)
  445. # reset metric states for metric may performed multiple times
  446. self._reset_metrics()
  447. def evaluate(self):
  448. with paddle.no_grad():
  449. self._eval_with_loader(self.loader)
  450. def predict(self,
  451. images,
  452. draw_threshold=0.5,
  453. output_dir='output',
  454. save_txt=False):
  455. self.dataset.set_images(images)
  456. loader = create('TestReader')(self.dataset, 0)
  457. imid2path = self.dataset.get_imid2path()
  458. anno_file = self.dataset.get_anno()
  459. clsid2catid, catid2name = get_categories(
  460. self.cfg.metric, anno_file=anno_file)
  461. # Run Infer
  462. self.status['mode'] = 'test'
  463. self.model.eval()
  464. if self.cfg.get('print_flops', False):
  465. flops_loader = create('TestReader')(self.dataset, 0)
  466. self._flops(flops_loader)
  467. results = []
  468. for step_id, data in enumerate(loader):
  469. self.status['step_id'] = step_id
  470. # forward
  471. outs = self.model(data)
  472. for key in ['im_shape', 'scale_factor', 'im_id']:
  473. if isinstance(data, typing.Sequence):
  474. outs[key] = data[0][key]
  475. else:
  476. outs[key] = data[key]
  477. for key, value in outs.items():
  478. if hasattr(value, 'numpy'):
  479. outs[key] = value.numpy()
  480. results.append(outs)
  481. # sniper
  482. if type(self.dataset) == SniperCOCODataSet:
  483. results = self.dataset.anno_cropper.aggregate_chips_detections(
  484. results)
  485. for outs in results:
  486. batch_res = get_infer_results(outs, clsid2catid)
  487. bbox_num = outs['bbox_num']
  488. start = 0
  489. for i, im_id in enumerate(outs['im_id']):
  490. image_path = imid2path[int(im_id)]
  491. image = Image.open(image_path).convert('RGB')
  492. image = ImageOps.exif_transpose(image)
  493. self.status['original_image'] = np.array(image.copy())
  494. end = start + bbox_num[i]
  495. bbox_res = batch_res['bbox'][start:end] \
  496. if 'bbox' in batch_res else None
  497. mask_res = batch_res['mask'][start:end] \
  498. if 'mask' in batch_res else None
  499. segm_res = batch_res['segm'][start:end] \
  500. if 'segm' in batch_res else None
  501. keypoint_res = batch_res['keypoint'][start:end] \
  502. if 'keypoint' in batch_res else None
  503. image = visualize_results(
  504. image, bbox_res, mask_res, segm_res, keypoint_res,
  505. int(im_id), catid2name, draw_threshold)
  506. self.status['result_image'] = np.array(image.copy())
  507. if self._compose_callback:
  508. self._compose_callback.on_step_end(self.status)
  509. # save image with detection
  510. save_name = self._get_save_image_name(output_dir, image_path)
  511. logger.info("Detection bbox results save in {}".format(
  512. save_name))
  513. image.save(save_name, quality=95)
  514. if save_txt:
  515. save_path = os.path.splitext(save_name)[0] + '.txt'
  516. results = {}
  517. results["im_id"] = im_id
  518. if bbox_res:
  519. results["bbox_res"] = bbox_res
  520. if keypoint_res:
  521. results["keypoint_res"] = keypoint_res
  522. save_result(save_path, results, catid2name, draw_threshold)
  523. start = end
  524. def _get_save_image_name(self, output_dir, image_path):
  525. """
  526. Get save image name from source image path.
  527. """
  528. if not os.path.exists(output_dir):
  529. os.makedirs(output_dir)
  530. image_name = os.path.split(image_path)[-1]
  531. name, ext = os.path.splitext(image_name)
  532. return os.path.join(output_dir, "{}".format(name)) + ext
  533. def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
  534. image_shape = None
  535. im_shape = [None, 2]
  536. scale_factor = [None, 2]
  537. if self.cfg.architecture in MOT_ARCH:
  538. test_reader_name = 'TestMOTReader'
  539. else:
  540. test_reader_name = 'TestReader'
  541. if 'inputs_def' in self.cfg[test_reader_name]:
  542. inputs_def = self.cfg[test_reader_name]['inputs_def']
  543. image_shape = inputs_def.get('image_shape', None)
  544. # set image_shape=[None, 3, -1, -1] as default
  545. if image_shape is None:
  546. image_shape = [None, 3, -1, -1]
  547. if len(image_shape) == 3:
  548. image_shape = [None] + image_shape
  549. else:
  550. im_shape = [image_shape[0], 2]
  551. scale_factor = [image_shape[0], 2]
  552. if hasattr(self.model, 'deploy'):
  553. self.model.deploy = True
  554. if hasattr(self.model, 'fuse_norm'):
  555. self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
  556. False)
  557. # Save infer cfg
  558. _dump_infer_config(self.cfg,
  559. os.path.join(save_dir, 'infer_cfg.yml'),
  560. image_shape, self.model)
  561. input_spec = [{
  562. "image": InputSpec(
  563. shape=image_shape, name='image'),
  564. "im_shape": InputSpec(
  565. shape=im_shape, name='im_shape'),
  566. "scale_factor": InputSpec(
  567. shape=scale_factor, name='scale_factor')
  568. }]
  569. if self.cfg.architecture == 'DeepSORT':
  570. input_spec[0].update({
  571. "crops": InputSpec(
  572. shape=[None, 3, 192, 64], name='crops')
  573. })
  574. if prune_input:
  575. static_model = paddle.jit.to_static(
  576. self.model, input_spec=input_spec)
  577. # NOTE: dy2st do not pruned program, but jit.save will prune program
  578. # input spec, prune input spec here and save with pruned input spec
  579. pruned_input_spec = _prune_input_spec(
  580. input_spec, static_model.forward.main_program,
  581. static_model.forward.outputs)
  582. else:
  583. static_model = None
  584. pruned_input_spec = input_spec
  585. # TODO: Hard code, delete it when support prune input_spec.
  586. if self.cfg.architecture == 'PicoDet':
  587. pruned_input_spec = [{
  588. "image": InputSpec(
  589. shape=image_shape, name='image')
  590. }]
  591. return static_model, pruned_input_spec
  592. def export(self, output_dir='output_inference'):
  593. self.model.eval()
  594. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  595. save_dir = os.path.join(output_dir, model_name)
  596. if not os.path.exists(save_dir):
  597. os.makedirs(save_dir)
  598. static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  599. save_dir)
  600. # dy2st and save model
  601. if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
  602. paddle.jit.save(
  603. static_model,
  604. os.path.join(save_dir, 'model'),
  605. input_spec=pruned_input_spec)
  606. else:
  607. self.cfg.slim.save_quantized_model(
  608. self.model,
  609. os.path.join(save_dir, 'model'),
  610. input_spec=pruned_input_spec)
  611. logger.info("Export model and saved in {}".format(save_dir))
  612. def post_quant(self, output_dir='output_inference'):
  613. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  614. save_dir = os.path.join(output_dir, model_name)
  615. if not os.path.exists(save_dir):
  616. os.makedirs(save_dir)
  617. for idx, data in enumerate(self.loader):
  618. self.model(data)
  619. if idx == int(self.cfg.get('quant_batch_num', 10)):
  620. break
  621. # TODO: support prune input_spec
  622. _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  623. save_dir, prune_input=False)
  624. self.cfg.slim.save_quantized_model(
  625. self.model,
  626. os.path.join(save_dir, 'model'),
  627. input_spec=pruned_input_spec)
  628. logger.info("Export Post-Quant model and saved in {}".format(save_dir))
  629. def _flops(self, loader):
  630. self.model.eval()
  631. try:
  632. import paddleslim
  633. except Exception as e:
  634. logger.warning(
  635. 'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
  636. )
  637. return
  638. from paddleslim.analysis import dygraph_flops as flops
  639. input_data = None
  640. for data in loader:
  641. input_data = data
  642. break
  643. input_spec = [{
  644. "image": input_data['image'][0].unsqueeze(0),
  645. "im_shape": input_data['im_shape'][0].unsqueeze(0),
  646. "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
  647. }]
  648. flops = flops(self.model, input_spec) / (1000**3)
  649. logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
  650. flops, input_data['image'][0].unsqueeze(0).shape))