detector.py 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import paddlex.ppdet as ppdet
  24. from paddlex.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. super(BaseDetector, self).__init__('detector')
  41. if not hasattr(ppdet.modeling, model_name):
  42. raise Exception("ERROR: There's no model named {}.".format(
  43. model_name))
  44. self.model_name = model_name
  45. self.num_classes = num_classes
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. def build_net(self, **params):
  49. with paddle.utils.unique_name.guard():
  50. net = ppdet.modeling.__dict__[self.model_name](**params)
  51. return net
  52. def _fix_transforms_shape(self, image_shape):
  53. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  54. def _define_input_spec(self, image_shape):
  55. input_spec = [{
  56. "image": InputSpec(
  57. shape=image_shape, name='image', dtype='float32'),
  58. "im_shape": InputSpec(
  59. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  60. "scale_factor": InputSpec(
  61. shape=[image_shape[0], 2],
  62. name='scale_factor',
  63. dtype='float32')
  64. }]
  65. return input_spec
  66. def _check_image_shape(self, image_shape):
  67. if len(image_shape) == 2:
  68. image_shape = [1, 3] + image_shape
  69. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  70. raise Exception(
  71. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  72. format(image_shape[-2:]))
  73. return image_shape
  74. def _get_test_inputs(self, image_shape):
  75. if image_shape is not None:
  76. image_shape = self._check_image_shape(image_shape)
  77. self._fix_transforms_shape(image_shape[-2:])
  78. else:
  79. image_shape = [None, 3, -1, -1]
  80. self.fixed_input_shape = image_shape
  81. return self._define_input_spec(image_shape)
  82. def _get_backbone(self, backbone_name, **params):
  83. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  84. return backbone
  85. def run(self, net, inputs, mode):
  86. net_out = net(inputs)
  87. if mode in ['train', 'eval']:
  88. outputs = net_out
  89. else:
  90. for key in ['im_shape', 'scale_factor']:
  91. net_out[key] = inputs[key]
  92. outputs = dict()
  93. for key in net_out:
  94. outputs[key] = net_out[key].numpy()
  95. return outputs
  96. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  97. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  98. num_steps_each_epoch):
  99. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  100. values = [(lr_decay_gamma**i) * learning_rate
  101. for i in range(len(lr_decay_epochs) + 1)]
  102. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  103. boundaries=boundaries, values=values)
  104. if warmup_steps > 0:
  105. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  106. logging.error(
  107. "In function train(), parameters should satisfy: "
  108. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  109. exit=False)
  110. logging.error(
  111. "See this doc for more information: "
  112. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  113. exit=False)
  114. scheduler = paddle.optimizer.lr.LinearWarmup(
  115. learning_rate=scheduler,
  116. warmup_steps=warmup_steps,
  117. start_lr=warmup_start_lr,
  118. end_lr=learning_rate)
  119. optimizer = paddle.optimizer.Momentum(
  120. scheduler,
  121. momentum=.9,
  122. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  123. parameters=parameters)
  124. return optimizer
  125. def train(self,
  126. num_epochs,
  127. train_dataset,
  128. train_batch_size=64,
  129. eval_dataset=None,
  130. optimizer=None,
  131. save_interval_epochs=1,
  132. log_interval_steps=10,
  133. save_dir='output',
  134. pretrain_weights='IMAGENET',
  135. learning_rate=.001,
  136. warmup_steps=0,
  137. warmup_start_lr=0.0,
  138. lr_decay_epochs=(216, 243),
  139. lr_decay_gamma=0.1,
  140. metric=None,
  141. use_ema=False,
  142. early_stop=False,
  143. early_stop_patience=5,
  144. use_vdl=True,
  145. resume_checkpoint=None):
  146. """
  147. Train the model.
  148. Args:
  149. num_epochs(int): The number of epochs.
  150. train_dataset(paddlex.dataset): Training dataset.
  151. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  152. eval_dataset(paddlex.dataset, optional):
  153. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  154. optimizer(paddle.optimizer.Optimizer or None, optional):
  155. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  156. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  157. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  158. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  159. pretrain_weights(str or None, optional):
  160. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  161. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  162. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  163. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  164. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  165. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  166. metric({'VOC', 'COCO', None}, optional):
  167. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  168. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  169. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  170. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  171. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  172. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  173. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  174. `pretrain_weights` can be set simultaneously. Defaults to None.
  175. """
  176. if pretrain_weights is not None and resume_checkpoint is not None:
  177. logging.error(
  178. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  179. exit=True)
  180. if train_dataset.__class__.__name__ == 'VOCDetection':
  181. train_dataset.data_fields = {
  182. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  183. 'difficult'
  184. }
  185. elif train_dataset.__class__.__name__ == 'CocoDetection':
  186. if self.__class__.__name__ == 'MaskRCNN':
  187. train_dataset.data_fields = {
  188. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  189. 'gt_poly', 'is_crowd'
  190. }
  191. else:
  192. train_dataset.data_fields = {
  193. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  194. 'is_crowd'
  195. }
  196. if metric is None:
  197. if eval_dataset.__class__.__name__ == 'VOCDetection':
  198. self.metric = 'voc'
  199. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  200. self.metric = 'coco'
  201. else:
  202. assert metric.lower() in ['coco', 'voc'], \
  203. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  204. self.metric = metric.lower()
  205. self.labels = train_dataset.labels
  206. self.num_max_boxes = train_dataset.num_max_boxes
  207. train_dataset.batch_transforms = self._compose_batch_transform(
  208. train_dataset.transforms, mode='train')
  209. # build optimizer if not defined
  210. if optimizer is None:
  211. num_steps_each_epoch = len(train_dataset) // train_batch_size
  212. self.optimizer = self.default_optimizer(
  213. parameters=self.net.parameters(),
  214. learning_rate=learning_rate,
  215. warmup_steps=warmup_steps,
  216. warmup_start_lr=warmup_start_lr,
  217. lr_decay_epochs=lr_decay_epochs,
  218. lr_decay_gamma=lr_decay_gamma,
  219. num_steps_each_epoch=num_steps_each_epoch)
  220. else:
  221. self.optimizer = optimizer
  222. # initiate weights
  223. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  224. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  225. [self.model_name, self.backbone_name])]:
  226. logging.warning(
  227. "Path of pretrain_weights('{}') does not exist!".format(
  228. pretrain_weights))
  229. pretrain_weights = det_pretrain_weights_dict['_'.join(
  230. [self.model_name, self.backbone_name])][0]
  231. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  232. "If you don't want to use pretrain weights, "
  233. "set pretrain_weights to be None.".format(
  234. pretrain_weights))
  235. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  236. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  237. logging.error(
  238. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  239. exit=True)
  240. pretrained_dir = osp.join(save_dir, 'pretrain')
  241. self.net_initialize(
  242. pretrain_weights=pretrain_weights,
  243. save_dir=pretrained_dir,
  244. resume_checkpoint=resume_checkpoint)
  245. if use_ema:
  246. ema = ExponentialMovingAverage(
  247. decay=.9998, model=self.net, use_thres_step=True)
  248. else:
  249. ema = None
  250. # start train loop
  251. self.train_loop(
  252. num_epochs=num_epochs,
  253. train_dataset=train_dataset,
  254. train_batch_size=train_batch_size,
  255. eval_dataset=eval_dataset,
  256. save_interval_epochs=save_interval_epochs,
  257. log_interval_steps=log_interval_steps,
  258. save_dir=save_dir,
  259. ema=ema,
  260. early_stop=early_stop,
  261. early_stop_patience=early_stop_patience,
  262. use_vdl=use_vdl)
  263. def quant_aware_train(self,
  264. num_epochs,
  265. train_dataset,
  266. train_batch_size=64,
  267. eval_dataset=None,
  268. optimizer=None,
  269. save_interval_epochs=1,
  270. log_interval_steps=10,
  271. save_dir='output',
  272. learning_rate=.00001,
  273. warmup_steps=0,
  274. warmup_start_lr=0.0,
  275. lr_decay_epochs=(216, 243),
  276. lr_decay_gamma=0.1,
  277. metric=None,
  278. use_ema=False,
  279. early_stop=False,
  280. early_stop_patience=5,
  281. use_vdl=True,
  282. resume_checkpoint=None,
  283. quant_config=None):
  284. """
  285. Quantization-aware training.
  286. Args:
  287. num_epochs(int): The number of epochs.
  288. train_dataset(paddlex.dataset): Training dataset.
  289. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  290. eval_dataset(paddlex.dataset, optional):
  291. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  292. optimizer(paddle.optimizer.Optimizer or None, optional):
  293. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  294. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  295. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  296. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  297. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  298. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  299. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  300. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  301. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  302. metric({'VOC', 'COCO', None}, optional):
  303. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  304. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  305. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  306. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  307. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  308. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  309. configuration will be used. Defaults to None.
  310. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  311. from. If None, no training checkpoint will be resumed. Defaults to None.
  312. """
  313. self._prepare_qat(quant_config)
  314. self.train(
  315. num_epochs=num_epochs,
  316. train_dataset=train_dataset,
  317. train_batch_size=train_batch_size,
  318. eval_dataset=eval_dataset,
  319. optimizer=optimizer,
  320. save_interval_epochs=save_interval_epochs,
  321. log_interval_steps=log_interval_steps,
  322. save_dir=save_dir,
  323. pretrain_weights=None,
  324. learning_rate=learning_rate,
  325. warmup_steps=warmup_steps,
  326. warmup_start_lr=warmup_start_lr,
  327. lr_decay_epochs=lr_decay_epochs,
  328. lr_decay_gamma=lr_decay_gamma,
  329. metric=metric,
  330. use_ema=use_ema,
  331. early_stop=early_stop,
  332. early_stop_patience=early_stop_patience,
  333. use_vdl=use_vdl,
  334. resume_checkpoint=resume_checkpoint)
  335. def evaluate(self,
  336. eval_dataset,
  337. batch_size=1,
  338. metric=None,
  339. return_details=False):
  340. """
  341. Evaluate the model.
  342. Args:
  343. eval_dataset(paddlex.dataset): Evaluation dataset.
  344. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  345. metric({'VOC', 'COCO', None}, optional):
  346. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  347. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  348. Returns:
  349. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  350. """
  351. if metric is None:
  352. if not hasattr(self, 'metric'):
  353. if eval_dataset.__class__.__name__ == 'VOCDetection':
  354. self.metric = 'voc'
  355. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  356. self.metric = 'coco'
  357. else:
  358. assert metric.lower() in ['coco', 'voc'], \
  359. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  360. self.metric = metric.lower()
  361. if self.metric == 'voc':
  362. eval_dataset.data_fields = {
  363. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  364. 'difficult'
  365. }
  366. elif self.metric == 'coco':
  367. if self.__class__.__name__ == 'MaskRCNN':
  368. eval_dataset.data_fields = {
  369. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  370. 'gt_poly', 'is_crowd'
  371. }
  372. else:
  373. eval_dataset.data_fields = {
  374. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  375. 'is_crowd'
  376. }
  377. eval_dataset.batch_transforms = self._compose_batch_transform(
  378. eval_dataset.transforms, mode='eval')
  379. arrange_transforms(
  380. model_type=self.model_type,
  381. transforms=eval_dataset.transforms,
  382. mode='eval')
  383. self.net.eval()
  384. nranks = paddle.distributed.get_world_size()
  385. local_rank = paddle.distributed.get_rank()
  386. if nranks > 1:
  387. # Initialize parallel environment if not done.
  388. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  389. ):
  390. paddle.distributed.init_parallel_env()
  391. if batch_size > 1:
  392. logging.warning(
  393. "Detector only supports single card evaluation with batch_size=1 "
  394. "during evaluation, so batch_size is forcibly set to 1.")
  395. batch_size = 1
  396. if nranks < 2 or local_rank == 0:
  397. self.eval_data_loader = self.build_data_loader(
  398. eval_dataset, batch_size=batch_size, mode='eval')
  399. is_bbox_normalized = False
  400. if eval_dataset.batch_transforms is not None:
  401. is_bbox_normalized = any(
  402. isinstance(t, _NormalizeBox)
  403. for t in eval_dataset.batch_transforms.batch_transforms)
  404. if self.metric == 'voc':
  405. eval_metric = VOCMetric(
  406. labels=eval_dataset.labels,
  407. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  408. is_bbox_normalized=is_bbox_normalized,
  409. classwise=False)
  410. else:
  411. eval_metric = COCOMetric(
  412. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  413. classwise=False)
  414. scores = collections.OrderedDict()
  415. logging.info(
  416. "Start to evaluate(total_samples={}, total_steps={})...".
  417. format(eval_dataset.num_samples, eval_dataset.num_samples))
  418. with paddle.no_grad():
  419. for step, data in enumerate(self.eval_data_loader):
  420. outputs = self.run(self.net, data, 'eval')
  421. eval_metric.update(data, outputs)
  422. eval_metric.accumulate()
  423. self.eval_details = eval_metric.details
  424. scores.update(eval_metric.get())
  425. eval_metric.reset()
  426. if return_details:
  427. return scores, self.eval_details
  428. return scores
  429. def predict(self, img_file, transforms=None):
  430. """
  431. Do inference.
  432. Args:
  433. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  434. Image path or decoded image data in a BGR format, which also could constitute a list,
  435. meaning all images to be predicted as a mini-batch.
  436. transforms(paddlex.transforms.Compose or None, optional):
  437. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  438. Returns:
  439. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  440. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  441. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  442. category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
  443. category(str): category name
  444. bbox(list): bounding box in [x, y, w, h] format
  445. score(str): confidence
  446. mask(dict): Only for instance segmentation task. Mask of the object in RLE format
  447. """
  448. if transforms is None and not hasattr(self, 'test_transforms'):
  449. raise Exception("transforms need to be defined, now is None.")
  450. if transforms is None:
  451. transforms = self.test_transforms
  452. if isinstance(img_file, (str, np.ndarray)):
  453. images = [img_file]
  454. else:
  455. images = img_file
  456. batch_samples = self._preprocess(images, transforms)
  457. self.net.eval()
  458. outputs = self.run(self.net, batch_samples, 'test')
  459. prediction = self._postprocess(outputs)
  460. if isinstance(img_file, (str, np.ndarray)):
  461. prediction = prediction[0]
  462. return prediction
  463. def _preprocess(self, images, transforms):
  464. arrange_transforms(
  465. model_type=self.model_type, transforms=transforms, mode='test')
  466. batch_samples = list()
  467. for im in images:
  468. sample = {'image': im}
  469. batch_samples.append(transforms(sample))
  470. batch_transforms = self._compose_batch_transform(transforms, 'test')
  471. batch_samples = batch_transforms(batch_samples)
  472. for k, v in batch_samples.items():
  473. batch_samples[k] = paddle.to_tensor(v)
  474. return batch_samples
  475. def _postprocess(self, batch_pred):
  476. infer_result = {}
  477. if 'bbox' in batch_pred:
  478. bboxes = batch_pred['bbox']
  479. bbox_nums = batch_pred['bbox_num']
  480. det_res = []
  481. k = 0
  482. for i in range(len(bbox_nums)):
  483. det_nums = bbox_nums[i]
  484. for j in range(det_nums):
  485. dt = bboxes[k]
  486. k = k + 1
  487. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  488. if int(num_id) < 0:
  489. continue
  490. category = self.labels[int(num_id)]
  491. w = xmax - xmin
  492. h = ymax - ymin
  493. bbox = [xmin, ymin, w, h]
  494. dt_res = {
  495. 'category_id': int(num_id),
  496. 'category': category,
  497. 'bbox': bbox,
  498. 'score': score
  499. }
  500. det_res.append(dt_res)
  501. infer_result['bbox'] = det_res
  502. if 'mask' in batch_pred:
  503. masks = batch_pred['mask']
  504. bboxes = batch_pred['bbox']
  505. mask_nums = batch_pred['bbox_num']
  506. seg_res = []
  507. k = 0
  508. for i in range(len(mask_nums)):
  509. det_nums = mask_nums[i]
  510. for j in range(det_nums):
  511. mask = masks[k].astype(np.uint8)
  512. score = float(bboxes[k][1])
  513. label = int(bboxes[k][0])
  514. k = k + 1
  515. if label == -1:
  516. continue
  517. category = self.labels[int(label)]
  518. import pycocotools.mask as mask_util
  519. rle = mask_util.encode(
  520. np.array(
  521. mask[:, :, None], order="F", dtype="uint8"))[0]
  522. if six.PY3:
  523. if 'counts' in rle:
  524. rle['counts'] = rle['counts'].decode("utf8")
  525. sg_res = {
  526. 'category_id': int(label),
  527. 'category': category,
  528. 'mask': rle,
  529. 'score': score
  530. }
  531. seg_res.append(sg_res)
  532. infer_result['mask'] = seg_res
  533. bbox_num = batch_pred['bbox_num']
  534. results = []
  535. start = 0
  536. for num in bbox_num:
  537. end = start + num
  538. curr_res = infer_result['bbox'][start:end]
  539. if 'mask' in infer_result:
  540. mask_res = infer_result['mask'][start:end]
  541. for box, mask in zip(curr_res, mask_res):
  542. box.update(mask)
  543. results.append(curr_res)
  544. start = end
  545. return results
  546. class YOLOv3(BaseDetector):
  547. def __init__(self,
  548. num_classes=80,
  549. backbone='MobileNetV1',
  550. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  551. [59, 119], [116, 90], [156, 198], [373, 326]],
  552. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  553. ignore_threshold=0.7,
  554. nms_score_threshold=0.01,
  555. nms_topk=1000,
  556. nms_keep_topk=100,
  557. nms_iou_threshold=0.45,
  558. label_smooth=False):
  559. self.init_params = locals()
  560. if backbone not in [
  561. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  562. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  563. ]:
  564. raise ValueError(
  565. "backbone: {} is not supported. Please choose one of "
  566. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  567. format(backbone))
  568. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  569. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  570. norm_type = 'sync_bn'
  571. else:
  572. norm_type = 'bn'
  573. self.backbone_name = backbone
  574. if 'MobileNetV1' in backbone:
  575. norm_type = 'bn'
  576. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  577. elif 'MobileNetV3' in backbone:
  578. backbone = self._get_backbone(
  579. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  580. elif backbone == 'ResNet50_vd_dcn':
  581. backbone = self._get_backbone(
  582. 'ResNet',
  583. norm_type=norm_type,
  584. variant='d',
  585. return_idx=[1, 2, 3],
  586. dcn_v2_stages=[3],
  587. freeze_at=-1,
  588. freeze_norm=False)
  589. elif backbone == 'ResNet34':
  590. backbone = self._get_backbone(
  591. 'ResNet',
  592. depth=34,
  593. norm_type=norm_type,
  594. return_idx=[1, 2, 3],
  595. freeze_at=-1,
  596. freeze_norm=False,
  597. norm_decay=0.)
  598. else:
  599. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  600. neck = ppdet.modeling.YOLOv3FPN(
  601. norm_type=norm_type,
  602. in_channels=[i.channels for i in backbone.out_shape])
  603. loss = ppdet.modeling.YOLOv3Loss(
  604. num_classes=num_classes,
  605. ignore_thresh=ignore_threshold,
  606. label_smooth=label_smooth)
  607. yolo_head = ppdet.modeling.YOLOv3Head(
  608. in_channels=[i.channels for i in neck.out_shape],
  609. anchors=anchors,
  610. anchor_masks=anchor_masks,
  611. num_classes=num_classes,
  612. loss=loss)
  613. post_process = ppdet.modeling.BBoxPostProcess(
  614. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  615. nms=ppdet.modeling.MultiClassNMS(
  616. score_threshold=nms_score_threshold,
  617. nms_top_k=nms_topk,
  618. keep_top_k=nms_keep_topk,
  619. nms_threshold=nms_iou_threshold))
  620. params = {
  621. 'backbone': backbone,
  622. 'neck': neck,
  623. 'yolo_head': yolo_head,
  624. 'post_process': post_process
  625. }
  626. super(YOLOv3, self).__init__(
  627. model_name='YOLOv3', num_classes=num_classes, **params)
  628. self.anchors = anchors
  629. self.anchor_masks = anchor_masks
  630. def _compose_batch_transform(self, transforms, mode='train'):
  631. if mode == 'train':
  632. default_batch_transforms = [
  633. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  634. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  635. _Gt2YoloTarget(
  636. anchor_masks=self.anchor_masks,
  637. anchors=self.anchors,
  638. downsample_ratios=getattr(self, 'downsample_ratios',
  639. [32, 16, 8]),
  640. num_classes=self.num_classes)
  641. ]
  642. else:
  643. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  644. if mode == 'eval' and self.metric == 'voc':
  645. collate_batch = False
  646. else:
  647. collate_batch = True
  648. custom_batch_transforms = []
  649. for i, op in enumerate(transforms.transforms):
  650. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  651. if mode != 'train':
  652. raise Exception(
  653. "{} cannot be present in the {} transforms. ".format(
  654. op.__class__.__name__, mode) +
  655. "Please check the {} transforms.".format(mode))
  656. custom_batch_transforms.insert(0, copy.deepcopy(op))
  657. batch_transforms = BatchCompose(
  658. custom_batch_transforms + default_batch_transforms,
  659. collate_batch=collate_batch)
  660. return batch_transforms
  661. def _fix_transforms_shape(self, image_shape):
  662. if hasattr(self, 'test_transforms'):
  663. if self.test_transforms is not None:
  664. has_resize_op = False
  665. resize_op_idx = -1
  666. normalize_op_idx = len(self.test_transforms.transforms)
  667. for idx, op in enumerate(self.test_transforms.transforms):
  668. name = op.__class__.__name__
  669. if name == 'Resize':
  670. has_resize_op = True
  671. resize_op_idx = idx
  672. if name == 'Normalize':
  673. normalize_op_idx = idx
  674. if not has_resize_op:
  675. self.test_transforms.transforms.insert(
  676. normalize_op_idx,
  677. Resize(
  678. target_size=image_shape, interp='CUBIC'))
  679. else:
  680. self.test_transforms.transforms[
  681. resize_op_idx].target_size = image_shape
  682. class FasterRCNN(BaseDetector):
  683. def __init__(self,
  684. num_classes=80,
  685. backbone='ResNet50',
  686. with_fpn=True,
  687. with_dcn=False,
  688. aspect_ratios=[0.5, 1.0, 2.0],
  689. anchor_sizes=[[32], [64], [128], [256], [512]],
  690. keep_top_k=100,
  691. nms_threshold=0.5,
  692. score_threshold=0.05,
  693. fpn_num_channels=256,
  694. rpn_batch_size_per_im=256,
  695. rpn_fg_fraction=0.5,
  696. test_pre_nms_top_n=None,
  697. test_post_nms_top_n=1000):
  698. self.init_params = locals()
  699. if backbone not in [
  700. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  701. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  702. ]:
  703. raise ValueError(
  704. "backbone: {} is not supported. Please choose one of "
  705. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  706. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  707. self.backbone_name = backbone
  708. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  709. if backbone == 'HRNet_W18':
  710. if not with_fpn:
  711. logging.warning(
  712. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  713. format(backbone))
  714. with_fpn = True
  715. if with_dcn:
  716. logging.warning(
  717. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  718. format(backbone))
  719. backbone = self._get_backbone(
  720. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  721. elif backbone == 'ResNet50_vd_ssld':
  722. if not with_fpn:
  723. logging.warning(
  724. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  725. format(backbone))
  726. with_fpn = True
  727. backbone = self._get_backbone(
  728. 'ResNet',
  729. variant='d',
  730. norm_type='bn',
  731. freeze_at=0,
  732. return_idx=[0, 1, 2, 3],
  733. num_stages=4,
  734. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  735. dcn_v2_stages=dcn_v2_stages)
  736. elif 'ResNet50' in backbone:
  737. if with_fpn:
  738. backbone = self._get_backbone(
  739. 'ResNet',
  740. variant='d' if '_vd' in backbone else 'b',
  741. norm_type='bn',
  742. freeze_at=0,
  743. return_idx=[0, 1, 2, 3],
  744. num_stages=4,
  745. dcn_v2_stages=dcn_v2_stages)
  746. else:
  747. if with_dcn:
  748. logging.warning(
  749. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  750. format(backbone))
  751. backbone = self._get_backbone(
  752. 'ResNet',
  753. variant='d' if '_vd' in backbone else 'b',
  754. norm_type='bn',
  755. freeze_at=0,
  756. return_idx=[2],
  757. num_stages=3)
  758. elif 'ResNet34' in backbone:
  759. if not with_fpn:
  760. logging.warning(
  761. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  762. format(backbone))
  763. with_fpn = True
  764. backbone = self._get_backbone(
  765. 'ResNet',
  766. depth=34,
  767. variant='d' if 'vd' in backbone else 'b',
  768. norm_type='bn',
  769. freeze_at=0,
  770. return_idx=[0, 1, 2, 3],
  771. num_stages=4,
  772. dcn_v2_stages=dcn_v2_stages)
  773. else:
  774. if not with_fpn:
  775. logging.warning(
  776. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  777. format(backbone))
  778. with_fpn = True
  779. backbone = self._get_backbone(
  780. 'ResNet',
  781. depth=101,
  782. variant='d' if 'vd' in backbone else 'b',
  783. norm_type='bn',
  784. freeze_at=0,
  785. return_idx=[0, 1, 2, 3],
  786. num_stages=4,
  787. dcn_v2_stages=dcn_v2_stages)
  788. rpn_in_channel = backbone.out_shape[0].channels
  789. if with_fpn:
  790. self.backbone_name = self.backbone_name + '_fpn'
  791. if 'HRNet' in self.backbone_name:
  792. neck = ppdet.modeling.HRFPN(
  793. in_channels=[i.channels for i in backbone.out_shape],
  794. out_channel=fpn_num_channels,
  795. spatial_scales=[
  796. 1.0 / i.stride for i in backbone.out_shape
  797. ],
  798. share_conv=False)
  799. else:
  800. neck = ppdet.modeling.FPN(
  801. in_channels=[i.channels for i in backbone.out_shape],
  802. out_channel=fpn_num_channels,
  803. spatial_scales=[
  804. 1.0 / i.stride for i in backbone.out_shape
  805. ])
  806. rpn_in_channel = neck.out_shape[0].channels
  807. anchor_generator_cfg = {
  808. 'aspect_ratios': aspect_ratios,
  809. 'anchor_sizes': anchor_sizes,
  810. 'strides': [4, 8, 16, 32, 64]
  811. }
  812. train_proposal_cfg = {
  813. 'min_size': 0.0,
  814. 'nms_thresh': .7,
  815. 'pre_nms_top_n': 2000,
  816. 'post_nms_top_n': 1000,
  817. 'topk_after_collect': True
  818. }
  819. test_proposal_cfg = {
  820. 'min_size': 0.0,
  821. 'nms_thresh': .7,
  822. 'pre_nms_top_n': 1000
  823. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  824. 'post_nms_top_n': test_post_nms_top_n
  825. }
  826. head = ppdet.modeling.TwoFCHead(
  827. in_channel=neck.out_shape[0].channels, out_channel=1024)
  828. roi_extractor_cfg = {
  829. 'resolution': 7,
  830. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  831. 'sampling_ratio': 0,
  832. 'aligned': True
  833. }
  834. with_pool = False
  835. else:
  836. neck = None
  837. anchor_generator_cfg = {
  838. 'aspect_ratios': aspect_ratios,
  839. 'anchor_sizes': anchor_sizes,
  840. 'strides': [16]
  841. }
  842. train_proposal_cfg = {
  843. 'min_size': 0.0,
  844. 'nms_thresh': .7,
  845. 'pre_nms_top_n': 12000,
  846. 'post_nms_top_n': 2000,
  847. 'topk_after_collect': False
  848. }
  849. test_proposal_cfg = {
  850. 'min_size': 0.0,
  851. 'nms_thresh': .7,
  852. 'pre_nms_top_n': 6000
  853. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  854. 'post_nms_top_n': test_post_nms_top_n
  855. }
  856. head = ppdet.modeling.Res5Head()
  857. roi_extractor_cfg = {
  858. 'resolution': 14,
  859. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  860. 'sampling_ratio': 0,
  861. 'aligned': True
  862. }
  863. with_pool = True
  864. rpn_target_assign_cfg = {
  865. 'batch_size_per_im': rpn_batch_size_per_im,
  866. 'fg_fraction': rpn_fg_fraction,
  867. 'negative_overlap': .3,
  868. 'positive_overlap': .7,
  869. 'use_random': True
  870. }
  871. rpn_head = ppdet.modeling.RPNHead(
  872. anchor_generator=anchor_generator_cfg,
  873. rpn_target_assign=rpn_target_assign_cfg,
  874. train_proposal=train_proposal_cfg,
  875. test_proposal=test_proposal_cfg,
  876. in_channel=rpn_in_channel)
  877. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  878. bbox_head = ppdet.modeling.BBoxHead(
  879. head=head,
  880. in_channel=head.out_shape[0].channels,
  881. roi_extractor=roi_extractor_cfg,
  882. with_pool=with_pool,
  883. bbox_assigner=bbox_assigner,
  884. num_classes=num_classes)
  885. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  886. num_classes=num_classes,
  887. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  888. nms=ppdet.modeling.MultiClassNMS(
  889. score_threshold=score_threshold,
  890. keep_top_k=keep_top_k,
  891. nms_threshold=nms_threshold))
  892. params = {
  893. 'backbone': backbone,
  894. 'neck': neck,
  895. 'rpn_head': rpn_head,
  896. 'bbox_head': bbox_head,
  897. 'bbox_post_process': bbox_post_process
  898. }
  899. self.with_fpn = with_fpn
  900. super(FasterRCNN, self).__init__(
  901. model_name='FasterRCNN', num_classes=num_classes, **params)
  902. def _compose_batch_transform(self, transforms, mode='train'):
  903. if mode == 'train':
  904. default_batch_transforms = [
  905. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  906. ]
  907. collate_batch = False
  908. else:
  909. default_batch_transforms = [
  910. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  911. ]
  912. collate_batch = True
  913. custom_batch_transforms = []
  914. for i, op in enumerate(transforms.transforms):
  915. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  916. if mode != 'train':
  917. raise Exception(
  918. "{} cannot be present in the {} transforms. ".format(
  919. op.__class__.__name__, mode) +
  920. "Please check the {} transforms.".format(mode))
  921. custom_batch_transforms.insert(0, copy.deepcopy(op))
  922. batch_transforms = BatchCompose(
  923. custom_batch_transforms + default_batch_transforms,
  924. collate_batch=collate_batch)
  925. return batch_transforms
  926. def _fix_transforms_shape(self, image_shape):
  927. if hasattr(self, 'test_transforms'):
  928. if self.test_transforms is not None:
  929. has_resize_op = False
  930. resize_op_idx = -1
  931. normalize_op_idx = len(self.test_transforms.transforms)
  932. for idx, op in enumerate(self.test_transforms.transforms):
  933. name = op.__class__.__name__
  934. if name == 'ResizeByShort':
  935. has_resize_op = True
  936. resize_op_idx = idx
  937. if name == 'Normalize':
  938. normalize_op_idx = idx
  939. if not has_resize_op:
  940. self.test_transforms.transforms.insert(
  941. normalize_op_idx,
  942. Resize(
  943. target_size=image_shape,
  944. keep_ratio=True,
  945. interp='CUBIC'))
  946. else:
  947. self.test_transforms.transforms[resize_op_idx] = Resize(
  948. target_size=image_shape,
  949. keep_ratio=True,
  950. interp='CUBIC')
  951. self.test_transforms.transforms.append(
  952. Padding(im_padding_value=[0., 0., 0.]))
  953. def _get_test_inputs(self, image_shape):
  954. if image_shape is not None:
  955. image_shape = self._check_image_shape(image_shape)
  956. self._fix_transforms_shape(image_shape[-2:])
  957. else:
  958. image_shape = [None, 3, -1, -1]
  959. if self.with_fpn:
  960. self.test_transforms.transforms.append(
  961. Padding(im_padding_value=[0., 0., 0.]))
  962. self.fixed_input_shape = image_shape
  963. return self._define_input_spec(image_shape)
  964. class PPYOLO(YOLOv3):
  965. def __init__(self,
  966. num_classes=80,
  967. backbone='ResNet50_vd_dcn',
  968. anchors=None,
  969. anchor_masks=None,
  970. use_coord_conv=True,
  971. use_iou_aware=True,
  972. use_spp=True,
  973. use_drop_block=True,
  974. scale_x_y=1.05,
  975. ignore_threshold=0.7,
  976. label_smooth=False,
  977. use_iou_loss=True,
  978. use_matrix_nms=True,
  979. nms_score_threshold=0.01,
  980. nms_topk=-1,
  981. nms_keep_topk=100,
  982. nms_iou_threshold=0.45):
  983. self.init_params = locals()
  984. if backbone not in [
  985. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  986. 'MobileNetV3_small'
  987. ]:
  988. raise ValueError(
  989. "backbone: {} is not supported. Please choose one of "
  990. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  991. format(backbone))
  992. self.backbone_name = backbone
  993. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  994. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  995. norm_type = 'sync_bn'
  996. else:
  997. norm_type = 'bn'
  998. if anchors is None and anchor_masks is None:
  999. if 'MobileNetV3' in backbone:
  1000. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1001. [120, 195], [254, 235]]
  1002. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1003. elif backbone == 'ResNet50_vd_dcn':
  1004. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1005. [59, 119], [116, 90], [156, 198], [373, 326]]
  1006. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1007. else:
  1008. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  1009. [344, 319]]
  1010. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1011. elif anchors is None or anchor_masks is None:
  1012. raise ValueError("Please define both anchors and anchor_masks.")
  1013. if backbone == 'ResNet50_vd_dcn':
  1014. backbone = self._get_backbone(
  1015. 'ResNet',
  1016. variant='d',
  1017. norm_type=norm_type,
  1018. return_idx=[1, 2, 3],
  1019. dcn_v2_stages=[3],
  1020. freeze_at=-1,
  1021. freeze_norm=False,
  1022. norm_decay=0.)
  1023. downsample_ratios = [32, 16, 8]
  1024. elif backbone == 'ResNet18_vd':
  1025. backbone = self._get_backbone(
  1026. 'ResNet',
  1027. depth=18,
  1028. variant='d',
  1029. norm_type=norm_type,
  1030. return_idx=[2, 3],
  1031. freeze_at=-1,
  1032. freeze_norm=False,
  1033. norm_decay=0.)
  1034. downsample_ratios = [32, 16]
  1035. elif backbone == 'MobileNetV3_large':
  1036. backbone = self._get_backbone(
  1037. 'MobileNetV3',
  1038. model_name='large',
  1039. norm_type=norm_type,
  1040. scale=1,
  1041. with_extra_blocks=False,
  1042. extra_block_filters=[],
  1043. feature_maps=[13, 16])
  1044. downsample_ratios = [32, 16]
  1045. elif backbone == 'MobileNetV3_small':
  1046. backbone = self._get_backbone(
  1047. 'MobileNetV3',
  1048. model_name='small',
  1049. norm_type=norm_type,
  1050. scale=1,
  1051. with_extra_blocks=False,
  1052. extra_block_filters=[],
  1053. feature_maps=[9, 12])
  1054. downsample_ratios = [32, 16]
  1055. neck = ppdet.modeling.PPYOLOFPN(
  1056. norm_type=norm_type,
  1057. in_channels=[i.channels for i in backbone.out_shape],
  1058. coord_conv=use_coord_conv,
  1059. drop_block=use_drop_block,
  1060. spp=use_spp,
  1061. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  1062. self.backbone_name == 'ResNet18_vd') else 2)
  1063. loss = ppdet.modeling.YOLOv3Loss(
  1064. num_classes=num_classes,
  1065. ignore_thresh=ignore_threshold,
  1066. downsample=downsample_ratios,
  1067. label_smooth=label_smooth,
  1068. scale_x_y=scale_x_y,
  1069. iou_loss=ppdet.modeling.IouLoss(
  1070. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1071. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1072. if use_iou_aware else None)
  1073. yolo_head = ppdet.modeling.YOLOv3Head(
  1074. in_channels=[i.channels for i in neck.out_shape],
  1075. anchors=anchors,
  1076. anchor_masks=anchor_masks,
  1077. num_classes=num_classes,
  1078. loss=loss,
  1079. iou_aware=use_iou_aware)
  1080. if use_matrix_nms:
  1081. nms = ppdet.modeling.MatrixNMS(
  1082. keep_top_k=nms_keep_topk,
  1083. score_threshold=nms_score_threshold,
  1084. post_threshold=.05
  1085. if 'MobileNetV3' in self.backbone_name else .01,
  1086. nms_top_k=nms_topk,
  1087. background_label=-1)
  1088. else:
  1089. nms = ppdet.modeling.MultiClassNMS(
  1090. score_threshold=nms_score_threshold,
  1091. nms_top_k=nms_topk,
  1092. keep_top_k=nms_keep_topk,
  1093. nms_threshold=nms_iou_threshold)
  1094. post_process = ppdet.modeling.BBoxPostProcess(
  1095. decode=ppdet.modeling.YOLOBox(
  1096. num_classes=num_classes,
  1097. conf_thresh=.005
  1098. if 'MobileNetV3' in self.backbone_name else .01,
  1099. scale_x_y=scale_x_y),
  1100. nms=nms)
  1101. params = {
  1102. 'backbone': backbone,
  1103. 'neck': neck,
  1104. 'yolo_head': yolo_head,
  1105. 'post_process': post_process
  1106. }
  1107. super(YOLOv3, self).__init__(
  1108. model_name='YOLOv3', num_classes=num_classes, **params)
  1109. self.anchors = anchors
  1110. self.anchor_masks = anchor_masks
  1111. self.downsample_ratios = downsample_ratios
  1112. self.model_name = 'PPYOLO'
  1113. class PPYOLOTiny(YOLOv3):
  1114. def __init__(self,
  1115. num_classes=80,
  1116. backbone='MobileNetV3',
  1117. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1118. [60, 170], [220, 125], [128, 222], [264, 266]],
  1119. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1120. use_iou_aware=False,
  1121. use_spp=True,
  1122. use_drop_block=True,
  1123. scale_x_y=1.05,
  1124. ignore_threshold=0.5,
  1125. label_smooth=False,
  1126. use_iou_loss=True,
  1127. use_matrix_nms=False,
  1128. nms_score_threshold=0.005,
  1129. nms_topk=1000,
  1130. nms_keep_topk=100,
  1131. nms_iou_threshold=0.45):
  1132. self.init_params = locals()
  1133. if backbone != 'MobileNetV3':
  1134. logging.warning(
  1135. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1136. "Backbone is forcibly set to MobileNetV3.")
  1137. self.backbone_name = 'MobileNetV3'
  1138. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1139. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1140. norm_type = 'sync_bn'
  1141. else:
  1142. norm_type = 'bn'
  1143. backbone = self._get_backbone(
  1144. 'MobileNetV3',
  1145. model_name='large',
  1146. norm_type=norm_type,
  1147. scale=.5,
  1148. with_extra_blocks=False,
  1149. extra_block_filters=[],
  1150. feature_maps=[7, 13, 16])
  1151. downsample_ratios = [32, 16, 8]
  1152. neck = ppdet.modeling.PPYOLOTinyFPN(
  1153. detection_block_channels=[160, 128, 96],
  1154. in_channels=[i.channels for i in backbone.out_shape],
  1155. spp=use_spp,
  1156. drop_block=use_drop_block)
  1157. loss = ppdet.modeling.YOLOv3Loss(
  1158. num_classes=num_classes,
  1159. ignore_thresh=ignore_threshold,
  1160. downsample=downsample_ratios,
  1161. label_smooth=label_smooth,
  1162. scale_x_y=scale_x_y,
  1163. iou_loss=ppdet.modeling.IouLoss(
  1164. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1165. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1166. if use_iou_aware else None)
  1167. yolo_head = ppdet.modeling.YOLOv3Head(
  1168. in_channels=[i.channels for i in neck.out_shape],
  1169. anchors=anchors,
  1170. anchor_masks=anchor_masks,
  1171. num_classes=num_classes,
  1172. loss=loss,
  1173. iou_aware=use_iou_aware)
  1174. if use_matrix_nms:
  1175. nms = ppdet.modeling.MatrixNMS(
  1176. keep_top_k=nms_keep_topk,
  1177. score_threshold=nms_score_threshold,
  1178. post_threshold=.05,
  1179. nms_top_k=nms_topk,
  1180. background_label=-1)
  1181. else:
  1182. nms = ppdet.modeling.MultiClassNMS(
  1183. score_threshold=nms_score_threshold,
  1184. nms_top_k=nms_topk,
  1185. keep_top_k=nms_keep_topk,
  1186. nms_threshold=nms_iou_threshold)
  1187. post_process = ppdet.modeling.BBoxPostProcess(
  1188. decode=ppdet.modeling.YOLOBox(
  1189. num_classes=num_classes,
  1190. conf_thresh=.005,
  1191. downsample_ratio=32,
  1192. clip_bbox=True,
  1193. scale_x_y=scale_x_y),
  1194. nms=nms)
  1195. params = {
  1196. 'backbone': backbone,
  1197. 'neck': neck,
  1198. 'yolo_head': yolo_head,
  1199. 'post_process': post_process
  1200. }
  1201. super(YOLOv3, self).__init__(
  1202. model_name='YOLOv3', num_classes=num_classes, **params)
  1203. self.anchors = anchors
  1204. self.anchor_masks = anchor_masks
  1205. self.downsample_ratios = downsample_ratios
  1206. self.model_name = 'PPYOLOTiny'
  1207. class PPYOLOv2(YOLOv3):
  1208. def __init__(self,
  1209. num_classes=80,
  1210. backbone='ResNet50_vd_dcn',
  1211. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1212. [59, 119], [116, 90], [156, 198], [373, 326]],
  1213. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1214. use_iou_aware=True,
  1215. use_spp=True,
  1216. use_drop_block=True,
  1217. scale_x_y=1.05,
  1218. ignore_threshold=0.7,
  1219. label_smooth=False,
  1220. use_iou_loss=True,
  1221. use_matrix_nms=True,
  1222. nms_score_threshold=0.01,
  1223. nms_topk=-1,
  1224. nms_keep_topk=100,
  1225. nms_iou_threshold=0.45):
  1226. self.init_params = locals()
  1227. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1228. raise ValueError(
  1229. "backbone: {} is not supported. Please choose one of "
  1230. "('ResNet50_vd_dcn', 'ResNet101_vd_dcn')".format(backbone))
  1231. self.backbone_name = backbone
  1232. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1233. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1234. norm_type = 'sync_bn'
  1235. else:
  1236. norm_type = 'bn'
  1237. if backbone == 'ResNet50_vd_dcn':
  1238. backbone = self._get_backbone(
  1239. 'ResNet',
  1240. variant='d',
  1241. norm_type=norm_type,
  1242. return_idx=[1, 2, 3],
  1243. dcn_v2_stages=[3],
  1244. freeze_at=-1,
  1245. freeze_norm=False,
  1246. norm_decay=0.)
  1247. downsample_ratios = [32, 16, 8]
  1248. elif backbone == 'ResNet101_vd_dcn':
  1249. backbone = self._get_backbone(
  1250. 'ResNet',
  1251. depth=101,
  1252. variant='d',
  1253. norm_type=norm_type,
  1254. return_idx=[1, 2, 3],
  1255. dcn_v2_stages=[3],
  1256. freeze_at=-1,
  1257. freeze_norm=False,
  1258. norm_decay=0.)
  1259. downsample_ratios = [32, 16, 8]
  1260. neck = ppdet.modeling.PPYOLOPAN(
  1261. norm_type=norm_type,
  1262. in_channels=[i.channels for i in backbone.out_shape],
  1263. drop_block=use_drop_block,
  1264. block_size=3,
  1265. keep_prob=.9,
  1266. spp=use_spp)
  1267. loss = ppdet.modeling.YOLOv3Loss(
  1268. num_classes=num_classes,
  1269. ignore_thresh=ignore_threshold,
  1270. downsample=downsample_ratios,
  1271. label_smooth=label_smooth,
  1272. scale_x_y=scale_x_y,
  1273. iou_loss=ppdet.modeling.IouLoss(
  1274. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1275. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1276. if use_iou_aware else None)
  1277. yolo_head = ppdet.modeling.YOLOv3Head(
  1278. in_channels=[i.channels for i in neck.out_shape],
  1279. anchors=anchors,
  1280. anchor_masks=anchor_masks,
  1281. num_classes=num_classes,
  1282. loss=loss,
  1283. iou_aware=use_iou_aware,
  1284. iou_aware_factor=.5)
  1285. if use_matrix_nms:
  1286. nms = ppdet.modeling.MatrixNMS(
  1287. keep_top_k=nms_keep_topk,
  1288. score_threshold=nms_score_threshold,
  1289. post_threshold=.01,
  1290. nms_top_k=nms_topk,
  1291. background_label=-1)
  1292. else:
  1293. nms = ppdet.modeling.MultiClassNMS(
  1294. score_threshold=nms_score_threshold,
  1295. nms_top_k=nms_topk,
  1296. keep_top_k=nms_keep_topk,
  1297. nms_threshold=nms_iou_threshold)
  1298. post_process = ppdet.modeling.BBoxPostProcess(
  1299. decode=ppdet.modeling.YOLOBox(
  1300. num_classes=num_classes,
  1301. conf_thresh=.01,
  1302. downsample_ratio=32,
  1303. clip_bbox=True,
  1304. scale_x_y=scale_x_y),
  1305. nms=nms)
  1306. params = {
  1307. 'backbone': backbone,
  1308. 'neck': neck,
  1309. 'yolo_head': yolo_head,
  1310. 'post_process': post_process
  1311. }
  1312. super(YOLOv3, self).__init__(
  1313. model_name='YOLOv3', num_classes=num_classes, **params)
  1314. self.anchors = anchors
  1315. self.anchor_masks = anchor_masks
  1316. self.downsample_ratios = downsample_ratios
  1317. self.model_name = 'PPYOLOv2'
  1318. def _get_test_inputs(self, image_shape):
  1319. if image_shape is not None:
  1320. image_shape = self._check_image_shape(image_shape)
  1321. self._fix_transforms_shape(image_shape[-2:])
  1322. else:
  1323. image_shape = [None, 3, 608, 608]
  1324. logging.warning(
  1325. '[Important!!!] When exporting inference model for {},'.format(
  1326. self.__class__.__name__) +
  1327. ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 608, 608]. '
  1328. +
  1329. 'Please check image shape after transforms is [3, 608, 608], if not, fixed_input_shape '
  1330. + 'should be specified manually.')
  1331. self.fixed_input_shape = image_shape
  1332. return self._define_input_spec(image_shape)
  1333. class MaskRCNN(BaseDetector):
  1334. def __init__(self,
  1335. num_classes=80,
  1336. backbone='ResNet50_vd',
  1337. with_fpn=True,
  1338. with_dcn=False,
  1339. aspect_ratios=[0.5, 1.0, 2.0],
  1340. anchor_sizes=[[32], [64], [128], [256], [512]],
  1341. keep_top_k=100,
  1342. nms_threshold=0.5,
  1343. score_threshold=0.05,
  1344. fpn_num_channels=256,
  1345. rpn_batch_size_per_im=256,
  1346. rpn_fg_fraction=0.5,
  1347. test_pre_nms_top_n=None,
  1348. test_post_nms_top_n=1000):
  1349. self.init_params = locals()
  1350. if backbone not in [
  1351. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1352. 'ResNet101_vd'
  1353. ]:
  1354. raise ValueError(
  1355. "backbone: {} is not supported. Please choose one of "
  1356. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1357. format(backbone))
  1358. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1359. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1360. if backbone == 'ResNet50':
  1361. if with_fpn:
  1362. backbone = self._get_backbone(
  1363. 'ResNet',
  1364. norm_type='bn',
  1365. freeze_at=0,
  1366. return_idx=[0, 1, 2, 3],
  1367. num_stages=4,
  1368. dcn_v2_stages=dcn_v2_stages)
  1369. else:
  1370. if with_dcn:
  1371. logging.warning(
  1372. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1373. format(backbone))
  1374. backbone = self._get_backbone(
  1375. 'ResNet',
  1376. norm_type='bn',
  1377. freeze_at=0,
  1378. return_idx=[2],
  1379. num_stages=3)
  1380. elif 'ResNet50_vd' in backbone:
  1381. if not with_fpn:
  1382. logging.warning(
  1383. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1384. format(backbone))
  1385. with_fpn = True
  1386. backbone = self._get_backbone(
  1387. 'ResNet',
  1388. variant='d',
  1389. norm_type='bn',
  1390. freeze_at=0,
  1391. return_idx=[0, 1, 2, 3],
  1392. num_stages=4,
  1393. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1394. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1395. dcn_v2_stages=dcn_v2_stages)
  1396. else:
  1397. if not with_fpn:
  1398. logging.warning(
  1399. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1400. format(backbone))
  1401. with_fpn = True
  1402. backbone = self._get_backbone(
  1403. 'ResNet',
  1404. variant='d' if '_vd' in backbone else 'b',
  1405. depth=101,
  1406. norm_type='bn',
  1407. freeze_at=0,
  1408. return_idx=[0, 1, 2, 3],
  1409. num_stages=4,
  1410. dcn_v2_stages=dcn_v2_stages)
  1411. rpn_in_channel = backbone.out_shape[0].channels
  1412. if with_fpn:
  1413. neck = ppdet.modeling.FPN(
  1414. in_channels=[i.channels for i in backbone.out_shape],
  1415. out_channel=fpn_num_channels,
  1416. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1417. rpn_in_channel = neck.out_shape[0].channels
  1418. anchor_generator_cfg = {
  1419. 'aspect_ratios': aspect_ratios,
  1420. 'anchor_sizes': anchor_sizes,
  1421. 'strides': [4, 8, 16, 32, 64]
  1422. }
  1423. train_proposal_cfg = {
  1424. 'min_size': 0.0,
  1425. 'nms_thresh': .7,
  1426. 'pre_nms_top_n': 2000,
  1427. 'post_nms_top_n': 1000,
  1428. 'topk_after_collect': True
  1429. }
  1430. test_proposal_cfg = {
  1431. 'min_size': 0.0,
  1432. 'nms_thresh': .7,
  1433. 'pre_nms_top_n': 1000
  1434. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1435. 'post_nms_top_n': test_post_nms_top_n
  1436. }
  1437. bb_head = ppdet.modeling.TwoFCHead(
  1438. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1439. bb_roi_extractor_cfg = {
  1440. 'resolution': 7,
  1441. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1442. 'sampling_ratio': 0,
  1443. 'aligned': True
  1444. }
  1445. with_pool = False
  1446. m_head = ppdet.modeling.MaskFeat(
  1447. in_channel=neck.out_shape[0].channels,
  1448. out_channel=256,
  1449. num_convs=4)
  1450. m_roi_extractor_cfg = {
  1451. 'resolution': 14,
  1452. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1453. 'sampling_ratio': 0,
  1454. 'aligned': True
  1455. }
  1456. mask_assigner = MaskAssigner(
  1457. num_classes=num_classes, mask_resolution=28)
  1458. share_bbox_feat = False
  1459. else:
  1460. neck = None
  1461. anchor_generator_cfg = {
  1462. 'aspect_ratios': aspect_ratios,
  1463. 'anchor_sizes': anchor_sizes,
  1464. 'strides': [16]
  1465. }
  1466. train_proposal_cfg = {
  1467. 'min_size': 0.0,
  1468. 'nms_thresh': .7,
  1469. 'pre_nms_top_n': 12000,
  1470. 'post_nms_top_n': 2000,
  1471. 'topk_after_collect': False
  1472. }
  1473. test_proposal_cfg = {
  1474. 'min_size': 0.0,
  1475. 'nms_thresh': .7,
  1476. 'pre_nms_top_n': 6000
  1477. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1478. 'post_nms_top_n': test_post_nms_top_n
  1479. }
  1480. bb_head = ppdet.modeling.Res5Head()
  1481. bb_roi_extractor_cfg = {
  1482. 'resolution': 14,
  1483. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1484. 'sampling_ratio': 0,
  1485. 'aligned': True
  1486. }
  1487. with_pool = True
  1488. m_head = ppdet.modeling.MaskFeat(
  1489. in_channel=bb_head.out_shape[0].channels,
  1490. out_channel=256,
  1491. num_convs=0)
  1492. m_roi_extractor_cfg = {
  1493. 'resolution': 14,
  1494. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1495. 'sampling_ratio': 0,
  1496. 'aligned': True
  1497. }
  1498. mask_assigner = MaskAssigner(
  1499. num_classes=num_classes, mask_resolution=14)
  1500. share_bbox_feat = True
  1501. rpn_target_assign_cfg = {
  1502. 'batch_size_per_im': rpn_batch_size_per_im,
  1503. 'fg_fraction': rpn_fg_fraction,
  1504. 'negative_overlap': .3,
  1505. 'positive_overlap': .7,
  1506. 'use_random': True
  1507. }
  1508. rpn_head = ppdet.modeling.RPNHead(
  1509. anchor_generator=anchor_generator_cfg,
  1510. rpn_target_assign=rpn_target_assign_cfg,
  1511. train_proposal=train_proposal_cfg,
  1512. test_proposal=test_proposal_cfg,
  1513. in_channel=rpn_in_channel)
  1514. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1515. bbox_head = ppdet.modeling.BBoxHead(
  1516. head=bb_head,
  1517. in_channel=bb_head.out_shape[0].channels,
  1518. roi_extractor=bb_roi_extractor_cfg,
  1519. with_pool=with_pool,
  1520. bbox_assigner=bbox_assigner,
  1521. num_classes=num_classes)
  1522. mask_head = ppdet.modeling.MaskHead(
  1523. head=m_head,
  1524. roi_extractor=m_roi_extractor_cfg,
  1525. mask_assigner=mask_assigner,
  1526. share_bbox_feat=share_bbox_feat,
  1527. num_classes=num_classes)
  1528. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1529. num_classes=num_classes,
  1530. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1531. nms=ppdet.modeling.MultiClassNMS(
  1532. score_threshold=score_threshold,
  1533. keep_top_k=keep_top_k,
  1534. nms_threshold=nms_threshold))
  1535. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1536. params = {
  1537. 'backbone': backbone,
  1538. 'neck': neck,
  1539. 'rpn_head': rpn_head,
  1540. 'bbox_head': bbox_head,
  1541. 'mask_head': mask_head,
  1542. 'bbox_post_process': bbox_post_process,
  1543. 'mask_post_process': mask_post_process
  1544. }
  1545. self.with_fpn = with_fpn
  1546. super(MaskRCNN, self).__init__(
  1547. model_name='MaskRCNN', num_classes=num_classes, **params)
  1548. def _compose_batch_transform(self, transforms, mode='train'):
  1549. if mode == 'train':
  1550. default_batch_transforms = [
  1551. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1552. ]
  1553. collate_batch = False
  1554. else:
  1555. default_batch_transforms = [
  1556. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1557. ]
  1558. collate_batch = True
  1559. custom_batch_transforms = []
  1560. for i, op in enumerate(transforms.transforms):
  1561. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1562. if mode != 'train':
  1563. raise Exception(
  1564. "{} cannot be present in the {} transforms. ".format(
  1565. op.__class__.__name__, mode) +
  1566. "Please check the {} transforms.".format(mode))
  1567. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1568. batch_transforms = BatchCompose(
  1569. custom_batch_transforms + default_batch_transforms,
  1570. collate_batch=collate_batch)
  1571. return batch_transforms
  1572. def _fix_transforms_shape(self, image_shape):
  1573. if hasattr(self, 'test_transforms'):
  1574. if self.test_transforms is not None:
  1575. has_resize_op = False
  1576. resize_op_idx = -1
  1577. normalize_op_idx = len(self.test_transforms.transforms)
  1578. for idx, op in enumerate(self.test_transforms.transforms):
  1579. name = op.__class__.__name__
  1580. if name == 'ResizeByShort':
  1581. has_resize_op = True
  1582. resize_op_idx = idx
  1583. if name == 'Normalize':
  1584. normalize_op_idx = idx
  1585. if not has_resize_op:
  1586. self.test_transforms.transforms.insert(
  1587. normalize_op_idx,
  1588. Resize(
  1589. target_size=image_shape,
  1590. keep_ratio=True,
  1591. interp='CUBIC'))
  1592. else:
  1593. self.test_transforms.transforms[resize_op_idx] = Resize(
  1594. target_size=image_shape,
  1595. keep_ratio=True,
  1596. interp='CUBIC')
  1597. self.test_transforms.transforms.append(
  1598. Padding(im_padding_value=[0., 0., 0.]))
  1599. def _get_test_inputs(self, image_shape):
  1600. if image_shape is not None:
  1601. image_shape = self._check_image_shape(image_shape)
  1602. self._fix_transforms_shape(image_shape[-2:])
  1603. else:
  1604. image_shape = [None, 3, -1, -1]
  1605. if self.with_fpn:
  1606. self.test_transforms.transforms.append(
  1607. Padding(im_padding_value=[0., 0., 0.]))
  1608. self.fixed_input_shape = image_shape
  1609. return self._define_input_spec(image_shape)