detector.py 69 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import ppdet
  24. from ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. super(BaseDetector, self).__init__('detector')
  41. if not hasattr(ppdet.modeling, model_name):
  42. raise Exception("ERROR: There's no model named {}.".format(
  43. model_name))
  44. self.model_name = model_name
  45. self.num_classes = num_classes
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. def build_net(self, **params):
  49. with paddle.utils.unique_name.guard():
  50. net = ppdet.modeling.__dict__[self.model_name](**params)
  51. return net
  52. def _fix_transforms_shape(self, image_shape):
  53. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  54. def _get_test_inputs(self, image_shape):
  55. if image_shape is not None:
  56. if len(image_shape) == 2:
  57. image_shape = [None, 3] + image_shape
  58. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  59. raise Exception(
  60. "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
  61. format(image_shape[-2:]))
  62. self._fix_transforms_shape(image_shape[-2:])
  63. else:
  64. image_shape = [None, 3, -1, -1]
  65. input_spec = [{
  66. "image": InputSpec(
  67. shape=image_shape, name='image', dtype='float32'),
  68. "im_shape": InputSpec(
  69. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  70. "scale_factor": InputSpec(
  71. shape=[image_shape[0], 2],
  72. name='scale_factor',
  73. dtype='float32')
  74. }]
  75. return input_spec
  76. def _get_backbone(self, backbone_name, **params):
  77. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  78. return backbone
  79. def run(self, net, inputs, mode):
  80. net_out = net(inputs)
  81. if mode in ['train', 'eval']:
  82. outputs = net_out
  83. else:
  84. for key in ['im_shape', 'scale_factor']:
  85. net_out[key] = inputs[key]
  86. outputs = dict()
  87. for key in net_out:
  88. outputs[key] = net_out[key].numpy()
  89. return outputs
  90. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  91. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  92. num_steps_each_epoch):
  93. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  94. values = [(lr_decay_gamma**i) * learning_rate
  95. for i in range(len(lr_decay_epochs) + 1)]
  96. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  97. boundaries=boundaries, values=values)
  98. if warmup_steps > 0:
  99. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  100. logging.error(
  101. "In function train(), parameters should satisfy: "
  102. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  103. exit=False)
  104. logging.error(
  105. "See this doc for more information: "
  106. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  107. exit=False)
  108. scheduler = paddle.optimizer.lr.LinearWarmup(
  109. learning_rate=scheduler,
  110. warmup_steps=warmup_steps,
  111. start_lr=warmup_start_lr,
  112. end_lr=learning_rate)
  113. optimizer = paddle.optimizer.Momentum(
  114. scheduler,
  115. momentum=.9,
  116. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  117. parameters=parameters)
  118. return optimizer
  119. def train(self,
  120. num_epochs,
  121. train_dataset,
  122. train_batch_size=64,
  123. eval_dataset=None,
  124. optimizer=None,
  125. save_interval_epochs=1,
  126. log_interval_steps=10,
  127. save_dir='output',
  128. pretrain_weights='IMAGENET',
  129. learning_rate=.001,
  130. warmup_steps=0,
  131. warmup_start_lr=0.0,
  132. lr_decay_epochs=(216, 243),
  133. lr_decay_gamma=0.1,
  134. metric=None,
  135. use_ema=False,
  136. early_stop=False,
  137. early_stop_patience=5,
  138. use_vdl=True):
  139. """
  140. Train the model.
  141. Args:
  142. num_epochs(int): The number of epochs.
  143. train_dataset(paddlex.dataset): Training dataset.
  144. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  145. eval_dataset(paddlex.dataset, optional):
  146. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  147. optimizer(paddle.optimizer.Optimizer or None, optional):
  148. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  149. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  150. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  151. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  152. pretrain_weights(str or None, optional):
  153. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  154. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  155. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  156. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  157. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  158. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  159. metric({'VOC', 'COCO', None}, optional):
  160. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  161. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  162. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  163. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  164. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  165. """
  166. if train_dataset.__class__.__name__ == 'VOCDetection':
  167. train_dataset.data_fields = {
  168. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  169. 'difficult'
  170. }
  171. elif train_dataset.__class__.__name__ == 'CocoDetection':
  172. if self.__class__.__name__ == 'MaskRCNN':
  173. train_dataset.data_fields = {
  174. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  175. 'gt_poly', 'is_crowd'
  176. }
  177. else:
  178. train_dataset.data_fields = {
  179. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  180. 'is_crowd'
  181. }
  182. if metric is None:
  183. if eval_dataset.__class__.__name__ == 'VOCDetection':
  184. self.metric = 'voc'
  185. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  186. self.metric = 'coco'
  187. else:
  188. assert metric.lower() in ['coco', 'voc'], \
  189. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  190. self.metric = metric.lower()
  191. self.labels = train_dataset.labels
  192. self.num_max_boxes = train_dataset.num_max_boxes
  193. train_dataset.batch_transforms = self._compose_batch_transform(
  194. train_dataset.transforms, mode='train')
  195. # build optimizer if not defined
  196. if optimizer is None:
  197. num_steps_each_epoch = len(train_dataset) // train_batch_size
  198. self.optimizer = self.default_optimizer(
  199. parameters=self.net.parameters(),
  200. learning_rate=learning_rate,
  201. warmup_steps=warmup_steps,
  202. warmup_start_lr=warmup_start_lr,
  203. lr_decay_epochs=lr_decay_epochs,
  204. lr_decay_gamma=lr_decay_gamma,
  205. num_steps_each_epoch=num_steps_each_epoch)
  206. else:
  207. self.optimizer = optimizer
  208. # initiate weights
  209. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  210. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  211. [self.model_name, self.backbone_name])]:
  212. logging.warning(
  213. "Path of pretrain_weights('{}') does not exist!".format(
  214. pretrain_weights))
  215. pretrain_weights = det_pretrain_weights_dict['_'.join(
  216. [self.model_name, self.backbone_name])][0]
  217. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  218. "If you don't want to use pretrain weights, "
  219. "set pretrain_weights to be None.".format(
  220. pretrain_weights))
  221. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  222. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  223. logging.error(
  224. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  225. exit=True)
  226. pretrained_dir = osp.join(save_dir, 'pretrain')
  227. self.net_initialize(
  228. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  229. if use_ema:
  230. ema = ExponentialMovingAverage(
  231. decay=.9998, model=self.net, use_thres_step=True)
  232. else:
  233. ema = None
  234. # start train loop
  235. self.train_loop(
  236. num_epochs=num_epochs,
  237. train_dataset=train_dataset,
  238. train_batch_size=train_batch_size,
  239. eval_dataset=eval_dataset,
  240. save_interval_epochs=save_interval_epochs,
  241. log_interval_steps=log_interval_steps,
  242. save_dir=save_dir,
  243. ema=ema,
  244. early_stop=early_stop,
  245. early_stop_patience=early_stop_patience,
  246. use_vdl=use_vdl)
  247. def quant_aware_train(self,
  248. num_epochs,
  249. train_dataset,
  250. train_batch_size=64,
  251. eval_dataset=None,
  252. optimizer=None,
  253. save_interval_epochs=1,
  254. log_interval_steps=10,
  255. save_dir='output',
  256. learning_rate=.00001,
  257. warmup_steps=0,
  258. warmup_start_lr=0.0,
  259. lr_decay_epochs=(216, 243),
  260. lr_decay_gamma=0.1,
  261. metric=None,
  262. use_ema=False,
  263. early_stop=False,
  264. early_stop_patience=5,
  265. use_vdl=True,
  266. quant_config=None):
  267. """
  268. Quantization-aware training.
  269. Args:
  270. num_epochs(int): The number of epochs.
  271. train_dataset(paddlex.dataset): Training dataset.
  272. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  273. eval_dataset(paddlex.dataset, optional):
  274. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  275. optimizer(paddle.optimizer.Optimizer or None, optional):
  276. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  277. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  278. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  279. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  280. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  281. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  282. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  283. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  284. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  285. metric({'VOC', 'COCO', None}, optional):
  286. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  287. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  288. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  289. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  290. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  291. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  292. configuration will be used. Defaults to None.
  293. """
  294. self._prepare_qat(quant_config)
  295. self.train(
  296. num_epochs=num_epochs,
  297. train_dataset=train_dataset,
  298. train_batch_size=train_batch_size,
  299. eval_dataset=eval_dataset,
  300. optimizer=optimizer,
  301. save_interval_epochs=save_interval_epochs,
  302. log_interval_steps=log_interval_steps,
  303. save_dir=save_dir,
  304. pretrain_weights=None,
  305. learning_rate=learning_rate,
  306. warmup_steps=warmup_steps,
  307. warmup_start_lr=warmup_start_lr,
  308. lr_decay_epochs=lr_decay_epochs,
  309. lr_decay_gamma=lr_decay_gamma,
  310. metric=metric,
  311. use_ema=use_ema,
  312. early_stop=early_stop,
  313. early_stop_patience=early_stop_patience,
  314. use_vdl=use_vdl)
  315. def evaluate(self,
  316. eval_dataset,
  317. batch_size=1,
  318. metric=None,
  319. return_details=False):
  320. """
  321. Evaluate the model.
  322. Args:
  323. eval_dataset(paddlex.dataset): Evaluation dataset.
  324. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  325. metric({'VOC', 'COCO', None}, optional):
  326. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  327. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  328. Returns:
  329. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  330. """
  331. if metric is None:
  332. if not hasattr(self, 'metric'):
  333. if eval_dataset.__class__.__name__ == 'VOCDetection':
  334. self.metric = 'voc'
  335. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  336. self.metric = 'coco'
  337. else:
  338. assert metric.lower() in ['coco', 'voc'], \
  339. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  340. self.metric = metric.lower()
  341. if self.metric == 'voc':
  342. eval_dataset.data_fields = {
  343. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  344. 'difficult'
  345. }
  346. elif self.metric == 'coco':
  347. if self.__class__.__name__ == 'MaskRCNN':
  348. eval_dataset.data_fields = {
  349. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  350. 'gt_poly', 'is_crowd'
  351. }
  352. else:
  353. eval_dataset.data_fields = {
  354. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  355. 'is_crowd'
  356. }
  357. eval_dataset.batch_transforms = self._compose_batch_transform(
  358. eval_dataset.transforms, mode='eval')
  359. arrange_transforms(
  360. model_type=self.model_type,
  361. transforms=eval_dataset.transforms,
  362. mode='eval')
  363. self.net.eval()
  364. nranks = paddle.distributed.get_world_size()
  365. local_rank = paddle.distributed.get_rank()
  366. if nranks > 1:
  367. # Initialize parallel environment if not done.
  368. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  369. ):
  370. paddle.distributed.init_parallel_env()
  371. if batch_size > 1:
  372. logging.warning(
  373. "Detector only supports single card evaluation with batch_size=1 "
  374. "during evaluation, so batch_size is forcibly set to 1.")
  375. batch_size = 1
  376. if nranks < 2 or local_rank == 0:
  377. self.eval_data_loader = self.build_data_loader(
  378. eval_dataset, batch_size=batch_size, mode='eval')
  379. is_bbox_normalized = False
  380. if eval_dataset.batch_transforms is not None:
  381. is_bbox_normalized = any(
  382. isinstance(t, _NormalizeBox)
  383. for t in eval_dataset.batch_transforms.batch_transforms)
  384. if self.metric == 'voc':
  385. eval_metric = VOCMetric(
  386. labels=eval_dataset.labels,
  387. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  388. is_bbox_normalized=is_bbox_normalized,
  389. classwise=False)
  390. else:
  391. eval_metric = COCOMetric(
  392. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  393. classwise=False)
  394. scores = collections.OrderedDict()
  395. logging.info(
  396. "Start to evaluate(total_samples={}, total_steps={})...".
  397. format(eval_dataset.num_samples, eval_dataset.num_samples))
  398. with paddle.no_grad():
  399. for step, data in enumerate(self.eval_data_loader):
  400. outputs = self.run(self.net, data, 'eval')
  401. eval_metric.update(data, outputs)
  402. eval_metric.accumulate()
  403. self.eval_details = eval_metric.details
  404. scores.update(eval_metric.get())
  405. eval_metric.reset()
  406. if return_details:
  407. return scores, self.eval_details
  408. return scores
  409. def predict(self, img_file, transforms=None):
  410. """
  411. Do inference.
  412. Args:
  413. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  414. Image path or decoded image data in a BGR format, which also could constitute a list,
  415. meaning all images to be predicted as a mini-batch.
  416. transforms(paddlex.transforms.Compose or None, optional):
  417. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  418. Returns:
  419. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  420. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  421. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  422. category_id(int): the predicted category ID
  423. category(str): category name
  424. bbox(list): bounding box in [x, y, w, h] format
  425. score(str): confidence
  426. """
  427. if transforms is None and not hasattr(self, 'test_transforms'):
  428. raise Exception("transforms need to be defined, now is None.")
  429. if transforms is None:
  430. transforms = self.test_transforms
  431. if isinstance(img_file, (str, np.ndarray)):
  432. images = [img_file]
  433. else:
  434. images = img_file
  435. batch_samples = self._preprocess(images, transforms)
  436. self.net.eval()
  437. outputs = self.run(self.net, batch_samples, 'test')
  438. prediction = self._postprocess(outputs)
  439. if isinstance(img_file, (str, np.ndarray)):
  440. prediction = prediction[0]
  441. return prediction
  442. def _preprocess(self, images, transforms):
  443. arrange_transforms(
  444. model_type=self.model_type, transforms=transforms, mode='test')
  445. batch_samples = list()
  446. for im in images:
  447. sample = {'image': im}
  448. batch_samples.append(transforms(sample))
  449. batch_transforms = self._compose_batch_transform(transforms, 'test')
  450. batch_samples = batch_transforms(batch_samples)
  451. for k, v in batch_samples.items():
  452. batch_samples[k] = paddle.to_tensor(v)
  453. return batch_samples
  454. def _postprocess(self, batch_pred):
  455. infer_result = {}
  456. if 'bbox' in batch_pred:
  457. bboxes = batch_pred['bbox']
  458. bbox_nums = batch_pred['bbox_num']
  459. det_res = []
  460. k = 0
  461. for i in range(len(bbox_nums)):
  462. det_nums = bbox_nums[i]
  463. for j in range(det_nums):
  464. dt = bboxes[k]
  465. k = k + 1
  466. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  467. if int(num_id) < 0:
  468. continue
  469. category = self.labels[int(num_id)]
  470. w = xmax - xmin
  471. h = ymax - ymin
  472. bbox = [xmin, ymin, w, h]
  473. dt_res = {
  474. 'category_id': int(num_id),
  475. 'category': category,
  476. 'bbox': bbox,
  477. 'score': score
  478. }
  479. det_res.append(dt_res)
  480. infer_result['bbox'] = det_res
  481. if 'mask' in batch_pred:
  482. masks = batch_pred['mask']
  483. bboxes = batch_pred['bbox']
  484. mask_nums = batch_pred['bbox_num']
  485. seg_res = []
  486. k = 0
  487. for i in range(len(mask_nums)):
  488. det_nums = mask_nums[i]
  489. for j in range(det_nums):
  490. mask = masks[k].astype(np.uint8)
  491. score = float(bboxes[k][1])
  492. label = int(bboxes[k][0])
  493. k = k + 1
  494. if label == -1:
  495. continue
  496. category = self.labels[int(label)]
  497. import pycocotools.mask as mask_util
  498. rle = mask_util.encode(
  499. np.array(
  500. mask[:, :, None], order="F", dtype="uint8"))[0]
  501. if six.PY3:
  502. if 'counts' in rle:
  503. rle['counts'] = rle['counts'].decode("utf8")
  504. sg_res = {
  505. 'category_id': int(label) + 1,
  506. 'category': category,
  507. 'segmentation': rle,
  508. 'score': score
  509. }
  510. seg_res.append(sg_res)
  511. infer_result['mask'] = seg_res
  512. bbox_num = batch_pred['bbox_num']
  513. results = []
  514. start = 0
  515. for num in bbox_num:
  516. end = start + num
  517. curr_res = infer_result['bbox'][start:end]
  518. if 'mask' in infer_result:
  519. mask_res = infer_result['mask'][start:end]
  520. for box, mask in zip(curr_res, mask_res):
  521. box.update(mask)
  522. results.append(curr_res)
  523. start = end
  524. return results
  525. class YOLOv3(BaseDetector):
  526. def __init__(self,
  527. num_classes=80,
  528. backbone='MobileNetV1',
  529. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  530. [59, 119], [116, 90], [156, 198], [373, 326]],
  531. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  532. ignore_threshold=0.7,
  533. nms_score_threshold=0.01,
  534. nms_topk=1000,
  535. nms_keep_topk=100,
  536. nms_iou_threshold=0.45,
  537. label_smooth=False):
  538. self.init_params = locals()
  539. if backbone not in [
  540. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  541. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  542. ]:
  543. raise ValueError(
  544. "backbone: {} is not supported. Please choose one of "
  545. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  546. format(backbone))
  547. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  548. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  549. norm_type = 'sync_bn'
  550. else:
  551. norm_type = 'bn'
  552. self.backbone_name = backbone
  553. if 'MobileNetV1' in backbone:
  554. norm_type = 'bn'
  555. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  556. elif 'MobileNetV3' in backbone:
  557. backbone = self._get_backbone(
  558. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  559. elif backbone == 'ResNet50_vd_dcn':
  560. backbone = self._get_backbone(
  561. 'ResNet',
  562. norm_type=norm_type,
  563. variant='d',
  564. return_idx=[1, 2, 3],
  565. dcn_v2_stages=[3],
  566. freeze_at=-1,
  567. freeze_norm=False)
  568. elif backbone == 'ResNet34':
  569. backbone = self._get_backbone(
  570. 'ResNet',
  571. depth=34,
  572. norm_type=norm_type,
  573. return_idx=[1, 2, 3],
  574. freeze_at=-1,
  575. freeze_norm=False,
  576. norm_decay=0.)
  577. else:
  578. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  579. neck = ppdet.modeling.YOLOv3FPN(
  580. norm_type=norm_type,
  581. in_channels=[i.channels for i in backbone.out_shape])
  582. loss = ppdet.modeling.YOLOv3Loss(
  583. num_classes=num_classes,
  584. ignore_thresh=ignore_threshold,
  585. label_smooth=label_smooth)
  586. yolo_head = ppdet.modeling.YOLOv3Head(
  587. in_channels=[i.channels for i in neck.out_shape],
  588. anchors=anchors,
  589. anchor_masks=anchor_masks,
  590. num_classes=num_classes,
  591. loss=loss)
  592. post_process = ppdet.modeling.BBoxPostProcess(
  593. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  594. nms=ppdet.modeling.MultiClassNMS(
  595. score_threshold=nms_score_threshold,
  596. nms_top_k=nms_topk,
  597. keep_top_k=nms_keep_topk,
  598. nms_threshold=nms_iou_threshold))
  599. params = {
  600. 'backbone': backbone,
  601. 'neck': neck,
  602. 'yolo_head': yolo_head,
  603. 'post_process': post_process
  604. }
  605. super(YOLOv3, self).__init__(
  606. model_name='YOLOv3', num_classes=num_classes, **params)
  607. self.anchors = anchors
  608. self.anchor_masks = anchor_masks
  609. def _compose_batch_transform(self, transforms, mode='train'):
  610. if mode == 'train':
  611. default_batch_transforms = [
  612. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  613. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  614. _Gt2YoloTarget(
  615. anchor_masks=self.anchor_masks,
  616. anchors=self.anchors,
  617. downsample_ratios=getattr(self, 'downsample_ratios',
  618. [32, 16, 8]),
  619. num_classes=self.num_classes)
  620. ]
  621. else:
  622. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  623. if mode == 'eval' and self.metric == 'voc':
  624. collate_batch = False
  625. else:
  626. collate_batch = True
  627. custom_batch_transforms = []
  628. for i, op in enumerate(transforms.transforms):
  629. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  630. if mode != 'train':
  631. raise Exception(
  632. "{} cannot be present in the {} transforms. ".format(
  633. op.__class__.__name__, mode) +
  634. "Please check the {} transforms.".format(mode))
  635. custom_batch_transforms.insert(0, copy.deepcopy(op))
  636. batch_transforms = BatchCompose(
  637. custom_batch_transforms + default_batch_transforms,
  638. collate_batch=collate_batch)
  639. return batch_transforms
  640. def _fix_transforms_shape(self, image_shape):
  641. if hasattr(self, 'test_transforms'):
  642. if self.test_transforms is not None:
  643. has_resize_op = False
  644. resize_op_idx = -1
  645. normalize_op_idx = len(self.test_transforms.transforms)
  646. for idx, op in enumerate(self.test_transforms.transforms):
  647. name = op.__class__.__name__
  648. if name == 'Resize':
  649. has_resize_op = True
  650. resize_op_idx = idx
  651. if name == 'Normalize':
  652. normalize_op_idx = idx
  653. if not has_resize_op:
  654. self.test_transforms.transforms.insert(
  655. normalize_op_idx,
  656. Resize(
  657. target_size=image_shape, interp='CUBIC'))
  658. else:
  659. self.test_transforms.transforms[
  660. resize_op_idx].target_size = image_shape
  661. class FasterRCNN(BaseDetector):
  662. def __init__(self,
  663. num_classes=80,
  664. backbone='ResNet50',
  665. with_fpn=True,
  666. with_dcn=False,
  667. aspect_ratios=[0.5, 1.0, 2.0],
  668. anchor_sizes=[[32], [64], [128], [256], [512]],
  669. keep_top_k=100,
  670. nms_threshold=0.5,
  671. score_threshold=0.05,
  672. fpn_num_channels=256,
  673. rpn_batch_size_per_im=256,
  674. rpn_fg_fraction=0.5,
  675. test_pre_nms_top_n=None,
  676. test_post_nms_top_n=1000):
  677. self.init_params = locals()
  678. if backbone not in [
  679. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  680. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  681. ]:
  682. raise ValueError(
  683. "backbone: {} is not supported. Please choose one of "
  684. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  685. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  686. self.backbone_name = backbone
  687. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  688. if backbone == 'HRNet_W18':
  689. if not with_fpn:
  690. logging.warning(
  691. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  692. format(backbone))
  693. with_fpn = True
  694. if with_dcn:
  695. logging.warning(
  696. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  697. format(backbone))
  698. backbone = self._get_backbone(
  699. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  700. elif backbone == 'ResNet50_vd_ssld':
  701. if not with_fpn:
  702. logging.warning(
  703. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  704. format(backbone))
  705. with_fpn = True
  706. backbone = self._get_backbone(
  707. 'ResNet',
  708. variant='d',
  709. norm_type='bn',
  710. freeze_at=0,
  711. return_idx=[0, 1, 2, 3],
  712. num_stages=4,
  713. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  714. dcn_v2_stages=dcn_v2_stages)
  715. elif 'ResNet50' in backbone:
  716. if with_fpn:
  717. backbone = self._get_backbone(
  718. 'ResNet',
  719. variant='d' if '_vd' in backbone else 'b',
  720. norm_type='bn',
  721. freeze_at=0,
  722. return_idx=[0, 1, 2, 3],
  723. num_stages=4,
  724. dcn_v2_stages=dcn_v2_stages)
  725. else:
  726. if with_dcn:
  727. logging.warning(
  728. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  729. format(backbone))
  730. backbone = self._get_backbone(
  731. 'ResNet',
  732. variant='d' if '_vd' in backbone else 'b',
  733. norm_type='bn',
  734. freeze_at=0,
  735. return_idx=[2],
  736. num_stages=3)
  737. elif 'ResNet34' in backbone:
  738. if not with_fpn:
  739. logging.warning(
  740. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  741. format(backbone))
  742. with_fpn = True
  743. backbone = self._get_backbone(
  744. 'ResNet',
  745. depth=34,
  746. variant='d' if 'vd' in backbone else 'b',
  747. norm_type='bn',
  748. freeze_at=0,
  749. return_idx=[0, 1, 2, 3],
  750. num_stages=4,
  751. dcn_v2_stages=dcn_v2_stages)
  752. else:
  753. if not with_fpn:
  754. logging.warning(
  755. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  756. format(backbone))
  757. with_fpn = True
  758. backbone = self._get_backbone(
  759. 'ResNet',
  760. depth=101,
  761. variant='d' if 'vd' in backbone else 'b',
  762. norm_type='bn',
  763. freeze_at=0,
  764. return_idx=[0, 1, 2, 3],
  765. num_stages=4,
  766. dcn_v2_stages=dcn_v2_stages)
  767. rpn_in_channel = backbone.out_shape[0].channels
  768. if with_fpn:
  769. self.backbone_name = self.backbone_name + '_fpn'
  770. if 'HRNet' in self.backbone_name:
  771. neck = ppdet.modeling.HRFPN(
  772. in_channels=[i.channels for i in backbone.out_shape],
  773. out_channel=fpn_num_channels,
  774. spatial_scales=[
  775. 1.0 / i.stride for i in backbone.out_shape
  776. ],
  777. share_conv=False)
  778. else:
  779. neck = ppdet.modeling.FPN(
  780. in_channels=[i.channels for i in backbone.out_shape],
  781. out_channel=fpn_num_channels,
  782. spatial_scales=[
  783. 1.0 / i.stride for i in backbone.out_shape
  784. ])
  785. rpn_in_channel = neck.out_shape[0].channels
  786. anchor_generator_cfg = {
  787. 'aspect_ratios': aspect_ratios,
  788. 'anchor_sizes': anchor_sizes,
  789. 'strides': [4, 8, 16, 32, 64]
  790. }
  791. train_proposal_cfg = {
  792. 'min_size': 0.0,
  793. 'nms_thresh': .7,
  794. 'pre_nms_top_n': 2000,
  795. 'post_nms_top_n': 1000,
  796. 'topk_after_collect': True
  797. }
  798. test_proposal_cfg = {
  799. 'min_size': 0.0,
  800. 'nms_thresh': .7,
  801. 'pre_nms_top_n': 1000
  802. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  803. 'post_nms_top_n': test_post_nms_top_n
  804. }
  805. head = ppdet.modeling.TwoFCHead(
  806. in_channel=neck.out_shape[0].channels, out_channel=1024)
  807. roi_extractor_cfg = {
  808. 'resolution': 7,
  809. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  810. 'sampling_ratio': 0,
  811. 'aligned': True
  812. }
  813. with_pool = False
  814. else:
  815. neck = None
  816. anchor_generator_cfg = {
  817. 'aspect_ratios': aspect_ratios,
  818. 'anchor_sizes': anchor_sizes,
  819. 'strides': [16]
  820. }
  821. train_proposal_cfg = {
  822. 'min_size': 0.0,
  823. 'nms_thresh': .7,
  824. 'pre_nms_top_n': 12000,
  825. 'post_nms_top_n': 2000,
  826. 'topk_after_collect': False
  827. }
  828. test_proposal_cfg = {
  829. 'min_size': 0.0,
  830. 'nms_thresh': .7,
  831. 'pre_nms_top_n': 6000
  832. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  833. 'post_nms_top_n': test_post_nms_top_n
  834. }
  835. head = ppdet.modeling.Res5Head()
  836. roi_extractor_cfg = {
  837. 'resolution': 14,
  838. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  839. 'sampling_ratio': 0,
  840. 'aligned': True
  841. }
  842. with_pool = True
  843. rpn_target_assign_cfg = {
  844. 'batch_size_per_im': rpn_batch_size_per_im,
  845. 'fg_fraction': rpn_fg_fraction,
  846. 'negative_overlap': .3,
  847. 'positive_overlap': .7,
  848. 'use_random': True
  849. }
  850. rpn_head = ppdet.modeling.RPNHead(
  851. anchor_generator=anchor_generator_cfg,
  852. rpn_target_assign=rpn_target_assign_cfg,
  853. train_proposal=train_proposal_cfg,
  854. test_proposal=test_proposal_cfg,
  855. in_channel=rpn_in_channel)
  856. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  857. bbox_head = ppdet.modeling.BBoxHead(
  858. head=head,
  859. in_channel=head.out_shape[0].channels,
  860. roi_extractor=roi_extractor_cfg,
  861. with_pool=with_pool,
  862. bbox_assigner=bbox_assigner,
  863. num_classes=num_classes)
  864. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  865. num_classes=num_classes,
  866. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  867. nms=ppdet.modeling.MultiClassNMS(
  868. score_threshold=score_threshold,
  869. keep_top_k=keep_top_k,
  870. nms_threshold=nms_threshold))
  871. params = {
  872. 'backbone': backbone,
  873. 'neck': neck,
  874. 'rpn_head': rpn_head,
  875. 'bbox_head': bbox_head,
  876. 'bbox_post_process': bbox_post_process
  877. }
  878. self.with_fpn = with_fpn
  879. super(FasterRCNN, self).__init__(
  880. model_name='FasterRCNN', num_classes=num_classes, **params)
  881. def _compose_batch_transform(self, transforms, mode='train'):
  882. if mode == 'train':
  883. default_batch_transforms = [
  884. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  885. ]
  886. collate_batch = False
  887. else:
  888. default_batch_transforms = [
  889. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  890. ]
  891. collate_batch = True
  892. custom_batch_transforms = []
  893. for i, op in enumerate(transforms.transforms):
  894. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  895. if mode != 'train':
  896. raise Exception(
  897. "{} cannot be present in the {} transforms. ".format(
  898. op.__class__.__name__, mode) +
  899. "Please check the {} transforms.".format(mode))
  900. custom_batch_transforms.insert(0, copy.deepcopy(op))
  901. batch_transforms = BatchCompose(
  902. custom_batch_transforms + default_batch_transforms,
  903. collate_batch=collate_batch)
  904. return batch_transforms
  905. def _fix_transforms_shape(self, image_shape):
  906. if hasattr(self, 'test_transforms'):
  907. if self.test_transforms is not None:
  908. has_resize_op = False
  909. resize_op_idx = -1
  910. normalize_op_idx = len(self.test_transforms.transforms)
  911. for idx, op in enumerate(self.test_transforms.transforms):
  912. name = op.__class__.__name__
  913. if name == 'ResizeByShort':
  914. has_resize_op = True
  915. resize_op_idx = idx
  916. if name == 'Normalize':
  917. normalize_op_idx = idx
  918. if not has_resize_op:
  919. self.test_transforms.transforms.insert(
  920. normalize_op_idx,
  921. Resize(
  922. target_size=image_shape,
  923. keep_ratio=True,
  924. interp='CUBIC'))
  925. else:
  926. self.test_transforms.transforms[resize_op_idx] = Resize(
  927. target_size=image_shape,
  928. keep_ratio=True,
  929. interp='CUBIC')
  930. self.test_transforms.transforms.append(
  931. Padding(im_padding_value=[0., 0., 0.]))
  932. class PPYOLO(YOLOv3):
  933. def __init__(self,
  934. num_classes=80,
  935. backbone='ResNet50_vd_dcn',
  936. anchors=None,
  937. anchor_masks=None,
  938. use_coord_conv=True,
  939. use_iou_aware=True,
  940. use_spp=True,
  941. use_drop_block=True,
  942. scale_x_y=1.05,
  943. ignore_threshold=0.7,
  944. label_smooth=False,
  945. use_iou_loss=True,
  946. use_matrix_nms=True,
  947. nms_score_threshold=0.01,
  948. nms_topk=-1,
  949. nms_keep_topk=100,
  950. nms_iou_threshold=0.45):
  951. self.init_params = locals()
  952. if backbone not in [
  953. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  954. 'MobileNetV3_small'
  955. ]:
  956. raise ValueError(
  957. "backbone: {} is not supported. Please choose one of "
  958. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  959. format(backbone))
  960. self.backbone_name = backbone
  961. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  962. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  963. norm_type = 'sync_bn'
  964. else:
  965. norm_type = 'bn'
  966. if anchors is None and anchor_masks is None:
  967. if 'MobileNetV3' in backbone:
  968. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  969. [120, 195], [254, 235]]
  970. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  971. elif backbone == 'ResNet50_vd_dcn':
  972. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  973. [59, 119], [116, 90], [156, 198], [373, 326]]
  974. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  975. else:
  976. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  977. [344, 319]]
  978. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  979. elif anchors is None or anchor_masks is None:
  980. raise ValueError("Please define both anchors and anchor_masks.")
  981. if backbone == 'ResNet50_vd_dcn':
  982. backbone = self._get_backbone(
  983. 'ResNet',
  984. variant='d',
  985. norm_type=norm_type,
  986. return_idx=[1, 2, 3],
  987. dcn_v2_stages=[3],
  988. freeze_at=-1,
  989. freeze_norm=False,
  990. norm_decay=0.)
  991. downsample_ratios = [32, 16, 8]
  992. elif backbone == 'ResNet18_vd':
  993. backbone = self._get_backbone(
  994. 'ResNet',
  995. depth=18,
  996. variant='d',
  997. norm_type=norm_type,
  998. return_idx=[2, 3],
  999. freeze_at=-1,
  1000. freeze_norm=False,
  1001. norm_decay=0.)
  1002. downsample_ratios = [32, 16, 8]
  1003. elif backbone == 'MobileNetV3_large':
  1004. backbone = self._get_backbone(
  1005. 'MobileNetV3',
  1006. model_name='large',
  1007. norm_type=norm_type,
  1008. scale=1,
  1009. with_extra_blocks=False,
  1010. extra_block_filters=[],
  1011. feature_maps=[13, 16])
  1012. downsample_ratios = [32, 16]
  1013. elif backbone == 'MobileNetV3_small':
  1014. backbone = self._get_backbone(
  1015. 'MobileNetV3',
  1016. model_name='small',
  1017. norm_type=norm_type,
  1018. scale=1,
  1019. with_extra_blocks=False,
  1020. extra_block_filters=[],
  1021. feature_maps=[9, 12])
  1022. downsample_ratios = [32, 16]
  1023. neck = ppdet.modeling.PPYOLOFPN(
  1024. norm_type=norm_type,
  1025. in_channels=[i.channels for i in backbone.out_shape],
  1026. coord_conv=use_coord_conv,
  1027. drop_block=use_drop_block,
  1028. spp=use_spp,
  1029. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  1030. self.backbone_name == 'ResNet18_vd') else 2)
  1031. loss = ppdet.modeling.YOLOv3Loss(
  1032. num_classes=num_classes,
  1033. ignore_thresh=ignore_threshold,
  1034. downsample=downsample_ratios,
  1035. label_smooth=label_smooth,
  1036. scale_x_y=scale_x_y,
  1037. iou_loss=ppdet.modeling.IouLoss(
  1038. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1039. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1040. if use_iou_aware else None)
  1041. yolo_head = ppdet.modeling.YOLOv3Head(
  1042. in_channels=[i.channels for i in neck.out_shape],
  1043. anchors=anchors,
  1044. anchor_masks=anchor_masks,
  1045. num_classes=num_classes,
  1046. loss=loss,
  1047. iou_aware=use_iou_aware)
  1048. if use_matrix_nms:
  1049. nms = ppdet.modeling.MatrixNMS(
  1050. keep_top_k=nms_keep_topk,
  1051. score_threshold=nms_score_threshold,
  1052. post_threshold=.05
  1053. if 'MobileNetV3' in self.backbone_name else .01,
  1054. nms_top_k=nms_topk,
  1055. background_label=-1)
  1056. else:
  1057. nms = ppdet.modeling.MultiClassNMS(
  1058. score_threshold=nms_score_threshold,
  1059. nms_top_k=nms_topk,
  1060. keep_top_k=nms_keep_topk,
  1061. nms_threshold=nms_iou_threshold)
  1062. post_process = ppdet.modeling.BBoxPostProcess(
  1063. decode=ppdet.modeling.YOLOBox(
  1064. num_classes=num_classes,
  1065. conf_thresh=.005
  1066. if 'MobileNetV3' in self.backbone_name else .01,
  1067. scale_x_y=scale_x_y),
  1068. nms=nms)
  1069. params = {
  1070. 'backbone': backbone,
  1071. 'neck': neck,
  1072. 'yolo_head': yolo_head,
  1073. 'post_process': post_process
  1074. }
  1075. super(YOLOv3, self).__init__(
  1076. model_name='YOLOv3', num_classes=num_classes, **params)
  1077. self.anchors = anchors
  1078. self.anchor_masks = anchor_masks
  1079. self.downsample_ratios = downsample_ratios
  1080. self.model_name = 'PPYOLO'
  1081. class PPYOLOTiny(YOLOv3):
  1082. def __init__(self,
  1083. num_classes=80,
  1084. backbone='MobileNetV3',
  1085. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1086. [60, 170], [220, 125], [128, 222], [264, 266]],
  1087. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1088. use_iou_aware=False,
  1089. use_spp=True,
  1090. use_drop_block=True,
  1091. scale_x_y=1.05,
  1092. ignore_threshold=0.5,
  1093. label_smooth=False,
  1094. use_iou_loss=True,
  1095. use_matrix_nms=False,
  1096. nms_score_threshold=0.005,
  1097. nms_topk=1000,
  1098. nms_keep_topk=100,
  1099. nms_iou_threshold=0.45):
  1100. self.init_params = locals()
  1101. if backbone != 'MobileNetV3':
  1102. logging.warning(
  1103. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1104. "Backbone is forcibly set to MobileNetV3.")
  1105. self.backbone_name = 'MobileNetV3'
  1106. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1107. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1108. norm_type = 'sync_bn'
  1109. else:
  1110. norm_type = 'bn'
  1111. backbone = self._get_backbone(
  1112. 'MobileNetV3',
  1113. model_name='large',
  1114. norm_type=norm_type,
  1115. scale=.5,
  1116. with_extra_blocks=False,
  1117. extra_block_filters=[],
  1118. feature_maps=[7, 13, 16])
  1119. downsample_ratios = [32, 16, 8]
  1120. neck = ppdet.modeling.PPYOLOTinyFPN(
  1121. detection_block_channels=[160, 128, 96],
  1122. in_channels=[i.channels for i in backbone.out_shape],
  1123. spp=use_spp,
  1124. drop_block=use_drop_block)
  1125. loss = ppdet.modeling.YOLOv3Loss(
  1126. num_classes=num_classes,
  1127. ignore_thresh=ignore_threshold,
  1128. downsample=downsample_ratios,
  1129. label_smooth=label_smooth,
  1130. scale_x_y=scale_x_y,
  1131. iou_loss=ppdet.modeling.IouLoss(
  1132. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1133. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1134. if use_iou_aware else None)
  1135. yolo_head = ppdet.modeling.YOLOv3Head(
  1136. in_channels=[i.channels for i in neck.out_shape],
  1137. anchors=anchors,
  1138. anchor_masks=anchor_masks,
  1139. num_classes=num_classes,
  1140. loss=loss,
  1141. iou_aware=use_iou_aware)
  1142. if use_matrix_nms:
  1143. nms = ppdet.modeling.MatrixNMS(
  1144. keep_top_k=nms_keep_topk,
  1145. score_threshold=nms_score_threshold,
  1146. post_threshold=.05,
  1147. nms_top_k=nms_topk,
  1148. background_label=-1)
  1149. else:
  1150. nms = ppdet.modeling.MultiClassNMS(
  1151. score_threshold=nms_score_threshold,
  1152. nms_top_k=nms_topk,
  1153. keep_top_k=nms_keep_topk,
  1154. nms_threshold=nms_iou_threshold)
  1155. post_process = ppdet.modeling.BBoxPostProcess(
  1156. decode=ppdet.modeling.YOLOBox(
  1157. num_classes=num_classes,
  1158. conf_thresh=.005,
  1159. downsample_ratio=32,
  1160. clip_bbox=True,
  1161. scale_x_y=scale_x_y),
  1162. nms=nms)
  1163. params = {
  1164. 'backbone': backbone,
  1165. 'neck': neck,
  1166. 'yolo_head': yolo_head,
  1167. 'post_process': post_process
  1168. }
  1169. super(YOLOv3, self).__init__(
  1170. model_name='YOLOv3', num_classes=num_classes, **params)
  1171. self.anchors = anchors
  1172. self.anchor_masks = anchor_masks
  1173. self.downsample_ratios = downsample_ratios
  1174. self.model_name = 'PPYOLOTiny'
  1175. class PPYOLOv2(YOLOv3):
  1176. def __init__(self,
  1177. num_classes=80,
  1178. backbone='ResNet50_vd_dcn',
  1179. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1180. [59, 119], [116, 90], [156, 198], [373, 326]],
  1181. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1182. use_iou_aware=True,
  1183. use_spp=True,
  1184. use_drop_block=True,
  1185. scale_x_y=1.05,
  1186. ignore_threshold=0.7,
  1187. label_smooth=False,
  1188. use_iou_loss=True,
  1189. use_matrix_nms=True,
  1190. nms_score_threshold=0.01,
  1191. nms_topk=-1,
  1192. nms_keep_topk=100,
  1193. nms_iou_threshold=0.45):
  1194. self.init_params = locals()
  1195. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1196. raise ValueError(
  1197. "backbone: {} is not supported. Please choose one of "
  1198. "('ResNet50_vd_dcn', 'ResNet18_vd')".format(backbone))
  1199. self.backbone_name = backbone
  1200. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1201. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1202. norm_type = 'sync_bn'
  1203. else:
  1204. norm_type = 'bn'
  1205. if backbone == 'ResNet50_vd_dcn':
  1206. backbone = self._get_backbone(
  1207. 'ResNet',
  1208. variant='d',
  1209. norm_type=norm_type,
  1210. return_idx=[1, 2, 3],
  1211. dcn_v2_stages=[3],
  1212. freeze_at=-1,
  1213. freeze_norm=False,
  1214. norm_decay=0.)
  1215. downsample_ratios = [32, 16, 8]
  1216. elif backbone == 'ResNet101_vd_dcn':
  1217. backbone = self._get_backbone(
  1218. 'ResNet',
  1219. depth=101,
  1220. variant='d',
  1221. norm_type=norm_type,
  1222. return_idx=[1, 2, 3],
  1223. dcn_v2_stages=[3],
  1224. freeze_at=-1,
  1225. freeze_norm=False,
  1226. norm_decay=0.)
  1227. downsample_ratios = [32, 16, 8]
  1228. neck = ppdet.modeling.PPYOLOPAN(
  1229. norm_type=norm_type,
  1230. in_channels=[i.channels for i in backbone.out_shape],
  1231. drop_block=use_drop_block,
  1232. block_size=3,
  1233. keep_prob=.9,
  1234. spp=use_spp)
  1235. loss = ppdet.modeling.YOLOv3Loss(
  1236. num_classes=num_classes,
  1237. ignore_thresh=ignore_threshold,
  1238. downsample=downsample_ratios,
  1239. label_smooth=label_smooth,
  1240. scale_x_y=scale_x_y,
  1241. iou_loss=ppdet.modeling.IouLoss(
  1242. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1243. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1244. if use_iou_aware else None)
  1245. yolo_head = ppdet.modeling.YOLOv3Head(
  1246. in_channels=[i.channels for i in neck.out_shape],
  1247. anchors=anchors,
  1248. anchor_masks=anchor_masks,
  1249. num_classes=num_classes,
  1250. loss=loss,
  1251. iou_aware=use_iou_aware,
  1252. iou_aware_factor=.5)
  1253. if use_matrix_nms:
  1254. nms = ppdet.modeling.MatrixNMS(
  1255. keep_top_k=nms_keep_topk,
  1256. score_threshold=nms_score_threshold,
  1257. post_threshold=.01,
  1258. nms_top_k=nms_topk,
  1259. background_label=-1)
  1260. else:
  1261. nms = ppdet.modeling.MultiClassNMS(
  1262. score_threshold=nms_score_threshold,
  1263. nms_top_k=nms_topk,
  1264. keep_top_k=nms_keep_topk,
  1265. nms_threshold=nms_iou_threshold)
  1266. post_process = ppdet.modeling.BBoxPostProcess(
  1267. decode=ppdet.modeling.YOLOBox(
  1268. num_classes=num_classes,
  1269. conf_thresh=.01,
  1270. downsample_ratio=32,
  1271. clip_bbox=True,
  1272. scale_x_y=scale_x_y),
  1273. nms=nms)
  1274. params = {
  1275. 'backbone': backbone,
  1276. 'neck': neck,
  1277. 'yolo_head': yolo_head,
  1278. 'post_process': post_process
  1279. }
  1280. super(YOLOv3, self).__init__(
  1281. model_name='YOLOv3', num_classes=num_classes, **params)
  1282. self.anchors = anchors
  1283. self.anchor_masks = anchor_masks
  1284. self.downsample_ratios = downsample_ratios
  1285. self.model_name = 'PPYOLOv2'
  1286. def _get_test_inputs(self, image_shape):
  1287. if image_shape is not None:
  1288. if len(image_shape) == 2:
  1289. image_shape = [None, 3] + image_shape
  1290. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  1291. raise Exception(
  1292. "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
  1293. format(image_shape[-2:]))
  1294. self._fix_transforms_shape(image_shape[-2:])
  1295. else:
  1296. logging.warning(
  1297. '[Important!!!] When exporting inference model for {},'.format(
  1298. self.__class__.__name__) +
  1299. ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 608, 608]. '
  1300. +
  1301. 'Please check image shape after transforms is [3, 608, 608], if not, fixed_input_shape '
  1302. + 'should be specified manually.')
  1303. image_shape = [None, 3, 608, 608]
  1304. input_spec = [{
  1305. "image": InputSpec(
  1306. shape=image_shape, name='image', dtype='float32'),
  1307. "im_shape": InputSpec(
  1308. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  1309. "scale_factor": InputSpec(
  1310. shape=[image_shape[0], 2],
  1311. name='scale_factor',
  1312. dtype='float32')
  1313. }]
  1314. return input_spec
  1315. class MaskRCNN(BaseDetector):
  1316. def __init__(self,
  1317. num_classes=80,
  1318. backbone='ResNet50_vd',
  1319. with_fpn=True,
  1320. with_dcn=False,
  1321. aspect_ratios=[0.5, 1.0, 2.0],
  1322. anchor_sizes=[[32], [64], [128], [256], [512]],
  1323. keep_top_k=100,
  1324. nms_threshold=0.5,
  1325. score_threshold=0.05,
  1326. fpn_num_channels=256,
  1327. rpn_batch_size_per_im=256,
  1328. rpn_fg_fraction=0.5,
  1329. test_pre_nms_top_n=None,
  1330. test_post_nms_top_n=1000):
  1331. self.init_params = locals()
  1332. if backbone not in [
  1333. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1334. 'ResNet101_vd'
  1335. ]:
  1336. raise ValueError(
  1337. "backbone: {} is not supported. Please choose one of "
  1338. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1339. format(backbone))
  1340. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1341. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1342. if backbone == 'ResNet50':
  1343. if with_fpn:
  1344. backbone = self._get_backbone(
  1345. 'ResNet',
  1346. norm_type='bn',
  1347. freeze_at=0,
  1348. return_idx=[0, 1, 2, 3],
  1349. num_stages=4,
  1350. dcn_v2_stages=dcn_v2_stages)
  1351. else:
  1352. if with_dcn:
  1353. logging.warning(
  1354. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1355. format(backbone))
  1356. backbone = self._get_backbone(
  1357. 'ResNet',
  1358. norm_type='bn',
  1359. freeze_at=0,
  1360. return_idx=[2],
  1361. num_stages=3)
  1362. elif 'ResNet50_vd' in backbone:
  1363. if not with_fpn:
  1364. logging.warning(
  1365. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1366. format(backbone))
  1367. with_fpn = True
  1368. backbone = self._get_backbone(
  1369. 'ResNet',
  1370. variant='d',
  1371. norm_type='bn',
  1372. freeze_at=0,
  1373. return_idx=[0, 1, 2, 3],
  1374. num_stages=4,
  1375. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1376. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1377. dcn_v2_stages=dcn_v2_stages)
  1378. else:
  1379. if not with_fpn:
  1380. logging.warning(
  1381. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1382. format(backbone))
  1383. with_fpn = True
  1384. backbone = self._get_backbone(
  1385. 'ResNet',
  1386. variant='d' if '_vd' in backbone else 'b',
  1387. depth=101,
  1388. norm_type='bn',
  1389. freeze_at=0,
  1390. return_idx=[0, 1, 2, 3],
  1391. num_stages=4,
  1392. dcn_v2_stages=dcn_v2_stages)
  1393. rpn_in_channel = backbone.out_shape[0].channels
  1394. if with_fpn:
  1395. neck = ppdet.modeling.FPN(
  1396. in_channels=[i.channels for i in backbone.out_shape],
  1397. out_channel=fpn_num_channels,
  1398. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1399. rpn_in_channel = neck.out_shape[0].channels
  1400. anchor_generator_cfg = {
  1401. 'aspect_ratios': aspect_ratios,
  1402. 'anchor_sizes': anchor_sizes,
  1403. 'strides': [4, 8, 16, 32, 64]
  1404. }
  1405. train_proposal_cfg = {
  1406. 'min_size': 0.0,
  1407. 'nms_thresh': .7,
  1408. 'pre_nms_top_n': 2000,
  1409. 'post_nms_top_n': 1000,
  1410. 'topk_after_collect': True
  1411. }
  1412. test_proposal_cfg = {
  1413. 'min_size': 0.0,
  1414. 'nms_thresh': .7,
  1415. 'pre_nms_top_n': 1000
  1416. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1417. 'post_nms_top_n': test_post_nms_top_n
  1418. }
  1419. bb_head = ppdet.modeling.TwoFCHead(
  1420. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1421. bb_roi_extractor_cfg = {
  1422. 'resolution': 7,
  1423. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1424. 'sampling_ratio': 0,
  1425. 'aligned': True
  1426. }
  1427. with_pool = False
  1428. m_head = ppdet.modeling.MaskFeat(
  1429. in_channel=neck.out_shape[0].channels,
  1430. out_channel=256,
  1431. num_convs=4)
  1432. m_roi_extractor_cfg = {
  1433. 'resolution': 14,
  1434. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1435. 'sampling_ratio': 0,
  1436. 'aligned': True
  1437. }
  1438. mask_assigner = MaskAssigner(
  1439. num_classes=num_classes, mask_resolution=28)
  1440. share_bbox_feat = False
  1441. else:
  1442. neck = None
  1443. anchor_generator_cfg = {
  1444. 'aspect_ratios': aspect_ratios,
  1445. 'anchor_sizes': anchor_sizes,
  1446. 'strides': [16]
  1447. }
  1448. train_proposal_cfg = {
  1449. 'min_size': 0.0,
  1450. 'nms_thresh': .7,
  1451. 'pre_nms_top_n': 12000,
  1452. 'post_nms_top_n': 2000,
  1453. 'topk_after_collect': False
  1454. }
  1455. test_proposal_cfg = {
  1456. 'min_size': 0.0,
  1457. 'nms_thresh': .7,
  1458. 'pre_nms_top_n': 6000
  1459. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1460. 'post_nms_top_n': test_post_nms_top_n
  1461. }
  1462. bb_head = ppdet.modeling.Res5Head()
  1463. bb_roi_extractor_cfg = {
  1464. 'resolution': 14,
  1465. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1466. 'sampling_ratio': 0,
  1467. 'aligned': True
  1468. }
  1469. with_pool = True
  1470. m_head = ppdet.modeling.MaskFeat(
  1471. in_channel=bb_head.out_shape[0].channels,
  1472. out_channel=256,
  1473. num_convs=0)
  1474. m_roi_extractor_cfg = {
  1475. 'resolution': 14,
  1476. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1477. 'sampling_ratio': 0,
  1478. 'aligned': True
  1479. }
  1480. mask_assigner = MaskAssigner(
  1481. num_classes=num_classes, mask_resolution=14)
  1482. share_bbox_feat = True
  1483. rpn_target_assign_cfg = {
  1484. 'batch_size_per_im': rpn_batch_size_per_im,
  1485. 'fg_fraction': rpn_fg_fraction,
  1486. 'negative_overlap': .3,
  1487. 'positive_overlap': .7,
  1488. 'use_random': True
  1489. }
  1490. rpn_head = ppdet.modeling.RPNHead(
  1491. anchor_generator=anchor_generator_cfg,
  1492. rpn_target_assign=rpn_target_assign_cfg,
  1493. train_proposal=train_proposal_cfg,
  1494. test_proposal=test_proposal_cfg,
  1495. in_channel=rpn_in_channel)
  1496. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1497. bbox_head = ppdet.modeling.BBoxHead(
  1498. head=bb_head,
  1499. in_channel=bb_head.out_shape[0].channels,
  1500. roi_extractor=bb_roi_extractor_cfg,
  1501. with_pool=with_pool,
  1502. bbox_assigner=bbox_assigner,
  1503. num_classes=num_classes)
  1504. mask_head = ppdet.modeling.MaskHead(
  1505. head=m_head,
  1506. roi_extractor=m_roi_extractor_cfg,
  1507. mask_assigner=mask_assigner,
  1508. share_bbox_feat=share_bbox_feat,
  1509. num_classes=num_classes)
  1510. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1511. num_classes=num_classes,
  1512. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1513. nms=ppdet.modeling.MultiClassNMS(
  1514. score_threshold=score_threshold,
  1515. keep_top_k=keep_top_k,
  1516. nms_threshold=nms_threshold))
  1517. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1518. params = {
  1519. 'backbone': backbone,
  1520. 'neck': neck,
  1521. 'rpn_head': rpn_head,
  1522. 'bbox_head': bbox_head,
  1523. 'mask_head': mask_head,
  1524. 'bbox_post_process': bbox_post_process,
  1525. 'mask_post_process': mask_post_process
  1526. }
  1527. self.with_fpn = with_fpn
  1528. super(MaskRCNN, self).__init__(
  1529. model_name='MaskRCNN', num_classes=num_classes, **params)
  1530. def _compose_batch_transform(self, transforms, mode='train'):
  1531. if mode == 'train':
  1532. default_batch_transforms = [
  1533. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1534. ]
  1535. collate_batch = False
  1536. else:
  1537. default_batch_transforms = [
  1538. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1539. ]
  1540. collate_batch = True
  1541. custom_batch_transforms = []
  1542. for i, op in enumerate(transforms.transforms):
  1543. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1544. if mode != 'train':
  1545. raise Exception(
  1546. "{} cannot be present in the {} transforms. ".format(
  1547. op.__class__.__name__, mode) +
  1548. "Please check the {} transforms.".format(mode))
  1549. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1550. batch_transforms = BatchCompose(
  1551. custom_batch_transforms + default_batch_transforms,
  1552. collate_batch=collate_batch)
  1553. return batch_transforms
  1554. def _fix_transforms_shape(self, image_shape):
  1555. if hasattr(self, 'test_transforms'):
  1556. if self.test_transforms is not None:
  1557. has_resize_op = False
  1558. resize_op_idx = -1
  1559. normalize_op_idx = len(self.test_transforms.transforms)
  1560. for idx, op in enumerate(self.test_transforms.transforms):
  1561. name = op.__class__.__name__
  1562. if name == 'ResizeByShort':
  1563. has_resize_op = True
  1564. resize_op_idx = idx
  1565. if name == 'Normalize':
  1566. normalize_op_idx = idx
  1567. if not has_resize_op:
  1568. self.test_transforms.transforms.insert(
  1569. normalize_op_idx,
  1570. Resize(
  1571. target_size=image_shape,
  1572. keep_ratio=True,
  1573. interp='CUBIC'))
  1574. else:
  1575. self.test_transforms.transforms[resize_op_idx] = Resize(
  1576. target_size=image_shape,
  1577. keep_ratio=True,
  1578. interp='CUBIC')
  1579. self.test_transforms.transforms.append(
  1580. Padding(im_padding_value=[0., 0., 0.]))