detector.py 78 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import paddlex.ppdet as ppdet
  24. from paddlex.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. if 'with_net' in self.init_params:
  41. del self.init_params['with_net']
  42. super(BaseDetector, self).__init__('detector')
  43. if not hasattr(ppdet.modeling, model_name):
  44. raise Exception("ERROR: There's no model named {}.".format(
  45. model_name))
  46. self.model_name = model_name
  47. self.num_classes = num_classes
  48. self.labels = None
  49. if params.get('with_net', True):
  50. params.pop('with_net', None)
  51. self.net = self.build_net(**params)
  52. def build_net(self, **params):
  53. with paddle.utils.unique_name.guard():
  54. net = ppdet.modeling.__dict__[self.model_name](**params)
  55. return net
  56. def _fix_transforms_shape(self, image_shape):
  57. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  58. def _define_input_spec(self, image_shape):
  59. input_spec = [{
  60. "image": InputSpec(
  61. shape=image_shape, name='image', dtype='float32'),
  62. "im_shape": InputSpec(
  63. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  64. "scale_factor": InputSpec(
  65. shape=[image_shape[0], 2],
  66. name='scale_factor',
  67. dtype='float32')
  68. }]
  69. return input_spec
  70. def _check_image_shape(self, image_shape):
  71. if len(image_shape) == 2:
  72. image_shape = [1, 3] + image_shape
  73. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  74. raise Exception(
  75. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  76. format(image_shape[-2:]))
  77. return image_shape
  78. def _get_test_inputs(self, image_shape):
  79. if image_shape is not None:
  80. image_shape = self._check_image_shape(image_shape)
  81. self._fix_transforms_shape(image_shape[-2:])
  82. else:
  83. image_shape = [None, 3, -1, -1]
  84. self.fixed_input_shape = image_shape
  85. return self._define_input_spec(image_shape)
  86. def _get_backbone(self, backbone_name, **params):
  87. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  88. return backbone
  89. def run(self, net, inputs, mode):
  90. net_out = net(inputs)
  91. if mode in ['train', 'eval']:
  92. outputs = net_out
  93. else:
  94. outputs = dict()
  95. for key in net_out:
  96. outputs[key] = net_out[key].numpy()
  97. return outputs
  98. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  99. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  100. num_steps_each_epoch):
  101. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  102. values = [(lr_decay_gamma**i) * learning_rate
  103. for i in range(len(lr_decay_epochs) + 1)]
  104. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  105. boundaries=boundaries, values=values)
  106. if warmup_steps > 0:
  107. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  108. logging.error(
  109. "In function train(), parameters should satisfy: "
  110. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  111. exit=False)
  112. logging.error(
  113. "See this doc for more information: "
  114. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  115. exit=False)
  116. scheduler = paddle.optimizer.lr.LinearWarmup(
  117. learning_rate=scheduler,
  118. warmup_steps=warmup_steps,
  119. start_lr=warmup_start_lr,
  120. end_lr=learning_rate)
  121. optimizer = paddle.optimizer.Momentum(
  122. scheduler,
  123. momentum=.9,
  124. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  125. parameters=parameters)
  126. return optimizer
  127. def train(self,
  128. num_epochs,
  129. train_dataset,
  130. train_batch_size=64,
  131. eval_dataset=None,
  132. optimizer=None,
  133. save_interval_epochs=1,
  134. log_interval_steps=10,
  135. save_dir='output',
  136. pretrain_weights='IMAGENET',
  137. learning_rate=.001,
  138. warmup_steps=0,
  139. warmup_start_lr=0.0,
  140. lr_decay_epochs=(216, 243),
  141. lr_decay_gamma=0.1,
  142. metric=None,
  143. use_ema=False,
  144. early_stop=False,
  145. early_stop_patience=5,
  146. use_vdl=True,
  147. resume_checkpoint=None):
  148. """
  149. Train the model.
  150. Args:
  151. num_epochs(int): The number of epochs.
  152. train_dataset(paddlex.dataset): Training dataset.
  153. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  154. eval_dataset(paddlex.dataset, optional):
  155. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  156. optimizer(paddle.optimizer.Optimizer or None, optional):
  157. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  158. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  159. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  160. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  161. pretrain_weights(str or None, optional):
  162. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  163. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  164. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  165. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  166. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  167. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  168. metric({'VOC', 'COCO', None}, optional):
  169. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  170. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  171. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  172. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  173. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  174. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  175. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  176. `pretrain_weights` can be set simultaneously. Defaults to None.
  177. """
  178. if self.status == 'Infer':
  179. logging.error(
  180. "Exported inference model does not support training.",
  181. exit=True)
  182. if pretrain_weights is not None and resume_checkpoint is not None:
  183. logging.error(
  184. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  185. exit=True)
  186. if train_dataset.__class__.__name__ == 'VOCDetection':
  187. train_dataset.data_fields = {
  188. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  189. 'difficult'
  190. }
  191. elif train_dataset.__class__.__name__ == 'CocoDetection':
  192. if self.__class__.__name__ == 'MaskRCNN':
  193. train_dataset.data_fields = {
  194. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  195. 'gt_poly', 'is_crowd'
  196. }
  197. else:
  198. train_dataset.data_fields = {
  199. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  200. 'is_crowd'
  201. }
  202. if metric is None:
  203. if eval_dataset.__class__.__name__ == 'VOCDetection':
  204. self.metric = 'voc'
  205. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  206. self.metric = 'coco'
  207. else:
  208. assert metric.lower() in ['coco', 'voc'], \
  209. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  210. self.metric = metric.lower()
  211. self.labels = train_dataset.labels
  212. self.num_max_boxes = train_dataset.num_max_boxes
  213. train_dataset.batch_transforms = self._compose_batch_transform(
  214. train_dataset.transforms, mode='train')
  215. # build optimizer if not defined
  216. if optimizer is None:
  217. num_steps_each_epoch = len(train_dataset) // train_batch_size
  218. self.optimizer = self.default_optimizer(
  219. parameters=self.net.parameters(),
  220. learning_rate=learning_rate,
  221. warmup_steps=warmup_steps,
  222. warmup_start_lr=warmup_start_lr,
  223. lr_decay_epochs=lr_decay_epochs,
  224. lr_decay_gamma=lr_decay_gamma,
  225. num_steps_each_epoch=num_steps_each_epoch)
  226. else:
  227. self.optimizer = optimizer
  228. # initiate weights
  229. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  230. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  231. [self.model_name, self.backbone_name])]:
  232. logging.warning(
  233. "Path of pretrain_weights('{}') does not exist!".format(
  234. pretrain_weights))
  235. pretrain_weights = det_pretrain_weights_dict['_'.join(
  236. [self.model_name, self.backbone_name])][0]
  237. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  238. "If you don't want to use pretrain weights, "
  239. "set pretrain_weights to be None.".format(
  240. pretrain_weights))
  241. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  242. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  243. logging.error(
  244. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  245. exit=True)
  246. pretrained_dir = osp.join(save_dir, 'pretrain')
  247. self.net_initialize(
  248. pretrain_weights=pretrain_weights,
  249. save_dir=pretrained_dir,
  250. resume_checkpoint=resume_checkpoint)
  251. if use_ema:
  252. ema = ExponentialMovingAverage(
  253. decay=.9998, model=self.net, use_thres_step=True)
  254. else:
  255. ema = None
  256. # start train loop
  257. self.train_loop(
  258. num_epochs=num_epochs,
  259. train_dataset=train_dataset,
  260. train_batch_size=train_batch_size,
  261. eval_dataset=eval_dataset,
  262. save_interval_epochs=save_interval_epochs,
  263. log_interval_steps=log_interval_steps,
  264. save_dir=save_dir,
  265. ema=ema,
  266. early_stop=early_stop,
  267. early_stop_patience=early_stop_patience,
  268. use_vdl=use_vdl)
  269. def quant_aware_train(self,
  270. num_epochs,
  271. train_dataset,
  272. train_batch_size=64,
  273. eval_dataset=None,
  274. optimizer=None,
  275. save_interval_epochs=1,
  276. log_interval_steps=10,
  277. save_dir='output',
  278. learning_rate=.00001,
  279. warmup_steps=0,
  280. warmup_start_lr=0.0,
  281. lr_decay_epochs=(216, 243),
  282. lr_decay_gamma=0.1,
  283. metric=None,
  284. use_ema=False,
  285. early_stop=False,
  286. early_stop_patience=5,
  287. use_vdl=True,
  288. resume_checkpoint=None,
  289. quant_config=None):
  290. """
  291. Quantization-aware training.
  292. Args:
  293. num_epochs(int): The number of epochs.
  294. train_dataset(paddlex.dataset): Training dataset.
  295. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  296. eval_dataset(paddlex.dataset, optional):
  297. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  298. optimizer(paddle.optimizer.Optimizer or None, optional):
  299. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  300. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  301. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  302. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  303. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  304. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  305. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  306. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  307. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  308. metric({'VOC', 'COCO', None}, optional):
  309. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  310. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  311. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  312. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  313. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  314. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  315. configuration will be used. Defaults to None.
  316. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  317. from. If None, no training checkpoint will be resumed. Defaults to None.
  318. """
  319. self._prepare_qat(quant_config)
  320. self.train(
  321. num_epochs=num_epochs,
  322. train_dataset=train_dataset,
  323. train_batch_size=train_batch_size,
  324. eval_dataset=eval_dataset,
  325. optimizer=optimizer,
  326. save_interval_epochs=save_interval_epochs,
  327. log_interval_steps=log_interval_steps,
  328. save_dir=save_dir,
  329. pretrain_weights=None,
  330. learning_rate=learning_rate,
  331. warmup_steps=warmup_steps,
  332. warmup_start_lr=warmup_start_lr,
  333. lr_decay_epochs=lr_decay_epochs,
  334. lr_decay_gamma=lr_decay_gamma,
  335. metric=metric,
  336. use_ema=use_ema,
  337. early_stop=early_stop,
  338. early_stop_patience=early_stop_patience,
  339. use_vdl=use_vdl,
  340. resume_checkpoint=resume_checkpoint)
  341. def evaluate(self,
  342. eval_dataset,
  343. batch_size=1,
  344. metric=None,
  345. return_details=False):
  346. """
  347. Evaluate the model.
  348. Args:
  349. eval_dataset(paddlex.dataset): Evaluation dataset.
  350. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  351. metric({'VOC', 'COCO', None}, optional):
  352. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  353. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  354. Returns:
  355. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  356. """
  357. if metric is None:
  358. if not hasattr(self, 'metric'):
  359. if eval_dataset.__class__.__name__ == 'VOCDetection':
  360. self.metric = 'voc'
  361. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  362. self.metric = 'coco'
  363. else:
  364. assert metric.lower() in ['coco', 'voc'], \
  365. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  366. self.metric = metric.lower()
  367. if self.metric == 'voc':
  368. eval_dataset.data_fields = {
  369. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  370. 'difficult'
  371. }
  372. elif self.metric == 'coco':
  373. if self.__class__.__name__ == 'MaskRCNN':
  374. eval_dataset.data_fields = {
  375. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  376. 'gt_poly', 'is_crowd'
  377. }
  378. else:
  379. eval_dataset.data_fields = {
  380. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  381. 'is_crowd'
  382. }
  383. eval_dataset.batch_transforms = self._compose_batch_transform(
  384. eval_dataset.transforms, mode='eval')
  385. arrange_transforms(
  386. model_type=self.model_type,
  387. transforms=eval_dataset.transforms,
  388. mode='eval')
  389. self.net.eval()
  390. nranks = paddle.distributed.get_world_size()
  391. local_rank = paddle.distributed.get_rank()
  392. if nranks > 1:
  393. # Initialize parallel environment if not done.
  394. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  395. ):
  396. paddle.distributed.init_parallel_env()
  397. if batch_size > 1:
  398. logging.warning(
  399. "Detector only supports single card evaluation with batch_size=1 "
  400. "during evaluation, so batch_size is forcibly set to 1.")
  401. batch_size = 1
  402. if nranks < 2 or local_rank == 0:
  403. self.eval_data_loader = self.build_data_loader(
  404. eval_dataset, batch_size=batch_size, mode='eval')
  405. is_bbox_normalized = False
  406. if eval_dataset.batch_transforms is not None:
  407. is_bbox_normalized = any(
  408. isinstance(t, _NormalizeBox)
  409. for t in eval_dataset.batch_transforms.batch_transforms)
  410. if self.metric == 'voc':
  411. eval_metric = VOCMetric(
  412. labels=eval_dataset.labels,
  413. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  414. is_bbox_normalized=is_bbox_normalized,
  415. classwise=False)
  416. else:
  417. eval_metric = COCOMetric(
  418. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  419. classwise=False)
  420. scores = collections.OrderedDict()
  421. logging.info(
  422. "Start to evaluate(total_samples={}, total_steps={})...".
  423. format(eval_dataset.num_samples, eval_dataset.num_samples))
  424. with paddle.no_grad():
  425. for step, data in enumerate(self.eval_data_loader):
  426. outputs = self.run(self.net, data, 'eval')
  427. eval_metric.update(data, outputs)
  428. eval_metric.accumulate()
  429. self.eval_details = eval_metric.details
  430. scores.update(eval_metric.get())
  431. eval_metric.reset()
  432. if return_details:
  433. return scores, self.eval_details
  434. return scores
  435. def predict(self, img_file, transforms=None):
  436. """
  437. Do inference.
  438. Args:
  439. img_file(List[np.ndarray or str], str or np.ndarray):
  440. Image path or decoded image data in a BGR format, which also could constitute a list,
  441. meaning all images to be predicted as a mini-batch.
  442. transforms(paddlex.transforms.Compose or None, optional):
  443. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  444. Returns:
  445. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  446. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  447. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  448. category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
  449. category(str): category name
  450. bbox(list): bounding box in [x, y, w, h] format
  451. score(str): confidence
  452. mask(dict): Only for instance segmentation task. Mask of the object in RLE format
  453. """
  454. if transforms is None and not hasattr(self, 'test_transforms'):
  455. raise Exception("transforms need to be defined, now is None.")
  456. if transforms is None:
  457. transforms = self.test_transforms
  458. if isinstance(img_file, (str, np.ndarray)):
  459. images = [img_file]
  460. else:
  461. images = img_file
  462. batch_samples = self._preprocess(images, transforms)
  463. self.net.eval()
  464. outputs = self.run(self.net, batch_samples, 'test')
  465. prediction = self._postprocess(outputs)
  466. if isinstance(img_file, (str, np.ndarray)):
  467. prediction = prediction[0]
  468. return prediction
  469. def _preprocess(self, images, transforms, to_tensor=True):
  470. arrange_transforms(
  471. model_type=self.model_type, transforms=transforms, mode='test')
  472. batch_samples = list()
  473. for im in images:
  474. sample = {'image': im}
  475. batch_samples.append(transforms(sample))
  476. batch_transforms = self._compose_batch_transform(transforms, 'test')
  477. batch_samples = batch_transforms(batch_samples)
  478. if to_tensor:
  479. if isinstance(batch_samples, dict):
  480. for k in batch_samples:
  481. batch_samples[k] = paddle.to_tensor(batch_samples[k])
  482. else:
  483. for sample in batch_samples:
  484. for k in sample:
  485. sample[k] = paddle.to_tensor(sample[k])
  486. return batch_samples
  487. def _postprocess(self, batch_pred):
  488. infer_result = {}
  489. if 'bbox' in batch_pred:
  490. bboxes = batch_pred['bbox']
  491. bbox_nums = batch_pred['bbox_num']
  492. det_res = []
  493. k = 0
  494. for i in range(len(bbox_nums)):
  495. det_nums = bbox_nums[i]
  496. for j in range(det_nums):
  497. dt = bboxes[k]
  498. k = k + 1
  499. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  500. if int(num_id) < 0:
  501. continue
  502. category = self.labels[int(num_id)]
  503. w = xmax - xmin
  504. h = ymax - ymin
  505. bbox = [xmin, ymin, w, h]
  506. dt_res = {
  507. 'category_id': int(num_id),
  508. 'category': category,
  509. 'bbox': bbox,
  510. 'score': score
  511. }
  512. det_res.append(dt_res)
  513. infer_result['bbox'] = det_res
  514. if 'mask' in batch_pred:
  515. masks = batch_pred['mask']
  516. bboxes = batch_pred['bbox']
  517. mask_nums = batch_pred['bbox_num']
  518. seg_res = []
  519. k = 0
  520. for i in range(len(mask_nums)):
  521. det_nums = mask_nums[i]
  522. for j in range(det_nums):
  523. mask = masks[k].astype(np.uint8)
  524. score = float(bboxes[k][1])
  525. label = int(bboxes[k][0])
  526. k = k + 1
  527. if label == -1:
  528. continue
  529. category = self.labels[int(label)]
  530. import pycocotools.mask as mask_util
  531. rle = mask_util.encode(
  532. np.array(
  533. mask[:, :, None], order="F", dtype="uint8"))[0]
  534. if six.PY3:
  535. if 'counts' in rle:
  536. rle['counts'] = rle['counts'].decode("utf8")
  537. sg_res = {
  538. 'category_id': int(label),
  539. 'category': category,
  540. 'mask': rle,
  541. 'score': score
  542. }
  543. seg_res.append(sg_res)
  544. infer_result['mask'] = seg_res
  545. bbox_num = batch_pred['bbox_num']
  546. results = []
  547. start = 0
  548. for num in bbox_num:
  549. end = start + num
  550. curr_res = infer_result['bbox'][start:end]
  551. if 'mask' in infer_result:
  552. mask_res = infer_result['mask'][start:end]
  553. for box, mask in zip(curr_res, mask_res):
  554. box.update(mask)
  555. results.append(curr_res)
  556. start = end
  557. return results
  558. class YOLOv3(BaseDetector):
  559. def __init__(self,
  560. num_classes=80,
  561. backbone='MobileNetV1',
  562. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  563. [59, 119], [116, 90], [156, 198], [373, 326]],
  564. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  565. ignore_threshold=0.7,
  566. nms_score_threshold=0.01,
  567. nms_topk=1000,
  568. nms_keep_topk=100,
  569. nms_iou_threshold=0.45,
  570. label_smooth=False,
  571. **params):
  572. self.init_params = locals()
  573. if backbone not in [
  574. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  575. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  576. ]:
  577. raise ValueError(
  578. "backbone: {} is not supported. Please choose one of "
  579. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  580. format(backbone))
  581. self.backbone_name = backbone
  582. if params.get('with_net', True):
  583. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  584. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  585. norm_type = 'sync_bn'
  586. else:
  587. norm_type = 'bn'
  588. if 'MobileNetV1' in backbone:
  589. norm_type = 'bn'
  590. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  591. elif 'MobileNetV3' in backbone:
  592. backbone = self._get_backbone(
  593. 'MobileNetV3',
  594. norm_type=norm_type,
  595. feature_maps=[7, 13, 16])
  596. elif backbone == 'ResNet50_vd_dcn':
  597. backbone = self._get_backbone(
  598. 'ResNet',
  599. norm_type=norm_type,
  600. variant='d',
  601. return_idx=[1, 2, 3],
  602. dcn_v2_stages=[3],
  603. freeze_at=-1,
  604. freeze_norm=False)
  605. elif backbone == 'ResNet34':
  606. backbone = self._get_backbone(
  607. 'ResNet',
  608. depth=34,
  609. norm_type=norm_type,
  610. return_idx=[1, 2, 3],
  611. freeze_at=-1,
  612. freeze_norm=False,
  613. norm_decay=0.)
  614. else:
  615. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  616. neck = ppdet.modeling.YOLOv3FPN(
  617. norm_type=norm_type,
  618. in_channels=[i.channels for i in backbone.out_shape])
  619. loss = ppdet.modeling.YOLOv3Loss(
  620. num_classes=num_classes,
  621. ignore_thresh=ignore_threshold,
  622. label_smooth=label_smooth)
  623. yolo_head = ppdet.modeling.YOLOv3Head(
  624. in_channels=[i.channels for i in neck.out_shape],
  625. anchors=anchors,
  626. anchor_masks=anchor_masks,
  627. num_classes=num_classes,
  628. loss=loss)
  629. post_process = ppdet.modeling.BBoxPostProcess(
  630. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  631. nms=ppdet.modeling.MultiClassNMS(
  632. score_threshold=nms_score_threshold,
  633. nms_top_k=nms_topk,
  634. keep_top_k=nms_keep_topk,
  635. nms_threshold=nms_iou_threshold))
  636. params.update({
  637. 'backbone': backbone,
  638. 'neck': neck,
  639. 'yolo_head': yolo_head,
  640. 'post_process': post_process
  641. })
  642. super(YOLOv3, self).__init__(
  643. model_name='YOLOv3', num_classes=num_classes, **params)
  644. self.anchors = anchors
  645. self.anchor_masks = anchor_masks
  646. def _compose_batch_transform(self, transforms, mode='train'):
  647. if mode == 'train':
  648. default_batch_transforms = [
  649. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  650. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  651. _Gt2YoloTarget(
  652. anchor_masks=self.anchor_masks,
  653. anchors=self.anchors,
  654. downsample_ratios=getattr(self, 'downsample_ratios',
  655. [32, 16, 8]),
  656. num_classes=self.num_classes)
  657. ]
  658. else:
  659. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  660. if mode == 'eval' and self.metric == 'voc':
  661. collate_batch = False
  662. else:
  663. collate_batch = True
  664. custom_batch_transforms = []
  665. for i, op in enumerate(transforms.transforms):
  666. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  667. if mode != 'train':
  668. raise Exception(
  669. "{} cannot be present in the {} transforms. ".format(
  670. op.__class__.__name__, mode) +
  671. "Please check the {} transforms.".format(mode))
  672. custom_batch_transforms.insert(0, copy.deepcopy(op))
  673. batch_transforms = BatchCompose(
  674. custom_batch_transforms + default_batch_transforms,
  675. collate_batch=collate_batch)
  676. return batch_transforms
  677. def _fix_transforms_shape(self, image_shape):
  678. if hasattr(self, 'test_transforms'):
  679. if self.test_transforms is not None:
  680. has_resize_op = False
  681. resize_op_idx = -1
  682. normalize_op_idx = len(self.test_transforms.transforms)
  683. for idx, op in enumerate(self.test_transforms.transforms):
  684. name = op.__class__.__name__
  685. if name == 'Resize':
  686. has_resize_op = True
  687. resize_op_idx = idx
  688. if name == 'Normalize':
  689. normalize_op_idx = idx
  690. if not has_resize_op:
  691. self.test_transforms.transforms.insert(
  692. normalize_op_idx,
  693. Resize(
  694. target_size=image_shape, interp='CUBIC'))
  695. else:
  696. self.test_transforms.transforms[
  697. resize_op_idx].target_size = image_shape
  698. class FasterRCNN(BaseDetector):
  699. def __init__(self,
  700. num_classes=80,
  701. backbone='ResNet50',
  702. with_fpn=True,
  703. with_dcn=False,
  704. aspect_ratios=[0.5, 1.0, 2.0],
  705. anchor_sizes=[[32], [64], [128], [256], [512]],
  706. keep_top_k=100,
  707. nms_threshold=0.5,
  708. score_threshold=0.05,
  709. fpn_num_channels=256,
  710. rpn_batch_size_per_im=256,
  711. rpn_fg_fraction=0.5,
  712. test_pre_nms_top_n=None,
  713. test_post_nms_top_n=1000,
  714. **params):
  715. self.init_params = locals()
  716. if backbone not in [
  717. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  718. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  719. ]:
  720. raise ValueError(
  721. "backbone: {} is not supported. Please choose one of "
  722. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  723. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  724. self.backbone_name = backbone
  725. if params.get('with_net', True):
  726. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  727. if backbone == 'HRNet_W18':
  728. if not with_fpn:
  729. logging.warning(
  730. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  731. format(backbone))
  732. with_fpn = True
  733. if with_dcn:
  734. logging.warning(
  735. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  736. format(backbone))
  737. backbone = self._get_backbone(
  738. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  739. elif backbone == 'ResNet50_vd_ssld':
  740. if not with_fpn:
  741. logging.warning(
  742. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  743. format(backbone))
  744. with_fpn = True
  745. backbone = self._get_backbone(
  746. 'ResNet',
  747. variant='d',
  748. norm_type='bn',
  749. freeze_at=0,
  750. return_idx=[0, 1, 2, 3],
  751. num_stages=4,
  752. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  753. dcn_v2_stages=dcn_v2_stages)
  754. elif 'ResNet50' in backbone:
  755. if with_fpn:
  756. backbone = self._get_backbone(
  757. 'ResNet',
  758. variant='d' if '_vd' in backbone else 'b',
  759. norm_type='bn',
  760. freeze_at=0,
  761. return_idx=[0, 1, 2, 3],
  762. num_stages=4,
  763. dcn_v2_stages=dcn_v2_stages)
  764. else:
  765. if with_dcn:
  766. logging.warning(
  767. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  768. format(backbone))
  769. backbone = self._get_backbone(
  770. 'ResNet',
  771. variant='d' if '_vd' in backbone else 'b',
  772. norm_type='bn',
  773. freeze_at=0,
  774. return_idx=[2],
  775. num_stages=3)
  776. elif 'ResNet34' in backbone:
  777. if not with_fpn:
  778. logging.warning(
  779. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  780. format(backbone))
  781. with_fpn = True
  782. backbone = self._get_backbone(
  783. 'ResNet',
  784. depth=34,
  785. variant='d' if 'vd' in backbone else 'b',
  786. norm_type='bn',
  787. freeze_at=0,
  788. return_idx=[0, 1, 2, 3],
  789. num_stages=4,
  790. dcn_v2_stages=dcn_v2_stages)
  791. else:
  792. if not with_fpn:
  793. logging.warning(
  794. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  795. format(backbone))
  796. with_fpn = True
  797. backbone = self._get_backbone(
  798. 'ResNet',
  799. depth=101,
  800. variant='d' if 'vd' in backbone else 'b',
  801. norm_type='bn',
  802. freeze_at=0,
  803. return_idx=[0, 1, 2, 3],
  804. num_stages=4,
  805. dcn_v2_stages=dcn_v2_stages)
  806. rpn_in_channel = backbone.out_shape[0].channels
  807. if with_fpn:
  808. self.backbone_name = self.backbone_name + '_fpn'
  809. if 'HRNet' in self.backbone_name:
  810. neck = ppdet.modeling.HRFPN(
  811. in_channels=[i.channels for i in backbone.out_shape],
  812. out_channel=fpn_num_channels,
  813. spatial_scales=[
  814. 1.0 / i.stride for i in backbone.out_shape
  815. ],
  816. share_conv=False)
  817. else:
  818. neck = ppdet.modeling.FPN(
  819. in_channels=[i.channels for i in backbone.out_shape],
  820. out_channel=fpn_num_channels,
  821. spatial_scales=[
  822. 1.0 / i.stride for i in backbone.out_shape
  823. ])
  824. rpn_in_channel = neck.out_shape[0].channels
  825. anchor_generator_cfg = {
  826. 'aspect_ratios': aspect_ratios,
  827. 'anchor_sizes': anchor_sizes,
  828. 'strides': [4, 8, 16, 32, 64]
  829. }
  830. train_proposal_cfg = {
  831. 'min_size': 0.0,
  832. 'nms_thresh': .7,
  833. 'pre_nms_top_n': 2000,
  834. 'post_nms_top_n': 1000,
  835. 'topk_after_collect': True
  836. }
  837. test_proposal_cfg = {
  838. 'min_size': 0.0,
  839. 'nms_thresh': .7,
  840. 'pre_nms_top_n': 1000
  841. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  842. 'post_nms_top_n': test_post_nms_top_n
  843. }
  844. head = ppdet.modeling.TwoFCHead(
  845. in_channel=neck.out_shape[0].channels, out_channel=1024)
  846. roi_extractor_cfg = {
  847. 'resolution': 7,
  848. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  849. 'sampling_ratio': 0,
  850. 'aligned': True
  851. }
  852. with_pool = False
  853. else:
  854. neck = None
  855. anchor_generator_cfg = {
  856. 'aspect_ratios': aspect_ratios,
  857. 'anchor_sizes': anchor_sizes,
  858. 'strides': [16]
  859. }
  860. train_proposal_cfg = {
  861. 'min_size': 0.0,
  862. 'nms_thresh': .7,
  863. 'pre_nms_top_n': 12000,
  864. 'post_nms_top_n': 2000,
  865. 'topk_after_collect': False
  866. }
  867. test_proposal_cfg = {
  868. 'min_size': 0.0,
  869. 'nms_thresh': .7,
  870. 'pre_nms_top_n': 6000
  871. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  872. 'post_nms_top_n': test_post_nms_top_n
  873. }
  874. head = ppdet.modeling.Res5Head()
  875. roi_extractor_cfg = {
  876. 'resolution': 14,
  877. 'spatial_scale':
  878. [1. / i.stride for i in backbone.out_shape],
  879. 'sampling_ratio': 0,
  880. 'aligned': True
  881. }
  882. with_pool = True
  883. rpn_target_assign_cfg = {
  884. 'batch_size_per_im': rpn_batch_size_per_im,
  885. 'fg_fraction': rpn_fg_fraction,
  886. 'negative_overlap': .3,
  887. 'positive_overlap': .7,
  888. 'use_random': True
  889. }
  890. rpn_head = ppdet.modeling.RPNHead(
  891. anchor_generator=anchor_generator_cfg,
  892. rpn_target_assign=rpn_target_assign_cfg,
  893. train_proposal=train_proposal_cfg,
  894. test_proposal=test_proposal_cfg,
  895. in_channel=rpn_in_channel)
  896. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  897. bbox_head = ppdet.modeling.BBoxHead(
  898. head=head,
  899. in_channel=head.out_shape[0].channels,
  900. roi_extractor=roi_extractor_cfg,
  901. with_pool=with_pool,
  902. bbox_assigner=bbox_assigner,
  903. num_classes=num_classes)
  904. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  905. num_classes=num_classes,
  906. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  907. nms=ppdet.modeling.MultiClassNMS(
  908. score_threshold=score_threshold,
  909. keep_top_k=keep_top_k,
  910. nms_threshold=nms_threshold))
  911. params.update({
  912. 'backbone': backbone,
  913. 'neck': neck,
  914. 'rpn_head': rpn_head,
  915. 'bbox_head': bbox_head,
  916. 'bbox_post_process': bbox_post_process
  917. })
  918. else:
  919. if backbone not in ['ResNet50', 'ResNet50_vd']:
  920. with_fpn = True
  921. self.with_fpn = with_fpn
  922. super(FasterRCNN, self).__init__(
  923. model_name='FasterRCNN', num_classes=num_classes, **params)
  924. def run(self, net, inputs, mode):
  925. if mode in ['train', 'eval']:
  926. outputs = net(inputs)
  927. else:
  928. outputs = []
  929. for sample in inputs:
  930. net_out = net(sample)
  931. for key in net_out:
  932. net_out[key] = net_out[key].numpy()
  933. outputs.append(net_out)
  934. return outputs
  935. def _compose_batch_transform(self, transforms, mode='train'):
  936. if mode == 'train':
  937. default_batch_transforms = [
  938. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  939. ]
  940. collate_batch = False
  941. else:
  942. default_batch_transforms = [
  943. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  944. ]
  945. collate_batch = True
  946. custom_batch_transforms = []
  947. for i, op in enumerate(transforms.transforms):
  948. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  949. if mode != 'train':
  950. raise Exception(
  951. "{} cannot be present in the {} transforms. ".format(
  952. op.__class__.__name__, mode) +
  953. "Please check the {} transforms.".format(mode))
  954. custom_batch_transforms.insert(0, copy.deepcopy(op))
  955. batch_transforms = BatchCompose(
  956. custom_batch_transforms + default_batch_transforms,
  957. collate_batch=collate_batch,
  958. return_list=mode == 'test')
  959. return batch_transforms
  960. def _fix_transforms_shape(self, image_shape):
  961. if hasattr(self, 'test_transforms'):
  962. if self.test_transforms is not None:
  963. has_resize_op = False
  964. resize_op_idx = -1
  965. normalize_op_idx = len(self.test_transforms.transforms)
  966. for idx, op in enumerate(self.test_transforms.transforms):
  967. name = op.__class__.__name__
  968. if name == 'ResizeByShort':
  969. has_resize_op = True
  970. resize_op_idx = idx
  971. if name == 'Normalize':
  972. normalize_op_idx = idx
  973. if not has_resize_op:
  974. self.test_transforms.transforms.insert(
  975. normalize_op_idx,
  976. Resize(
  977. target_size=image_shape,
  978. keep_ratio=True,
  979. interp='CUBIC'))
  980. else:
  981. self.test_transforms.transforms[resize_op_idx] = Resize(
  982. target_size=image_shape,
  983. keep_ratio=True,
  984. interp='CUBIC')
  985. self.test_transforms.transforms.append(
  986. Padding(im_padding_value=[0., 0., 0.]))
  987. def _get_test_inputs(self, image_shape):
  988. if image_shape is not None:
  989. image_shape = self._check_image_shape(image_shape)
  990. self._fix_transforms_shape(image_shape[-2:])
  991. else:
  992. image_shape = [None, 3, -1, -1]
  993. if self.with_fpn:
  994. self.test_transforms.transforms.append(
  995. Padding(im_padding_value=[0., 0., 0.]))
  996. self.fixed_input_shape = image_shape
  997. return self._define_input_spec(image_shape)
  998. def _postprocess(self, batch_pred):
  999. prediction = [
  1000. super(FasterRCNN, self)._postprocess(pred)[0]
  1001. for pred in batch_pred
  1002. ]
  1003. return prediction
  1004. class PPYOLO(YOLOv3):
  1005. def __init__(self,
  1006. num_classes=80,
  1007. backbone='ResNet50_vd_dcn',
  1008. anchors=None,
  1009. anchor_masks=None,
  1010. use_coord_conv=True,
  1011. use_iou_aware=True,
  1012. use_spp=True,
  1013. use_drop_block=True,
  1014. scale_x_y=1.05,
  1015. ignore_threshold=0.7,
  1016. label_smooth=False,
  1017. use_iou_loss=True,
  1018. use_matrix_nms=True,
  1019. nms_score_threshold=0.01,
  1020. nms_topk=-1,
  1021. nms_keep_topk=100,
  1022. nms_iou_threshold=0.45,
  1023. **params):
  1024. self.init_params = locals()
  1025. if backbone not in [
  1026. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  1027. 'MobileNetV3_small'
  1028. ]:
  1029. raise ValueError(
  1030. "backbone: {} is not supported. Please choose one of "
  1031. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  1032. format(backbone))
  1033. self.backbone_name = backbone
  1034. if params.get('with_net', True):
  1035. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1036. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1037. norm_type = 'sync_bn'
  1038. else:
  1039. norm_type = 'bn'
  1040. if anchors is None and anchor_masks is None:
  1041. if 'MobileNetV3' in backbone:
  1042. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1043. [120, 195], [254, 235]]
  1044. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1045. elif backbone == 'ResNet50_vd_dcn':
  1046. anchors = [[10, 13], [16, 30], [33, 23], [30, 61],
  1047. [62, 45], [59, 119], [116, 90], [156, 198],
  1048. [373, 326]]
  1049. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1050. else:
  1051. anchors = [[10, 14], [23, 27], [37, 58], [81, 82],
  1052. [135, 169], [344, 319]]
  1053. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1054. elif anchors is None or anchor_masks is None:
  1055. raise ValueError(
  1056. "Please define both anchors and anchor_masks.")
  1057. if backbone == 'ResNet50_vd_dcn':
  1058. backbone = self._get_backbone(
  1059. 'ResNet',
  1060. variant='d',
  1061. norm_type=norm_type,
  1062. return_idx=[1, 2, 3],
  1063. dcn_v2_stages=[3],
  1064. freeze_at=-1,
  1065. freeze_norm=False,
  1066. norm_decay=0.)
  1067. downsample_ratios = [32, 16, 8]
  1068. elif backbone == 'ResNet18_vd':
  1069. backbone = self._get_backbone(
  1070. 'ResNet',
  1071. depth=18,
  1072. variant='d',
  1073. norm_type=norm_type,
  1074. return_idx=[2, 3],
  1075. freeze_at=-1,
  1076. freeze_norm=False,
  1077. norm_decay=0.)
  1078. downsample_ratios = [32, 16]
  1079. elif backbone == 'MobileNetV3_large':
  1080. backbone = self._get_backbone(
  1081. 'MobileNetV3',
  1082. model_name='large',
  1083. norm_type=norm_type,
  1084. scale=1,
  1085. with_extra_blocks=False,
  1086. extra_block_filters=[],
  1087. feature_maps=[13, 16])
  1088. downsample_ratios = [32, 16]
  1089. elif backbone == 'MobileNetV3_small':
  1090. backbone = self._get_backbone(
  1091. 'MobileNetV3',
  1092. model_name='small',
  1093. norm_type=norm_type,
  1094. scale=1,
  1095. with_extra_blocks=False,
  1096. extra_block_filters=[],
  1097. feature_maps=[9, 12])
  1098. downsample_ratios = [32, 16]
  1099. neck = ppdet.modeling.PPYOLOFPN(
  1100. norm_type=norm_type,
  1101. in_channels=[i.channels for i in backbone.out_shape],
  1102. coord_conv=use_coord_conv,
  1103. drop_block=use_drop_block,
  1104. spp=use_spp,
  1105. conv_block_num=0
  1106. if ('MobileNetV3' in self.backbone_name or
  1107. self.backbone_name == 'ResNet18_vd') else 2)
  1108. loss = ppdet.modeling.YOLOv3Loss(
  1109. num_classes=num_classes,
  1110. ignore_thresh=ignore_threshold,
  1111. downsample=downsample_ratios,
  1112. label_smooth=label_smooth,
  1113. scale_x_y=scale_x_y,
  1114. iou_loss=ppdet.modeling.IouLoss(
  1115. loss_weight=2.5, loss_square=True)
  1116. if use_iou_loss else None,
  1117. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1118. if use_iou_aware else None)
  1119. yolo_head = ppdet.modeling.YOLOv3Head(
  1120. in_channels=[i.channels for i in neck.out_shape],
  1121. anchors=anchors,
  1122. anchor_masks=anchor_masks,
  1123. num_classes=num_classes,
  1124. loss=loss,
  1125. iou_aware=use_iou_aware)
  1126. if use_matrix_nms:
  1127. nms = ppdet.modeling.MatrixNMS(
  1128. keep_top_k=nms_keep_topk,
  1129. score_threshold=nms_score_threshold,
  1130. post_threshold=.05
  1131. if 'MobileNetV3' in self.backbone_name else .01,
  1132. nms_top_k=nms_topk,
  1133. background_label=-1)
  1134. else:
  1135. nms = ppdet.modeling.MultiClassNMS(
  1136. score_threshold=nms_score_threshold,
  1137. nms_top_k=nms_topk,
  1138. keep_top_k=nms_keep_topk,
  1139. nms_threshold=nms_iou_threshold)
  1140. post_process = ppdet.modeling.BBoxPostProcess(
  1141. decode=ppdet.modeling.YOLOBox(
  1142. num_classes=num_classes,
  1143. conf_thresh=.005
  1144. if 'MobileNetV3' in self.backbone_name else .01,
  1145. scale_x_y=scale_x_y),
  1146. nms=nms)
  1147. params.update({
  1148. 'backbone': backbone,
  1149. 'neck': neck,
  1150. 'yolo_head': yolo_head,
  1151. 'post_process': post_process
  1152. })
  1153. super(YOLOv3, self).__init__(
  1154. model_name='YOLOv3', num_classes=num_classes, **params)
  1155. self.anchors = anchors
  1156. self.anchor_masks = anchor_masks
  1157. self.downsample_ratios = downsample_ratios
  1158. self.model_name = 'PPYOLO'
  1159. def _get_test_inputs(self, image_shape):
  1160. if image_shape is not None:
  1161. image_shape = self._check_image_shape(image_shape)
  1162. self._fix_transforms_shape(image_shape[-2:])
  1163. else:
  1164. image_shape = [None, 3, 608, 608]
  1165. if hasattr(self, 'test_transforms'):
  1166. if self.test_transforms is not None:
  1167. for idx, op in enumerate(self.test_transforms.transforms):
  1168. name = op.__class__.__name__
  1169. if name == 'Resize':
  1170. image_shape = [None, 3] + list(
  1171. self.test_transforms.transforms[
  1172. idx].target_size)
  1173. logging.warning(
  1174. '[Important!!!] When exporting inference model for {},'.format(
  1175. self.__class__.__name__) +
  1176. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1177. format(image_shape) +
  1178. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1179. format(image_shape[1:]) + 'should be specified manually.')
  1180. self.fixed_input_shape = image_shape
  1181. return self._define_input_spec(image_shape)
  1182. class PPYOLOTiny(YOLOv3):
  1183. def __init__(self,
  1184. num_classes=80,
  1185. backbone='MobileNetV3',
  1186. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1187. [60, 170], [220, 125], [128, 222], [264, 266]],
  1188. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1189. use_iou_aware=False,
  1190. use_spp=True,
  1191. use_drop_block=True,
  1192. scale_x_y=1.05,
  1193. ignore_threshold=0.5,
  1194. label_smooth=False,
  1195. use_iou_loss=True,
  1196. use_matrix_nms=False,
  1197. nms_score_threshold=0.005,
  1198. nms_topk=1000,
  1199. nms_keep_topk=100,
  1200. nms_iou_threshold=0.45,
  1201. **params):
  1202. self.init_params = locals()
  1203. if backbone != 'MobileNetV3':
  1204. logging.warning(
  1205. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1206. "Backbone is forcibly set to MobileNetV3.")
  1207. self.backbone_name = 'MobileNetV3'
  1208. if params.get('with_net', True):
  1209. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1210. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1211. norm_type = 'sync_bn'
  1212. else:
  1213. norm_type = 'bn'
  1214. backbone = self._get_backbone(
  1215. 'MobileNetV3',
  1216. model_name='large',
  1217. norm_type=norm_type,
  1218. scale=.5,
  1219. with_extra_blocks=False,
  1220. extra_block_filters=[],
  1221. feature_maps=[7, 13, 16])
  1222. downsample_ratios = [32, 16, 8]
  1223. neck = ppdet.modeling.PPYOLOTinyFPN(
  1224. detection_block_channels=[160, 128, 96],
  1225. in_channels=[i.channels for i in backbone.out_shape],
  1226. spp=use_spp,
  1227. drop_block=use_drop_block)
  1228. loss = ppdet.modeling.YOLOv3Loss(
  1229. num_classes=num_classes,
  1230. ignore_thresh=ignore_threshold,
  1231. downsample=downsample_ratios,
  1232. label_smooth=label_smooth,
  1233. scale_x_y=scale_x_y,
  1234. iou_loss=ppdet.modeling.IouLoss(
  1235. loss_weight=2.5, loss_square=True)
  1236. if use_iou_loss else None,
  1237. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1238. if use_iou_aware else None)
  1239. yolo_head = ppdet.modeling.YOLOv3Head(
  1240. in_channels=[i.channels for i in neck.out_shape],
  1241. anchors=anchors,
  1242. anchor_masks=anchor_masks,
  1243. num_classes=num_classes,
  1244. loss=loss,
  1245. iou_aware=use_iou_aware)
  1246. if use_matrix_nms:
  1247. nms = ppdet.modeling.MatrixNMS(
  1248. keep_top_k=nms_keep_topk,
  1249. score_threshold=nms_score_threshold,
  1250. post_threshold=.05,
  1251. nms_top_k=nms_topk,
  1252. background_label=-1)
  1253. else:
  1254. nms = ppdet.modeling.MultiClassNMS(
  1255. score_threshold=nms_score_threshold,
  1256. nms_top_k=nms_topk,
  1257. keep_top_k=nms_keep_topk,
  1258. nms_threshold=nms_iou_threshold)
  1259. post_process = ppdet.modeling.BBoxPostProcess(
  1260. decode=ppdet.modeling.YOLOBox(
  1261. num_classes=num_classes,
  1262. conf_thresh=.005,
  1263. downsample_ratio=32,
  1264. clip_bbox=True,
  1265. scale_x_y=scale_x_y),
  1266. nms=nms)
  1267. params.update({
  1268. 'backbone': backbone,
  1269. 'neck': neck,
  1270. 'yolo_head': yolo_head,
  1271. 'post_process': post_process
  1272. })
  1273. super(YOLOv3, self).__init__(
  1274. model_name='YOLOv3', num_classes=num_classes, **params)
  1275. self.anchors = anchors
  1276. self.anchor_masks = anchor_masks
  1277. self.downsample_ratios = downsample_ratios
  1278. self.model_name = 'PPYOLOTiny'
  1279. def _get_test_inputs(self, image_shape):
  1280. if image_shape is not None:
  1281. image_shape = self._check_image_shape(image_shape)
  1282. self._fix_transforms_shape(image_shape[-2:])
  1283. else:
  1284. image_shape = [None, 3, 320, 320]
  1285. if hasattr(self, 'test_transforms'):
  1286. if self.test_transforms is not None:
  1287. for idx, op in enumerate(self.test_transforms.transforms):
  1288. name = op.__class__.__name__
  1289. if name == 'Resize':
  1290. image_shape = [None, 3] + list(
  1291. self.test_transforms.transforms[
  1292. idx].target_size)
  1293. logging.warning(
  1294. '[Important!!!] When exporting inference model for {},'.format(
  1295. self.__class__.__name__) +
  1296. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1297. format(image_shape) +
  1298. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1299. format(image_shape[1:]) + 'should be specified manually.')
  1300. self.fixed_input_shape = image_shape
  1301. return self._define_input_spec(image_shape)
  1302. class PPYOLOv2(YOLOv3):
  1303. def __init__(self,
  1304. num_classes=80,
  1305. backbone='ResNet50_vd_dcn',
  1306. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1307. [59, 119], [116, 90], [156, 198], [373, 326]],
  1308. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1309. use_iou_aware=True,
  1310. use_spp=True,
  1311. use_drop_block=True,
  1312. scale_x_y=1.05,
  1313. ignore_threshold=0.7,
  1314. label_smooth=False,
  1315. use_iou_loss=True,
  1316. use_matrix_nms=True,
  1317. nms_score_threshold=0.01,
  1318. nms_topk=-1,
  1319. nms_keep_topk=100,
  1320. nms_iou_threshold=0.45,
  1321. **params):
  1322. self.init_params = locals()
  1323. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1324. raise ValueError(
  1325. "backbone: {} is not supported. Please choose one of "
  1326. "('ResNet50_vd_dcn', 'ResNet101_vd_dcn')".format(backbone))
  1327. self.backbone_name = backbone
  1328. if params.get('with_net', True):
  1329. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1330. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1331. norm_type = 'sync_bn'
  1332. else:
  1333. norm_type = 'bn'
  1334. if backbone == 'ResNet50_vd_dcn':
  1335. backbone = self._get_backbone(
  1336. 'ResNet',
  1337. variant='d',
  1338. norm_type=norm_type,
  1339. return_idx=[1, 2, 3],
  1340. dcn_v2_stages=[3],
  1341. freeze_at=-1,
  1342. freeze_norm=False,
  1343. norm_decay=0.)
  1344. downsample_ratios = [32, 16, 8]
  1345. elif backbone == 'ResNet101_vd_dcn':
  1346. backbone = self._get_backbone(
  1347. 'ResNet',
  1348. depth=101,
  1349. variant='d',
  1350. norm_type=norm_type,
  1351. return_idx=[1, 2, 3],
  1352. dcn_v2_stages=[3],
  1353. freeze_at=-1,
  1354. freeze_norm=False,
  1355. norm_decay=0.)
  1356. downsample_ratios = [32, 16, 8]
  1357. neck = ppdet.modeling.PPYOLOPAN(
  1358. norm_type=norm_type,
  1359. in_channels=[i.channels for i in backbone.out_shape],
  1360. drop_block=use_drop_block,
  1361. block_size=3,
  1362. keep_prob=.9,
  1363. spp=use_spp)
  1364. loss = ppdet.modeling.YOLOv3Loss(
  1365. num_classes=num_classes,
  1366. ignore_thresh=ignore_threshold,
  1367. downsample=downsample_ratios,
  1368. label_smooth=label_smooth,
  1369. scale_x_y=scale_x_y,
  1370. iou_loss=ppdet.modeling.IouLoss(
  1371. loss_weight=2.5, loss_square=True)
  1372. if use_iou_loss else None,
  1373. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1374. if use_iou_aware else None)
  1375. yolo_head = ppdet.modeling.YOLOv3Head(
  1376. in_channels=[i.channels for i in neck.out_shape],
  1377. anchors=anchors,
  1378. anchor_masks=anchor_masks,
  1379. num_classes=num_classes,
  1380. loss=loss,
  1381. iou_aware=use_iou_aware,
  1382. iou_aware_factor=.5)
  1383. if use_matrix_nms:
  1384. nms = ppdet.modeling.MatrixNMS(
  1385. keep_top_k=nms_keep_topk,
  1386. score_threshold=nms_score_threshold,
  1387. post_threshold=.01,
  1388. nms_top_k=nms_topk,
  1389. background_label=-1)
  1390. else:
  1391. nms = ppdet.modeling.MultiClassNMS(
  1392. score_threshold=nms_score_threshold,
  1393. nms_top_k=nms_topk,
  1394. keep_top_k=nms_keep_topk,
  1395. nms_threshold=nms_iou_threshold)
  1396. post_process = ppdet.modeling.BBoxPostProcess(
  1397. decode=ppdet.modeling.YOLOBox(
  1398. num_classes=num_classes,
  1399. conf_thresh=.01,
  1400. downsample_ratio=32,
  1401. clip_bbox=True,
  1402. scale_x_y=scale_x_y),
  1403. nms=nms)
  1404. params.update({
  1405. 'backbone': backbone,
  1406. 'neck': neck,
  1407. 'yolo_head': yolo_head,
  1408. 'post_process': post_process
  1409. })
  1410. super(YOLOv3, self).__init__(
  1411. model_name='YOLOv3', num_classes=num_classes, **params)
  1412. self.anchors = anchors
  1413. self.anchor_masks = anchor_masks
  1414. self.downsample_ratios = downsample_ratios
  1415. self.model_name = 'PPYOLOv2'
  1416. def _get_test_inputs(self, image_shape):
  1417. if image_shape is not None:
  1418. image_shape = self._check_image_shape(image_shape)
  1419. self._fix_transforms_shape(image_shape[-2:])
  1420. else:
  1421. image_shape = [None, 3, 640, 640]
  1422. if hasattr(self, 'test_transforms'):
  1423. if self.test_transforms is not None:
  1424. for idx, op in enumerate(self.test_transforms.transforms):
  1425. name = op.__class__.__name__
  1426. if name == 'Resize':
  1427. image_shape = [None, 3] + list(
  1428. self.test_transforms.transforms[
  1429. idx].target_size)
  1430. logging.warning(
  1431. '[Important!!!] When exporting inference model for {},'.format(
  1432. self.__class__.__name__) +
  1433. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1434. format(image_shape) +
  1435. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1436. format(image_shape[1:]) + 'should be specified manually.')
  1437. self.fixed_input_shape = image_shape
  1438. return self._define_input_spec(image_shape)
  1439. class MaskRCNN(FasterRCNN):
  1440. def __init__(self,
  1441. num_classes=80,
  1442. backbone='ResNet50_vd',
  1443. with_fpn=True,
  1444. with_dcn=False,
  1445. aspect_ratios=[0.5, 1.0, 2.0],
  1446. anchor_sizes=[[32], [64], [128], [256], [512]],
  1447. keep_top_k=100,
  1448. nms_threshold=0.5,
  1449. score_threshold=0.05,
  1450. fpn_num_channels=256,
  1451. rpn_batch_size_per_im=256,
  1452. rpn_fg_fraction=0.5,
  1453. test_pre_nms_top_n=None,
  1454. test_post_nms_top_n=1000,
  1455. **params):
  1456. self.init_params = locals()
  1457. if backbone not in [
  1458. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1459. 'ResNet101_vd'
  1460. ]:
  1461. raise ValueError(
  1462. "backbone: {} is not supported. Please choose one of "
  1463. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1464. format(backbone))
  1465. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1466. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1467. if params.get('with_net', True):
  1468. if backbone == 'ResNet50':
  1469. if with_fpn:
  1470. backbone = self._get_backbone(
  1471. 'ResNet',
  1472. norm_type='bn',
  1473. freeze_at=0,
  1474. return_idx=[0, 1, 2, 3],
  1475. num_stages=4,
  1476. dcn_v2_stages=dcn_v2_stages)
  1477. else:
  1478. if with_dcn:
  1479. logging.warning(
  1480. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1481. format(backbone))
  1482. backbone = self._get_backbone(
  1483. 'ResNet',
  1484. norm_type='bn',
  1485. freeze_at=0,
  1486. return_idx=[2],
  1487. num_stages=3)
  1488. elif 'ResNet50_vd' in backbone:
  1489. if not with_fpn:
  1490. logging.warning(
  1491. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1492. format(backbone))
  1493. with_fpn = True
  1494. backbone = self._get_backbone(
  1495. 'ResNet',
  1496. variant='d',
  1497. norm_type='bn',
  1498. freeze_at=0,
  1499. return_idx=[0, 1, 2, 3],
  1500. num_stages=4,
  1501. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1502. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1503. dcn_v2_stages=dcn_v2_stages)
  1504. else:
  1505. if not with_fpn:
  1506. logging.warning(
  1507. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1508. format(backbone))
  1509. with_fpn = True
  1510. backbone = self._get_backbone(
  1511. 'ResNet',
  1512. variant='d' if '_vd' in backbone else 'b',
  1513. depth=101,
  1514. norm_type='bn',
  1515. freeze_at=0,
  1516. return_idx=[0, 1, 2, 3],
  1517. num_stages=4,
  1518. dcn_v2_stages=dcn_v2_stages)
  1519. rpn_in_channel = backbone.out_shape[0].channels
  1520. if with_fpn:
  1521. neck = ppdet.modeling.FPN(
  1522. in_channels=[i.channels for i in backbone.out_shape],
  1523. out_channel=fpn_num_channels,
  1524. spatial_scales=[
  1525. 1.0 / i.stride for i in backbone.out_shape
  1526. ])
  1527. rpn_in_channel = neck.out_shape[0].channels
  1528. anchor_generator_cfg = {
  1529. 'aspect_ratios': aspect_ratios,
  1530. 'anchor_sizes': anchor_sizes,
  1531. 'strides': [4, 8, 16, 32, 64]
  1532. }
  1533. train_proposal_cfg = {
  1534. 'min_size': 0.0,
  1535. 'nms_thresh': .7,
  1536. 'pre_nms_top_n': 2000,
  1537. 'post_nms_top_n': 1000,
  1538. 'topk_after_collect': True
  1539. }
  1540. test_proposal_cfg = {
  1541. 'min_size': 0.0,
  1542. 'nms_thresh': .7,
  1543. 'pre_nms_top_n': 1000
  1544. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1545. 'post_nms_top_n': test_post_nms_top_n
  1546. }
  1547. bb_head = ppdet.modeling.TwoFCHead(
  1548. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1549. bb_roi_extractor_cfg = {
  1550. 'resolution': 7,
  1551. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1552. 'sampling_ratio': 0,
  1553. 'aligned': True
  1554. }
  1555. with_pool = False
  1556. m_head = ppdet.modeling.MaskFeat(
  1557. in_channel=neck.out_shape[0].channels,
  1558. out_channel=256,
  1559. num_convs=4)
  1560. m_roi_extractor_cfg = {
  1561. 'resolution': 14,
  1562. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1563. 'sampling_ratio': 0,
  1564. 'aligned': True
  1565. }
  1566. mask_assigner = MaskAssigner(
  1567. num_classes=num_classes, mask_resolution=28)
  1568. share_bbox_feat = False
  1569. else:
  1570. neck = None
  1571. anchor_generator_cfg = {
  1572. 'aspect_ratios': aspect_ratios,
  1573. 'anchor_sizes': anchor_sizes,
  1574. 'strides': [16]
  1575. }
  1576. train_proposal_cfg = {
  1577. 'min_size': 0.0,
  1578. 'nms_thresh': .7,
  1579. 'pre_nms_top_n': 12000,
  1580. 'post_nms_top_n': 2000,
  1581. 'topk_after_collect': False
  1582. }
  1583. test_proposal_cfg = {
  1584. 'min_size': 0.0,
  1585. 'nms_thresh': .7,
  1586. 'pre_nms_top_n': 6000
  1587. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1588. 'post_nms_top_n': test_post_nms_top_n
  1589. }
  1590. bb_head = ppdet.modeling.Res5Head()
  1591. bb_roi_extractor_cfg = {
  1592. 'resolution': 14,
  1593. 'spatial_scale':
  1594. [1. / i.stride for i in backbone.out_shape],
  1595. 'sampling_ratio': 0,
  1596. 'aligned': True
  1597. }
  1598. with_pool = True
  1599. m_head = ppdet.modeling.MaskFeat(
  1600. in_channel=bb_head.out_shape[0].channels,
  1601. out_channel=256,
  1602. num_convs=0)
  1603. m_roi_extractor_cfg = {
  1604. 'resolution': 14,
  1605. 'spatial_scale':
  1606. [1. / i.stride for i in backbone.out_shape],
  1607. 'sampling_ratio': 0,
  1608. 'aligned': True
  1609. }
  1610. mask_assigner = MaskAssigner(
  1611. num_classes=num_classes, mask_resolution=14)
  1612. share_bbox_feat = True
  1613. rpn_target_assign_cfg = {
  1614. 'batch_size_per_im': rpn_batch_size_per_im,
  1615. 'fg_fraction': rpn_fg_fraction,
  1616. 'negative_overlap': .3,
  1617. 'positive_overlap': .7,
  1618. 'use_random': True
  1619. }
  1620. rpn_head = ppdet.modeling.RPNHead(
  1621. anchor_generator=anchor_generator_cfg,
  1622. rpn_target_assign=rpn_target_assign_cfg,
  1623. train_proposal=train_proposal_cfg,
  1624. test_proposal=test_proposal_cfg,
  1625. in_channel=rpn_in_channel)
  1626. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1627. bbox_head = ppdet.modeling.BBoxHead(
  1628. head=bb_head,
  1629. in_channel=bb_head.out_shape[0].channels,
  1630. roi_extractor=bb_roi_extractor_cfg,
  1631. with_pool=with_pool,
  1632. bbox_assigner=bbox_assigner,
  1633. num_classes=num_classes)
  1634. mask_head = ppdet.modeling.MaskHead(
  1635. head=m_head,
  1636. roi_extractor=m_roi_extractor_cfg,
  1637. mask_assigner=mask_assigner,
  1638. share_bbox_feat=share_bbox_feat,
  1639. num_classes=num_classes)
  1640. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1641. num_classes=num_classes,
  1642. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1643. nms=ppdet.modeling.MultiClassNMS(
  1644. score_threshold=score_threshold,
  1645. keep_top_k=keep_top_k,
  1646. nms_threshold=nms_threshold))
  1647. mask_post_process = ppdet.modeling.MaskPostProcess(
  1648. binary_thresh=.5)
  1649. params.update({
  1650. 'backbone': backbone,
  1651. 'neck': neck,
  1652. 'rpn_head': rpn_head,
  1653. 'bbox_head': bbox_head,
  1654. 'mask_head': mask_head,
  1655. 'bbox_post_process': bbox_post_process,
  1656. 'mask_post_process': mask_post_process
  1657. })
  1658. self.with_fpn = with_fpn
  1659. super(FasterRCNN, self).__init__(
  1660. model_name='MaskRCNN', num_classes=num_classes, **params)
  1661. def _compose_batch_transform(self, transforms, mode='train'):
  1662. if mode == 'train':
  1663. default_batch_transforms = [
  1664. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1665. ]
  1666. collate_batch = False
  1667. else:
  1668. default_batch_transforms = [
  1669. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1670. ]
  1671. collate_batch = True
  1672. custom_batch_transforms = []
  1673. for i, op in enumerate(transforms.transforms):
  1674. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1675. if mode != 'train':
  1676. raise Exception(
  1677. "{} cannot be present in the {} transforms. ".format(
  1678. op.__class__.__name__, mode) +
  1679. "Please check the {} transforms.".format(mode))
  1680. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1681. batch_transforms = BatchCompose(
  1682. custom_batch_transforms + default_batch_transforms,
  1683. collate_batch=collate_batch,
  1684. return_list=mode == 'test')
  1685. return batch_transforms
  1686. def _fix_transforms_shape(self, image_shape):
  1687. if hasattr(self, 'test_transforms'):
  1688. if self.test_transforms is not None:
  1689. has_resize_op = False
  1690. resize_op_idx = -1
  1691. normalize_op_idx = len(self.test_transforms.transforms)
  1692. for idx, op in enumerate(self.test_transforms.transforms):
  1693. name = op.__class__.__name__
  1694. if name == 'ResizeByShort':
  1695. has_resize_op = True
  1696. resize_op_idx = idx
  1697. if name == 'Normalize':
  1698. normalize_op_idx = idx
  1699. if not has_resize_op:
  1700. self.test_transforms.transforms.insert(
  1701. normalize_op_idx,
  1702. Resize(
  1703. target_size=image_shape,
  1704. keep_ratio=True,
  1705. interp='CUBIC'))
  1706. else:
  1707. self.test_transforms.transforms[resize_op_idx] = Resize(
  1708. target_size=image_shape,
  1709. keep_ratio=True,
  1710. interp='CUBIC')
  1711. self.test_transforms.transforms.append(
  1712. Padding(im_padding_value=[0., 0., 0.]))
  1713. def _get_test_inputs(self, image_shape):
  1714. if image_shape is not None:
  1715. image_shape = self._check_image_shape(image_shape)
  1716. self._fix_transforms_shape(image_shape[-2:])
  1717. else:
  1718. image_shape = [None, 3, -1, -1]
  1719. if self.with_fpn:
  1720. self.test_transforms.transforms.append(
  1721. Padding(im_padding_value=[0., 0., 0.]))
  1722. self.fixed_input_shape = image_shape
  1723. return self._define_input_spec(image_shape)