detector.py 70 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import ppdet
  24. from ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. super(BaseDetector, self).__init__('detector')
  41. if not hasattr(ppdet.modeling, model_name):
  42. raise Exception("ERROR: There's no model named {}.".format(
  43. model_name))
  44. self.model_name = model_name
  45. self.num_classes = num_classes
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. def build_net(self, **params):
  49. with paddle.utils.unique_name.guard():
  50. net = ppdet.modeling.__dict__[self.model_name](**params)
  51. return net
  52. def _fix_transforms_shape(self, image_shape):
  53. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  54. def _define_input_spec(self, image_shape):
  55. input_spec = [{
  56. "image": InputSpec(
  57. shape=image_shape, name='image', dtype='float32'),
  58. "im_shape": InputSpec(
  59. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  60. "scale_factor": InputSpec(
  61. shape=[image_shape[0], 2],
  62. name='scale_factor',
  63. dtype='float32')
  64. }]
  65. return input_spec
  66. def _check_image_shape(self, image_shape):
  67. if len(image_shape) == 2:
  68. image_shape = [None, 3] + image_shape
  69. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  70. raise Exception(
  71. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  72. format(image_shape[-2:]))
  73. return image_shape
  74. def _get_test_inputs(self, image_shape):
  75. if image_shape is not None:
  76. image_shape = self._check_image_shape(image_shape)
  77. self._fix_transforms_shape(image_shape[-2:])
  78. else:
  79. image_shape = [None, 3, -1, -1]
  80. return self._define_input_spec(image_shape)
  81. def _get_backbone(self, backbone_name, **params):
  82. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  83. return backbone
  84. def run(self, net, inputs, mode):
  85. net_out = net(inputs)
  86. if mode in ['train', 'eval']:
  87. outputs = net_out
  88. else:
  89. for key in ['im_shape', 'scale_factor']:
  90. net_out[key] = inputs[key]
  91. outputs = dict()
  92. for key in net_out:
  93. outputs[key] = net_out[key].numpy()
  94. return outputs
  95. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  96. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  97. num_steps_each_epoch):
  98. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  99. values = [(lr_decay_gamma**i) * learning_rate
  100. for i in range(len(lr_decay_epochs) + 1)]
  101. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  102. boundaries=boundaries, values=values)
  103. if warmup_steps > 0:
  104. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  105. logging.error(
  106. "In function train(), parameters should satisfy: "
  107. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  108. exit=False)
  109. logging.error(
  110. "See this doc for more information: "
  111. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  112. exit=False)
  113. scheduler = paddle.optimizer.lr.LinearWarmup(
  114. learning_rate=scheduler,
  115. warmup_steps=warmup_steps,
  116. start_lr=warmup_start_lr,
  117. end_lr=learning_rate)
  118. optimizer = paddle.optimizer.Momentum(
  119. scheduler,
  120. momentum=.9,
  121. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  122. parameters=parameters)
  123. return optimizer
  124. def train(self,
  125. num_epochs,
  126. train_dataset,
  127. train_batch_size=64,
  128. eval_dataset=None,
  129. optimizer=None,
  130. save_interval_epochs=1,
  131. log_interval_steps=10,
  132. save_dir='output',
  133. pretrain_weights='IMAGENET',
  134. learning_rate=.001,
  135. warmup_steps=0,
  136. warmup_start_lr=0.0,
  137. lr_decay_epochs=(216, 243),
  138. lr_decay_gamma=0.1,
  139. metric=None,
  140. use_ema=False,
  141. early_stop=False,
  142. early_stop_patience=5,
  143. use_vdl=True):
  144. """
  145. Train the model.
  146. Args:
  147. num_epochs(int): The number of epochs.
  148. train_dataset(paddlex.dataset): Training dataset.
  149. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  150. eval_dataset(paddlex.dataset, optional):
  151. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  152. optimizer(paddle.optimizer.Optimizer or None, optional):
  153. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  154. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  155. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  156. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  157. pretrain_weights(str or None, optional):
  158. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  159. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  160. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  161. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  162. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  163. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  164. metric({'VOC', 'COCO', None}, optional):
  165. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  166. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  167. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  168. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  169. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  170. """
  171. if train_dataset.__class__.__name__ == 'VOCDetection':
  172. train_dataset.data_fields = {
  173. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  174. 'difficult'
  175. }
  176. elif train_dataset.__class__.__name__ == 'CocoDetection':
  177. if self.__class__.__name__ == 'MaskRCNN':
  178. train_dataset.data_fields = {
  179. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  180. 'gt_poly', 'is_crowd'
  181. }
  182. else:
  183. train_dataset.data_fields = {
  184. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  185. 'is_crowd'
  186. }
  187. if metric is None:
  188. if eval_dataset.__class__.__name__ == 'VOCDetection':
  189. self.metric = 'voc'
  190. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  191. self.metric = 'coco'
  192. else:
  193. assert metric.lower() in ['coco', 'voc'], \
  194. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  195. self.metric = metric.lower()
  196. self.labels = train_dataset.labels
  197. self.num_max_boxes = train_dataset.num_max_boxes
  198. train_dataset.batch_transforms = self._compose_batch_transform(
  199. train_dataset.transforms, mode='train')
  200. # build optimizer if not defined
  201. if optimizer is None:
  202. num_steps_each_epoch = len(train_dataset) // train_batch_size
  203. self.optimizer = self.default_optimizer(
  204. parameters=self.net.parameters(),
  205. learning_rate=learning_rate,
  206. warmup_steps=warmup_steps,
  207. warmup_start_lr=warmup_start_lr,
  208. lr_decay_epochs=lr_decay_epochs,
  209. lr_decay_gamma=lr_decay_gamma,
  210. num_steps_each_epoch=num_steps_each_epoch)
  211. else:
  212. self.optimizer = optimizer
  213. # initiate weights
  214. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  215. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  216. [self.model_name, self.backbone_name])]:
  217. logging.warning(
  218. "Path of pretrain_weights('{}') does not exist!".format(
  219. pretrain_weights))
  220. pretrain_weights = det_pretrain_weights_dict['_'.join(
  221. [self.model_name, self.backbone_name])][0]
  222. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  223. "If you don't want to use pretrain weights, "
  224. "set pretrain_weights to be None.".format(
  225. pretrain_weights))
  226. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  227. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  228. logging.error(
  229. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  230. exit=True)
  231. pretrained_dir = osp.join(save_dir, 'pretrain')
  232. self.net_initialize(
  233. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  234. if use_ema:
  235. ema = ExponentialMovingAverage(
  236. decay=.9998, model=self.net, use_thres_step=True)
  237. else:
  238. ema = None
  239. # start train loop
  240. self.train_loop(
  241. num_epochs=num_epochs,
  242. train_dataset=train_dataset,
  243. train_batch_size=train_batch_size,
  244. eval_dataset=eval_dataset,
  245. save_interval_epochs=save_interval_epochs,
  246. log_interval_steps=log_interval_steps,
  247. save_dir=save_dir,
  248. ema=ema,
  249. early_stop=early_stop,
  250. early_stop_patience=early_stop_patience,
  251. use_vdl=use_vdl)
  252. def quant_aware_train(self,
  253. num_epochs,
  254. train_dataset,
  255. train_batch_size=64,
  256. eval_dataset=None,
  257. optimizer=None,
  258. save_interval_epochs=1,
  259. log_interval_steps=10,
  260. save_dir='output',
  261. learning_rate=.00001,
  262. warmup_steps=0,
  263. warmup_start_lr=0.0,
  264. lr_decay_epochs=(216, 243),
  265. lr_decay_gamma=0.1,
  266. metric=None,
  267. use_ema=False,
  268. early_stop=False,
  269. early_stop_patience=5,
  270. use_vdl=True,
  271. quant_config=None):
  272. """
  273. Quantization-aware training.
  274. Args:
  275. num_epochs(int): The number of epochs.
  276. train_dataset(paddlex.dataset): Training dataset.
  277. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  278. eval_dataset(paddlex.dataset, optional):
  279. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  280. optimizer(paddle.optimizer.Optimizer or None, optional):
  281. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  282. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  283. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  284. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  285. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  286. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  287. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  288. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  289. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  290. metric({'VOC', 'COCO', None}, optional):
  291. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  292. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  293. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  294. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  295. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  296. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  297. configuration will be used. Defaults to None.
  298. """
  299. self._prepare_qat(quant_config)
  300. self.train(
  301. num_epochs=num_epochs,
  302. train_dataset=train_dataset,
  303. train_batch_size=train_batch_size,
  304. eval_dataset=eval_dataset,
  305. optimizer=optimizer,
  306. save_interval_epochs=save_interval_epochs,
  307. log_interval_steps=log_interval_steps,
  308. save_dir=save_dir,
  309. pretrain_weights=None,
  310. learning_rate=learning_rate,
  311. warmup_steps=warmup_steps,
  312. warmup_start_lr=warmup_start_lr,
  313. lr_decay_epochs=lr_decay_epochs,
  314. lr_decay_gamma=lr_decay_gamma,
  315. metric=metric,
  316. use_ema=use_ema,
  317. early_stop=early_stop,
  318. early_stop_patience=early_stop_patience,
  319. use_vdl=use_vdl)
  320. def evaluate(self,
  321. eval_dataset,
  322. batch_size=1,
  323. metric=None,
  324. return_details=False):
  325. """
  326. Evaluate the model.
  327. Args:
  328. eval_dataset(paddlex.dataset): Evaluation dataset.
  329. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  330. metric({'VOC', 'COCO', None}, optional):
  331. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  332. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  333. Returns:
  334. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  335. """
  336. if metric is None:
  337. if not hasattr(self, 'metric'):
  338. if eval_dataset.__class__.__name__ == 'VOCDetection':
  339. self.metric = 'voc'
  340. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  341. self.metric = 'coco'
  342. else:
  343. assert metric.lower() in ['coco', 'voc'], \
  344. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  345. self.metric = metric.lower()
  346. if self.metric == 'voc':
  347. eval_dataset.data_fields = {
  348. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  349. 'difficult'
  350. }
  351. elif self.metric == 'coco':
  352. if self.__class__.__name__ == 'MaskRCNN':
  353. eval_dataset.data_fields = {
  354. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  355. 'gt_poly', 'is_crowd'
  356. }
  357. else:
  358. eval_dataset.data_fields = {
  359. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  360. 'is_crowd'
  361. }
  362. eval_dataset.batch_transforms = self._compose_batch_transform(
  363. eval_dataset.transforms, mode='eval')
  364. arrange_transforms(
  365. model_type=self.model_type,
  366. transforms=eval_dataset.transforms,
  367. mode='eval')
  368. self.net.eval()
  369. nranks = paddle.distributed.get_world_size()
  370. local_rank = paddle.distributed.get_rank()
  371. if nranks > 1:
  372. # Initialize parallel environment if not done.
  373. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  374. ):
  375. paddle.distributed.init_parallel_env()
  376. if batch_size > 1:
  377. logging.warning(
  378. "Detector only supports single card evaluation with batch_size=1 "
  379. "during evaluation, so batch_size is forcibly set to 1.")
  380. batch_size = 1
  381. if nranks < 2 or local_rank == 0:
  382. self.eval_data_loader = self.build_data_loader(
  383. eval_dataset, batch_size=batch_size, mode='eval')
  384. is_bbox_normalized = False
  385. if eval_dataset.batch_transforms is not None:
  386. is_bbox_normalized = any(
  387. isinstance(t, _NormalizeBox)
  388. for t in eval_dataset.batch_transforms.batch_transforms)
  389. if self.metric == 'voc':
  390. eval_metric = VOCMetric(
  391. labels=eval_dataset.labels,
  392. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  393. is_bbox_normalized=is_bbox_normalized,
  394. classwise=False)
  395. else:
  396. eval_metric = COCOMetric(
  397. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  398. classwise=False)
  399. scores = collections.OrderedDict()
  400. logging.info(
  401. "Start to evaluate(total_samples={}, total_steps={})...".
  402. format(eval_dataset.num_samples, eval_dataset.num_samples))
  403. with paddle.no_grad():
  404. for step, data in enumerate(self.eval_data_loader):
  405. outputs = self.run(self.net, data, 'eval')
  406. eval_metric.update(data, outputs)
  407. eval_metric.accumulate()
  408. self.eval_details = eval_metric.details
  409. scores.update(eval_metric.get())
  410. eval_metric.reset()
  411. if return_details:
  412. return scores, self.eval_details
  413. return scores
  414. def predict(self, img_file, transforms=None):
  415. """
  416. Do inference.
  417. Args:
  418. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  419. Image path or decoded image data in a BGR format, which also could constitute a list,
  420. meaning all images to be predicted as a mini-batch.
  421. transforms(paddlex.transforms.Compose or None, optional):
  422. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  423. Returns:
  424. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  425. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  426. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  427. category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
  428. category(str): category name
  429. bbox(list): bounding box in [x, y, w, h] format
  430. score(str): confidence
  431. mask(dict): Only for instance segmentation task. Mask of the object in RLE format
  432. """
  433. if transforms is None and not hasattr(self, 'test_transforms'):
  434. raise Exception("transforms need to be defined, now is None.")
  435. if transforms is None:
  436. transforms = self.test_transforms
  437. if isinstance(img_file, (str, np.ndarray)):
  438. images = [img_file]
  439. else:
  440. images = img_file
  441. batch_samples = self._preprocess(images, transforms)
  442. self.net.eval()
  443. outputs = self.run(self.net, batch_samples, 'test')
  444. prediction = self._postprocess(outputs)
  445. if isinstance(img_file, (str, np.ndarray)):
  446. prediction = prediction[0]
  447. return prediction
  448. def _preprocess(self, images, transforms):
  449. arrange_transforms(
  450. model_type=self.model_type, transforms=transforms, mode='test')
  451. batch_samples = list()
  452. for im in images:
  453. sample = {'image': im}
  454. batch_samples.append(transforms(sample))
  455. batch_transforms = self._compose_batch_transform(transforms, 'test')
  456. batch_samples = batch_transforms(batch_samples)
  457. for k, v in batch_samples.items():
  458. batch_samples[k] = paddle.to_tensor(v)
  459. return batch_samples
  460. def _postprocess(self, batch_pred):
  461. infer_result = {}
  462. if 'bbox' in batch_pred:
  463. bboxes = batch_pred['bbox']
  464. bbox_nums = batch_pred['bbox_num']
  465. det_res = []
  466. k = 0
  467. for i in range(len(bbox_nums)):
  468. det_nums = bbox_nums[i]
  469. for j in range(det_nums):
  470. dt = bboxes[k]
  471. k = k + 1
  472. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  473. if int(num_id) < 0:
  474. continue
  475. category = self.labels[int(num_id)]
  476. w = xmax - xmin
  477. h = ymax - ymin
  478. bbox = [xmin, ymin, w, h]
  479. dt_res = {
  480. 'category_id': int(num_id),
  481. 'category': category,
  482. 'bbox': bbox,
  483. 'score': score
  484. }
  485. det_res.append(dt_res)
  486. infer_result['bbox'] = det_res
  487. if 'mask' in batch_pred:
  488. masks = batch_pred['mask']
  489. bboxes = batch_pred['bbox']
  490. mask_nums = batch_pred['bbox_num']
  491. seg_res = []
  492. k = 0
  493. for i in range(len(mask_nums)):
  494. det_nums = mask_nums[i]
  495. for j in range(det_nums):
  496. mask = masks[k].astype(np.uint8)
  497. score = float(bboxes[k][1])
  498. label = int(bboxes[k][0])
  499. k = k + 1
  500. if label == -1:
  501. continue
  502. category = self.labels[int(label)]
  503. import pycocotools.mask as mask_util
  504. rle = mask_util.encode(
  505. np.array(
  506. mask[:, :, None], order="F", dtype="uint8"))[0]
  507. if six.PY3:
  508. if 'counts' in rle:
  509. rle['counts'] = rle['counts'].decode("utf8")
  510. sg_res = {
  511. 'category_id': int(label),
  512. 'category': category,
  513. 'mask': rle,
  514. 'score': score
  515. }
  516. seg_res.append(sg_res)
  517. infer_result['mask'] = seg_res
  518. bbox_num = batch_pred['bbox_num']
  519. results = []
  520. start = 0
  521. for num in bbox_num:
  522. end = start + num
  523. curr_res = infer_result['bbox'][start:end]
  524. if 'mask' in infer_result:
  525. mask_res = infer_result['mask'][start:end]
  526. for box, mask in zip(curr_res, mask_res):
  527. box.update(mask)
  528. results.append(curr_res)
  529. start = end
  530. return results
  531. class YOLOv3(BaseDetector):
  532. def __init__(self,
  533. num_classes=80,
  534. backbone='MobileNetV1',
  535. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  536. [59, 119], [116, 90], [156, 198], [373, 326]],
  537. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  538. ignore_threshold=0.7,
  539. nms_score_threshold=0.01,
  540. nms_topk=1000,
  541. nms_keep_topk=100,
  542. nms_iou_threshold=0.45,
  543. label_smooth=False):
  544. self.init_params = locals()
  545. if backbone not in [
  546. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  547. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  548. ]:
  549. raise ValueError(
  550. "backbone: {} is not supported. Please choose one of "
  551. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  552. format(backbone))
  553. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  554. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  555. norm_type = 'sync_bn'
  556. else:
  557. norm_type = 'bn'
  558. self.backbone_name = backbone
  559. if 'MobileNetV1' in backbone:
  560. norm_type = 'bn'
  561. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  562. elif 'MobileNetV3' in backbone:
  563. backbone = self._get_backbone(
  564. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  565. elif backbone == 'ResNet50_vd_dcn':
  566. backbone = self._get_backbone(
  567. 'ResNet',
  568. norm_type=norm_type,
  569. variant='d',
  570. return_idx=[1, 2, 3],
  571. dcn_v2_stages=[3],
  572. freeze_at=-1,
  573. freeze_norm=False)
  574. elif backbone == 'ResNet34':
  575. backbone = self._get_backbone(
  576. 'ResNet',
  577. depth=34,
  578. norm_type=norm_type,
  579. return_idx=[1, 2, 3],
  580. freeze_at=-1,
  581. freeze_norm=False,
  582. norm_decay=0.)
  583. else:
  584. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  585. neck = ppdet.modeling.YOLOv3FPN(
  586. norm_type=norm_type,
  587. in_channels=[i.channels for i in backbone.out_shape])
  588. loss = ppdet.modeling.YOLOv3Loss(
  589. num_classes=num_classes,
  590. ignore_thresh=ignore_threshold,
  591. label_smooth=label_smooth)
  592. yolo_head = ppdet.modeling.YOLOv3Head(
  593. in_channels=[i.channels for i in neck.out_shape],
  594. anchors=anchors,
  595. anchor_masks=anchor_masks,
  596. num_classes=num_classes,
  597. loss=loss)
  598. post_process = ppdet.modeling.BBoxPostProcess(
  599. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  600. nms=ppdet.modeling.MultiClassNMS(
  601. score_threshold=nms_score_threshold,
  602. nms_top_k=nms_topk,
  603. keep_top_k=nms_keep_topk,
  604. nms_threshold=nms_iou_threshold))
  605. params = {
  606. 'backbone': backbone,
  607. 'neck': neck,
  608. 'yolo_head': yolo_head,
  609. 'post_process': post_process
  610. }
  611. super(YOLOv3, self).__init__(
  612. model_name='YOLOv3', num_classes=num_classes, **params)
  613. self.anchors = anchors
  614. self.anchor_masks = anchor_masks
  615. def _compose_batch_transform(self, transforms, mode='train'):
  616. if mode == 'train':
  617. default_batch_transforms = [
  618. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  619. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  620. _Gt2YoloTarget(
  621. anchor_masks=self.anchor_masks,
  622. anchors=self.anchors,
  623. downsample_ratios=getattr(self, 'downsample_ratios',
  624. [32, 16, 8]),
  625. num_classes=self.num_classes)
  626. ]
  627. else:
  628. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  629. if mode == 'eval' and self.metric == 'voc':
  630. collate_batch = False
  631. else:
  632. collate_batch = True
  633. custom_batch_transforms = []
  634. for i, op in enumerate(transforms.transforms):
  635. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  636. if mode != 'train':
  637. raise Exception(
  638. "{} cannot be present in the {} transforms. ".format(
  639. op.__class__.__name__, mode) +
  640. "Please check the {} transforms.".format(mode))
  641. custom_batch_transforms.insert(0, copy.deepcopy(op))
  642. batch_transforms = BatchCompose(
  643. custom_batch_transforms + default_batch_transforms,
  644. collate_batch=collate_batch)
  645. return batch_transforms
  646. def _fix_transforms_shape(self, image_shape):
  647. if hasattr(self, 'test_transforms'):
  648. if self.test_transforms is not None:
  649. has_resize_op = False
  650. resize_op_idx = -1
  651. normalize_op_idx = len(self.test_transforms.transforms)
  652. for idx, op in enumerate(self.test_transforms.transforms):
  653. name = op.__class__.__name__
  654. if name == 'Resize':
  655. has_resize_op = True
  656. resize_op_idx = idx
  657. if name == 'Normalize':
  658. normalize_op_idx = idx
  659. if not has_resize_op:
  660. self.test_transforms.transforms.insert(
  661. normalize_op_idx,
  662. Resize(
  663. target_size=image_shape, interp='CUBIC'))
  664. else:
  665. self.test_transforms.transforms[
  666. resize_op_idx].target_size = image_shape
  667. class FasterRCNN(BaseDetector):
  668. def __init__(self,
  669. num_classes=80,
  670. backbone='ResNet50',
  671. with_fpn=True,
  672. with_dcn=False,
  673. aspect_ratios=[0.5, 1.0, 2.0],
  674. anchor_sizes=[[32], [64], [128], [256], [512]],
  675. keep_top_k=100,
  676. nms_threshold=0.5,
  677. score_threshold=0.05,
  678. fpn_num_channels=256,
  679. rpn_batch_size_per_im=256,
  680. rpn_fg_fraction=0.5,
  681. test_pre_nms_top_n=None,
  682. test_post_nms_top_n=1000):
  683. self.init_params = locals()
  684. if backbone not in [
  685. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  686. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  687. ]:
  688. raise ValueError(
  689. "backbone: {} is not supported. Please choose one of "
  690. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  691. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  692. self.backbone_name = backbone
  693. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  694. if backbone == 'HRNet_W18':
  695. if not with_fpn:
  696. logging.warning(
  697. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  698. format(backbone))
  699. with_fpn = True
  700. if with_dcn:
  701. logging.warning(
  702. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  703. format(backbone))
  704. backbone = self._get_backbone(
  705. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  706. elif backbone == 'ResNet50_vd_ssld':
  707. if not with_fpn:
  708. logging.warning(
  709. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  710. format(backbone))
  711. with_fpn = True
  712. backbone = self._get_backbone(
  713. 'ResNet',
  714. variant='d',
  715. norm_type='bn',
  716. freeze_at=0,
  717. return_idx=[0, 1, 2, 3],
  718. num_stages=4,
  719. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  720. dcn_v2_stages=dcn_v2_stages)
  721. elif 'ResNet50' in backbone:
  722. if with_fpn:
  723. backbone = self._get_backbone(
  724. 'ResNet',
  725. variant='d' if '_vd' in backbone else 'b',
  726. norm_type='bn',
  727. freeze_at=0,
  728. return_idx=[0, 1, 2, 3],
  729. num_stages=4,
  730. dcn_v2_stages=dcn_v2_stages)
  731. else:
  732. if with_dcn:
  733. logging.warning(
  734. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  735. format(backbone))
  736. backbone = self._get_backbone(
  737. 'ResNet',
  738. variant='d' if '_vd' in backbone else 'b',
  739. norm_type='bn',
  740. freeze_at=0,
  741. return_idx=[2],
  742. num_stages=3)
  743. elif 'ResNet34' in backbone:
  744. if not with_fpn:
  745. logging.warning(
  746. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  747. format(backbone))
  748. with_fpn = True
  749. backbone = self._get_backbone(
  750. 'ResNet',
  751. depth=34,
  752. variant='d' if 'vd' in backbone else 'b',
  753. norm_type='bn',
  754. freeze_at=0,
  755. return_idx=[0, 1, 2, 3],
  756. num_stages=4,
  757. dcn_v2_stages=dcn_v2_stages)
  758. else:
  759. if not with_fpn:
  760. logging.warning(
  761. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  762. format(backbone))
  763. with_fpn = True
  764. backbone = self._get_backbone(
  765. 'ResNet',
  766. depth=101,
  767. variant='d' if 'vd' in backbone else 'b',
  768. norm_type='bn',
  769. freeze_at=0,
  770. return_idx=[0, 1, 2, 3],
  771. num_stages=4,
  772. dcn_v2_stages=dcn_v2_stages)
  773. rpn_in_channel = backbone.out_shape[0].channels
  774. if with_fpn:
  775. self.backbone_name = self.backbone_name + '_fpn'
  776. if 'HRNet' in self.backbone_name:
  777. neck = ppdet.modeling.HRFPN(
  778. in_channels=[i.channels for i in backbone.out_shape],
  779. out_channel=fpn_num_channels,
  780. spatial_scales=[
  781. 1.0 / i.stride for i in backbone.out_shape
  782. ],
  783. share_conv=False)
  784. else:
  785. neck = ppdet.modeling.FPN(
  786. in_channels=[i.channels for i in backbone.out_shape],
  787. out_channel=fpn_num_channels,
  788. spatial_scales=[
  789. 1.0 / i.stride for i in backbone.out_shape
  790. ])
  791. rpn_in_channel = neck.out_shape[0].channels
  792. anchor_generator_cfg = {
  793. 'aspect_ratios': aspect_ratios,
  794. 'anchor_sizes': anchor_sizes,
  795. 'strides': [4, 8, 16, 32, 64]
  796. }
  797. train_proposal_cfg = {
  798. 'min_size': 0.0,
  799. 'nms_thresh': .7,
  800. 'pre_nms_top_n': 2000,
  801. 'post_nms_top_n': 1000,
  802. 'topk_after_collect': True
  803. }
  804. test_proposal_cfg = {
  805. 'min_size': 0.0,
  806. 'nms_thresh': .7,
  807. 'pre_nms_top_n': 1000
  808. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  809. 'post_nms_top_n': test_post_nms_top_n
  810. }
  811. head = ppdet.modeling.TwoFCHead(
  812. in_channel=neck.out_shape[0].channels, out_channel=1024)
  813. roi_extractor_cfg = {
  814. 'resolution': 7,
  815. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  816. 'sampling_ratio': 0,
  817. 'aligned': True
  818. }
  819. with_pool = False
  820. else:
  821. neck = None
  822. anchor_generator_cfg = {
  823. 'aspect_ratios': aspect_ratios,
  824. 'anchor_sizes': anchor_sizes,
  825. 'strides': [16]
  826. }
  827. train_proposal_cfg = {
  828. 'min_size': 0.0,
  829. 'nms_thresh': .7,
  830. 'pre_nms_top_n': 12000,
  831. 'post_nms_top_n': 2000,
  832. 'topk_after_collect': False
  833. }
  834. test_proposal_cfg = {
  835. 'min_size': 0.0,
  836. 'nms_thresh': .7,
  837. 'pre_nms_top_n': 6000
  838. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  839. 'post_nms_top_n': test_post_nms_top_n
  840. }
  841. head = ppdet.modeling.Res5Head()
  842. roi_extractor_cfg = {
  843. 'resolution': 14,
  844. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  845. 'sampling_ratio': 0,
  846. 'aligned': True
  847. }
  848. with_pool = True
  849. rpn_target_assign_cfg = {
  850. 'batch_size_per_im': rpn_batch_size_per_im,
  851. 'fg_fraction': rpn_fg_fraction,
  852. 'negative_overlap': .3,
  853. 'positive_overlap': .7,
  854. 'use_random': True
  855. }
  856. rpn_head = ppdet.modeling.RPNHead(
  857. anchor_generator=anchor_generator_cfg,
  858. rpn_target_assign=rpn_target_assign_cfg,
  859. train_proposal=train_proposal_cfg,
  860. test_proposal=test_proposal_cfg,
  861. in_channel=rpn_in_channel)
  862. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  863. bbox_head = ppdet.modeling.BBoxHead(
  864. head=head,
  865. in_channel=head.out_shape[0].channels,
  866. roi_extractor=roi_extractor_cfg,
  867. with_pool=with_pool,
  868. bbox_assigner=bbox_assigner,
  869. num_classes=num_classes)
  870. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  871. num_classes=num_classes,
  872. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  873. nms=ppdet.modeling.MultiClassNMS(
  874. score_threshold=score_threshold,
  875. keep_top_k=keep_top_k,
  876. nms_threshold=nms_threshold))
  877. params = {
  878. 'backbone': backbone,
  879. 'neck': neck,
  880. 'rpn_head': rpn_head,
  881. 'bbox_head': bbox_head,
  882. 'bbox_post_process': bbox_post_process
  883. }
  884. self.with_fpn = with_fpn
  885. super(FasterRCNN, self).__init__(
  886. model_name='FasterRCNN', num_classes=num_classes, **params)
  887. def _compose_batch_transform(self, transforms, mode='train'):
  888. if mode == 'train':
  889. default_batch_transforms = [
  890. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  891. ]
  892. collate_batch = False
  893. else:
  894. default_batch_transforms = [
  895. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  896. ]
  897. collate_batch = True
  898. custom_batch_transforms = []
  899. for i, op in enumerate(transforms.transforms):
  900. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  901. if mode != 'train':
  902. raise Exception(
  903. "{} cannot be present in the {} transforms. ".format(
  904. op.__class__.__name__, mode) +
  905. "Please check the {} transforms.".format(mode))
  906. custom_batch_transforms.insert(0, copy.deepcopy(op))
  907. batch_transforms = BatchCompose(
  908. custom_batch_transforms + default_batch_transforms,
  909. collate_batch=collate_batch)
  910. return batch_transforms
  911. def _fix_transforms_shape(self, image_shape):
  912. if hasattr(self, 'test_transforms'):
  913. if self.test_transforms is not None:
  914. has_resize_op = False
  915. resize_op_idx = -1
  916. normalize_op_idx = len(self.test_transforms.transforms)
  917. for idx, op in enumerate(self.test_transforms.transforms):
  918. name = op.__class__.__name__
  919. if name == 'ResizeByShort':
  920. has_resize_op = True
  921. resize_op_idx = idx
  922. if name == 'Normalize':
  923. normalize_op_idx = idx
  924. if not has_resize_op:
  925. self.test_transforms.transforms.insert(
  926. normalize_op_idx,
  927. Resize(
  928. target_size=image_shape,
  929. keep_ratio=True,
  930. interp='CUBIC'))
  931. else:
  932. self.test_transforms.transforms[resize_op_idx] = Resize(
  933. target_size=image_shape,
  934. keep_ratio=True,
  935. interp='CUBIC')
  936. self.test_transforms.transforms.append(
  937. Padding(im_padding_value=[0., 0., 0.]))
  938. def _get_test_inputs(self, image_shape):
  939. if image_shape is not None:
  940. image_shape = self._check_image_shape(image_shape)
  941. self._fix_transforms_shape(image_shape[-2:])
  942. else:
  943. image_shape = [None, 3, -1, -1]
  944. if self.with_fpn:
  945. self.test_transforms.transforms.append(
  946. Padding(im_padding_value=[0., 0., 0.]))
  947. return self._define_input_spec(image_shape)
  948. class PPYOLO(YOLOv3):
  949. def __init__(self,
  950. num_classes=80,
  951. backbone='ResNet50_vd_dcn',
  952. anchors=None,
  953. anchor_masks=None,
  954. use_coord_conv=True,
  955. use_iou_aware=True,
  956. use_spp=True,
  957. use_drop_block=True,
  958. scale_x_y=1.05,
  959. ignore_threshold=0.7,
  960. label_smooth=False,
  961. use_iou_loss=True,
  962. use_matrix_nms=True,
  963. nms_score_threshold=0.01,
  964. nms_topk=-1,
  965. nms_keep_topk=100,
  966. nms_iou_threshold=0.45):
  967. self.init_params = locals()
  968. if backbone not in [
  969. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  970. 'MobileNetV3_small'
  971. ]:
  972. raise ValueError(
  973. "backbone: {} is not supported. Please choose one of "
  974. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  975. format(backbone))
  976. self.backbone_name = backbone
  977. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  978. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  979. norm_type = 'sync_bn'
  980. else:
  981. norm_type = 'bn'
  982. if anchors is None and anchor_masks is None:
  983. if 'MobileNetV3' in backbone:
  984. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  985. [120, 195], [254, 235]]
  986. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  987. elif backbone == 'ResNet50_vd_dcn':
  988. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  989. [59, 119], [116, 90], [156, 198], [373, 326]]
  990. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  991. else:
  992. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  993. [344, 319]]
  994. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  995. elif anchors is None or anchor_masks is None:
  996. raise ValueError("Please define both anchors and anchor_masks.")
  997. if backbone == 'ResNet50_vd_dcn':
  998. backbone = self._get_backbone(
  999. 'ResNet',
  1000. variant='d',
  1001. norm_type=norm_type,
  1002. return_idx=[1, 2, 3],
  1003. dcn_v2_stages=[3],
  1004. freeze_at=-1,
  1005. freeze_norm=False,
  1006. norm_decay=0.)
  1007. downsample_ratios = [32, 16, 8]
  1008. elif backbone == 'ResNet18_vd':
  1009. backbone = self._get_backbone(
  1010. 'ResNet',
  1011. depth=18,
  1012. variant='d',
  1013. norm_type=norm_type,
  1014. return_idx=[2, 3],
  1015. freeze_at=-1,
  1016. freeze_norm=False,
  1017. norm_decay=0.)
  1018. downsample_ratios = [32, 16, 8]
  1019. elif backbone == 'MobileNetV3_large':
  1020. backbone = self._get_backbone(
  1021. 'MobileNetV3',
  1022. model_name='large',
  1023. norm_type=norm_type,
  1024. scale=1,
  1025. with_extra_blocks=False,
  1026. extra_block_filters=[],
  1027. feature_maps=[13, 16])
  1028. downsample_ratios = [32, 16]
  1029. elif backbone == 'MobileNetV3_small':
  1030. backbone = self._get_backbone(
  1031. 'MobileNetV3',
  1032. model_name='small',
  1033. norm_type=norm_type,
  1034. scale=1,
  1035. with_extra_blocks=False,
  1036. extra_block_filters=[],
  1037. feature_maps=[9, 12])
  1038. downsample_ratios = [32, 16]
  1039. neck = ppdet.modeling.PPYOLOFPN(
  1040. norm_type=norm_type,
  1041. in_channels=[i.channels for i in backbone.out_shape],
  1042. coord_conv=use_coord_conv,
  1043. drop_block=use_drop_block,
  1044. spp=use_spp,
  1045. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  1046. self.backbone_name == 'ResNet18_vd') else 2)
  1047. loss = ppdet.modeling.YOLOv3Loss(
  1048. num_classes=num_classes,
  1049. ignore_thresh=ignore_threshold,
  1050. downsample=downsample_ratios,
  1051. label_smooth=label_smooth,
  1052. scale_x_y=scale_x_y,
  1053. iou_loss=ppdet.modeling.IouLoss(
  1054. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1055. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1056. if use_iou_aware else None)
  1057. yolo_head = ppdet.modeling.YOLOv3Head(
  1058. in_channels=[i.channels for i in neck.out_shape],
  1059. anchors=anchors,
  1060. anchor_masks=anchor_masks,
  1061. num_classes=num_classes,
  1062. loss=loss,
  1063. iou_aware=use_iou_aware)
  1064. if use_matrix_nms:
  1065. nms = ppdet.modeling.MatrixNMS(
  1066. keep_top_k=nms_keep_topk,
  1067. score_threshold=nms_score_threshold,
  1068. post_threshold=.05
  1069. if 'MobileNetV3' in self.backbone_name else .01,
  1070. nms_top_k=nms_topk,
  1071. background_label=-1)
  1072. else:
  1073. nms = ppdet.modeling.MultiClassNMS(
  1074. score_threshold=nms_score_threshold,
  1075. nms_top_k=nms_topk,
  1076. keep_top_k=nms_keep_topk,
  1077. nms_threshold=nms_iou_threshold)
  1078. post_process = ppdet.modeling.BBoxPostProcess(
  1079. decode=ppdet.modeling.YOLOBox(
  1080. num_classes=num_classes,
  1081. conf_thresh=.005
  1082. if 'MobileNetV3' in self.backbone_name else .01,
  1083. scale_x_y=scale_x_y),
  1084. nms=nms)
  1085. params = {
  1086. 'backbone': backbone,
  1087. 'neck': neck,
  1088. 'yolo_head': yolo_head,
  1089. 'post_process': post_process
  1090. }
  1091. super(YOLOv3, self).__init__(
  1092. model_name='YOLOv3', num_classes=num_classes, **params)
  1093. self.anchors = anchors
  1094. self.anchor_masks = anchor_masks
  1095. self.downsample_ratios = downsample_ratios
  1096. self.model_name = 'PPYOLO'
  1097. class PPYOLOTiny(YOLOv3):
  1098. def __init__(self,
  1099. num_classes=80,
  1100. backbone='MobileNetV3',
  1101. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1102. [60, 170], [220, 125], [128, 222], [264, 266]],
  1103. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1104. use_iou_aware=False,
  1105. use_spp=True,
  1106. use_drop_block=True,
  1107. scale_x_y=1.05,
  1108. ignore_threshold=0.5,
  1109. label_smooth=False,
  1110. use_iou_loss=True,
  1111. use_matrix_nms=False,
  1112. nms_score_threshold=0.005,
  1113. nms_topk=1000,
  1114. nms_keep_topk=100,
  1115. nms_iou_threshold=0.45):
  1116. self.init_params = locals()
  1117. if backbone != 'MobileNetV3':
  1118. logging.warning(
  1119. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1120. "Backbone is forcibly set to MobileNetV3.")
  1121. self.backbone_name = 'MobileNetV3'
  1122. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1123. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1124. norm_type = 'sync_bn'
  1125. else:
  1126. norm_type = 'bn'
  1127. backbone = self._get_backbone(
  1128. 'MobileNetV3',
  1129. model_name='large',
  1130. norm_type=norm_type,
  1131. scale=.5,
  1132. with_extra_blocks=False,
  1133. extra_block_filters=[],
  1134. feature_maps=[7, 13, 16])
  1135. downsample_ratios = [32, 16, 8]
  1136. neck = ppdet.modeling.PPYOLOTinyFPN(
  1137. detection_block_channels=[160, 128, 96],
  1138. in_channels=[i.channels for i in backbone.out_shape],
  1139. spp=use_spp,
  1140. drop_block=use_drop_block)
  1141. loss = ppdet.modeling.YOLOv3Loss(
  1142. num_classes=num_classes,
  1143. ignore_thresh=ignore_threshold,
  1144. downsample=downsample_ratios,
  1145. label_smooth=label_smooth,
  1146. scale_x_y=scale_x_y,
  1147. iou_loss=ppdet.modeling.IouLoss(
  1148. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1149. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1150. if use_iou_aware else None)
  1151. yolo_head = ppdet.modeling.YOLOv3Head(
  1152. in_channels=[i.channels for i in neck.out_shape],
  1153. anchors=anchors,
  1154. anchor_masks=anchor_masks,
  1155. num_classes=num_classes,
  1156. loss=loss,
  1157. iou_aware=use_iou_aware)
  1158. if use_matrix_nms:
  1159. nms = ppdet.modeling.MatrixNMS(
  1160. keep_top_k=nms_keep_topk,
  1161. score_threshold=nms_score_threshold,
  1162. post_threshold=.05,
  1163. nms_top_k=nms_topk,
  1164. background_label=-1)
  1165. else:
  1166. nms = ppdet.modeling.MultiClassNMS(
  1167. score_threshold=nms_score_threshold,
  1168. nms_top_k=nms_topk,
  1169. keep_top_k=nms_keep_topk,
  1170. nms_threshold=nms_iou_threshold)
  1171. post_process = ppdet.modeling.BBoxPostProcess(
  1172. decode=ppdet.modeling.YOLOBox(
  1173. num_classes=num_classes,
  1174. conf_thresh=.005,
  1175. downsample_ratio=32,
  1176. clip_bbox=True,
  1177. scale_x_y=scale_x_y),
  1178. nms=nms)
  1179. params = {
  1180. 'backbone': backbone,
  1181. 'neck': neck,
  1182. 'yolo_head': yolo_head,
  1183. 'post_process': post_process
  1184. }
  1185. super(YOLOv3, self).__init__(
  1186. model_name='YOLOv3', num_classes=num_classes, **params)
  1187. self.anchors = anchors
  1188. self.anchor_masks = anchor_masks
  1189. self.downsample_ratios = downsample_ratios
  1190. self.model_name = 'PPYOLOTiny'
  1191. class PPYOLOv2(YOLOv3):
  1192. def __init__(self,
  1193. num_classes=80,
  1194. backbone='ResNet50_vd_dcn',
  1195. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1196. [59, 119], [116, 90], [156, 198], [373, 326]],
  1197. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1198. use_iou_aware=True,
  1199. use_spp=True,
  1200. use_drop_block=True,
  1201. scale_x_y=1.05,
  1202. ignore_threshold=0.7,
  1203. label_smooth=False,
  1204. use_iou_loss=True,
  1205. use_matrix_nms=True,
  1206. nms_score_threshold=0.01,
  1207. nms_topk=-1,
  1208. nms_keep_topk=100,
  1209. nms_iou_threshold=0.45):
  1210. self.init_params = locals()
  1211. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1212. raise ValueError(
  1213. "backbone: {} is not supported. Please choose one of "
  1214. "('ResNet50_vd_dcn', 'ResNet18_vd')".format(backbone))
  1215. self.backbone_name = backbone
  1216. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1217. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1218. norm_type = 'sync_bn'
  1219. else:
  1220. norm_type = 'bn'
  1221. if backbone == 'ResNet50_vd_dcn':
  1222. backbone = self._get_backbone(
  1223. 'ResNet',
  1224. variant='d',
  1225. norm_type=norm_type,
  1226. return_idx=[1, 2, 3],
  1227. dcn_v2_stages=[3],
  1228. freeze_at=-1,
  1229. freeze_norm=False,
  1230. norm_decay=0.)
  1231. downsample_ratios = [32, 16, 8]
  1232. elif backbone == 'ResNet101_vd_dcn':
  1233. backbone = self._get_backbone(
  1234. 'ResNet',
  1235. depth=101,
  1236. variant='d',
  1237. norm_type=norm_type,
  1238. return_idx=[1, 2, 3],
  1239. dcn_v2_stages=[3],
  1240. freeze_at=-1,
  1241. freeze_norm=False,
  1242. norm_decay=0.)
  1243. downsample_ratios = [32, 16, 8]
  1244. neck = ppdet.modeling.PPYOLOPAN(
  1245. norm_type=norm_type,
  1246. in_channels=[i.channels for i in backbone.out_shape],
  1247. drop_block=use_drop_block,
  1248. block_size=3,
  1249. keep_prob=.9,
  1250. spp=use_spp)
  1251. loss = ppdet.modeling.YOLOv3Loss(
  1252. num_classes=num_classes,
  1253. ignore_thresh=ignore_threshold,
  1254. downsample=downsample_ratios,
  1255. label_smooth=label_smooth,
  1256. scale_x_y=scale_x_y,
  1257. iou_loss=ppdet.modeling.IouLoss(
  1258. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1259. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1260. if use_iou_aware else None)
  1261. yolo_head = ppdet.modeling.YOLOv3Head(
  1262. in_channels=[i.channels for i in neck.out_shape],
  1263. anchors=anchors,
  1264. anchor_masks=anchor_masks,
  1265. num_classes=num_classes,
  1266. loss=loss,
  1267. iou_aware=use_iou_aware,
  1268. iou_aware_factor=.5)
  1269. if use_matrix_nms:
  1270. nms = ppdet.modeling.MatrixNMS(
  1271. keep_top_k=nms_keep_topk,
  1272. score_threshold=nms_score_threshold,
  1273. post_threshold=.01,
  1274. nms_top_k=nms_topk,
  1275. background_label=-1)
  1276. else:
  1277. nms = ppdet.modeling.MultiClassNMS(
  1278. score_threshold=nms_score_threshold,
  1279. nms_top_k=nms_topk,
  1280. keep_top_k=nms_keep_topk,
  1281. nms_threshold=nms_iou_threshold)
  1282. post_process = ppdet.modeling.BBoxPostProcess(
  1283. decode=ppdet.modeling.YOLOBox(
  1284. num_classes=num_classes,
  1285. conf_thresh=.01,
  1286. downsample_ratio=32,
  1287. clip_bbox=True,
  1288. scale_x_y=scale_x_y),
  1289. nms=nms)
  1290. params = {
  1291. 'backbone': backbone,
  1292. 'neck': neck,
  1293. 'yolo_head': yolo_head,
  1294. 'post_process': post_process
  1295. }
  1296. super(YOLOv3, self).__init__(
  1297. model_name='YOLOv3', num_classes=num_classes, **params)
  1298. self.anchors = anchors
  1299. self.anchor_masks = anchor_masks
  1300. self.downsample_ratios = downsample_ratios
  1301. self.model_name = 'PPYOLOv2'
  1302. def _get_test_inputs(self, image_shape):
  1303. if image_shape is not None:
  1304. if len(image_shape) == 2:
  1305. image_shape = [None, 3] + image_shape
  1306. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  1307. raise Exception(
  1308. "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
  1309. format(image_shape[-2:]))
  1310. self._fix_transforms_shape(image_shape[-2:])
  1311. else:
  1312. logging.warning(
  1313. '[Important!!!] When exporting inference model for {},'.format(
  1314. self.__class__.__name__) +
  1315. ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 608, 608]. '
  1316. +
  1317. 'Please check image shape after transforms is [3, 608, 608], if not, fixed_input_shape '
  1318. + 'should be specified manually.')
  1319. image_shape = [None, 3, 608, 608]
  1320. input_spec = [{
  1321. "image": InputSpec(
  1322. shape=image_shape, name='image', dtype='float32'),
  1323. "im_shape": InputSpec(
  1324. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  1325. "scale_factor": InputSpec(
  1326. shape=[image_shape[0], 2],
  1327. name='scale_factor',
  1328. dtype='float32')
  1329. }]
  1330. return input_spec
  1331. class MaskRCNN(BaseDetector):
  1332. def __init__(self,
  1333. num_classes=80,
  1334. backbone='ResNet50_vd',
  1335. with_fpn=True,
  1336. with_dcn=False,
  1337. aspect_ratios=[0.5, 1.0, 2.0],
  1338. anchor_sizes=[[32], [64], [128], [256], [512]],
  1339. keep_top_k=100,
  1340. nms_threshold=0.5,
  1341. score_threshold=0.05,
  1342. fpn_num_channels=256,
  1343. rpn_batch_size_per_im=256,
  1344. rpn_fg_fraction=0.5,
  1345. test_pre_nms_top_n=None,
  1346. test_post_nms_top_n=1000):
  1347. self.init_params = locals()
  1348. if backbone not in [
  1349. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1350. 'ResNet101_vd'
  1351. ]:
  1352. raise ValueError(
  1353. "backbone: {} is not supported. Please choose one of "
  1354. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1355. format(backbone))
  1356. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1357. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1358. if backbone == 'ResNet50':
  1359. if with_fpn:
  1360. backbone = self._get_backbone(
  1361. 'ResNet',
  1362. norm_type='bn',
  1363. freeze_at=0,
  1364. return_idx=[0, 1, 2, 3],
  1365. num_stages=4,
  1366. dcn_v2_stages=dcn_v2_stages)
  1367. else:
  1368. if with_dcn:
  1369. logging.warning(
  1370. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1371. format(backbone))
  1372. backbone = self._get_backbone(
  1373. 'ResNet',
  1374. norm_type='bn',
  1375. freeze_at=0,
  1376. return_idx=[2],
  1377. num_stages=3)
  1378. elif 'ResNet50_vd' in backbone:
  1379. if not with_fpn:
  1380. logging.warning(
  1381. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1382. format(backbone))
  1383. with_fpn = True
  1384. backbone = self._get_backbone(
  1385. 'ResNet',
  1386. variant='d',
  1387. norm_type='bn',
  1388. freeze_at=0,
  1389. return_idx=[0, 1, 2, 3],
  1390. num_stages=4,
  1391. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1392. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1393. dcn_v2_stages=dcn_v2_stages)
  1394. else:
  1395. if not with_fpn:
  1396. logging.warning(
  1397. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1398. format(backbone))
  1399. with_fpn = True
  1400. backbone = self._get_backbone(
  1401. 'ResNet',
  1402. variant='d' if '_vd' in backbone else 'b',
  1403. depth=101,
  1404. norm_type='bn',
  1405. freeze_at=0,
  1406. return_idx=[0, 1, 2, 3],
  1407. num_stages=4,
  1408. dcn_v2_stages=dcn_v2_stages)
  1409. rpn_in_channel = backbone.out_shape[0].channels
  1410. if with_fpn:
  1411. neck = ppdet.modeling.FPN(
  1412. in_channels=[i.channels for i in backbone.out_shape],
  1413. out_channel=fpn_num_channels,
  1414. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1415. rpn_in_channel = neck.out_shape[0].channels
  1416. anchor_generator_cfg = {
  1417. 'aspect_ratios': aspect_ratios,
  1418. 'anchor_sizes': anchor_sizes,
  1419. 'strides': [4, 8, 16, 32, 64]
  1420. }
  1421. train_proposal_cfg = {
  1422. 'min_size': 0.0,
  1423. 'nms_thresh': .7,
  1424. 'pre_nms_top_n': 2000,
  1425. 'post_nms_top_n': 1000,
  1426. 'topk_after_collect': True
  1427. }
  1428. test_proposal_cfg = {
  1429. 'min_size': 0.0,
  1430. 'nms_thresh': .7,
  1431. 'pre_nms_top_n': 1000
  1432. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1433. 'post_nms_top_n': test_post_nms_top_n
  1434. }
  1435. bb_head = ppdet.modeling.TwoFCHead(
  1436. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1437. bb_roi_extractor_cfg = {
  1438. 'resolution': 7,
  1439. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1440. 'sampling_ratio': 0,
  1441. 'aligned': True
  1442. }
  1443. with_pool = False
  1444. m_head = ppdet.modeling.MaskFeat(
  1445. in_channel=neck.out_shape[0].channels,
  1446. out_channel=256,
  1447. num_convs=4)
  1448. m_roi_extractor_cfg = {
  1449. 'resolution': 14,
  1450. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1451. 'sampling_ratio': 0,
  1452. 'aligned': True
  1453. }
  1454. mask_assigner = MaskAssigner(
  1455. num_classes=num_classes, mask_resolution=28)
  1456. share_bbox_feat = False
  1457. else:
  1458. neck = None
  1459. anchor_generator_cfg = {
  1460. 'aspect_ratios': aspect_ratios,
  1461. 'anchor_sizes': anchor_sizes,
  1462. 'strides': [16]
  1463. }
  1464. train_proposal_cfg = {
  1465. 'min_size': 0.0,
  1466. 'nms_thresh': .7,
  1467. 'pre_nms_top_n': 12000,
  1468. 'post_nms_top_n': 2000,
  1469. 'topk_after_collect': False
  1470. }
  1471. test_proposal_cfg = {
  1472. 'min_size': 0.0,
  1473. 'nms_thresh': .7,
  1474. 'pre_nms_top_n': 6000
  1475. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1476. 'post_nms_top_n': test_post_nms_top_n
  1477. }
  1478. bb_head = ppdet.modeling.Res5Head()
  1479. bb_roi_extractor_cfg = {
  1480. 'resolution': 14,
  1481. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1482. 'sampling_ratio': 0,
  1483. 'aligned': True
  1484. }
  1485. with_pool = True
  1486. m_head = ppdet.modeling.MaskFeat(
  1487. in_channel=bb_head.out_shape[0].channels,
  1488. out_channel=256,
  1489. num_convs=0)
  1490. m_roi_extractor_cfg = {
  1491. 'resolution': 14,
  1492. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1493. 'sampling_ratio': 0,
  1494. 'aligned': True
  1495. }
  1496. mask_assigner = MaskAssigner(
  1497. num_classes=num_classes, mask_resolution=14)
  1498. share_bbox_feat = True
  1499. rpn_target_assign_cfg = {
  1500. 'batch_size_per_im': rpn_batch_size_per_im,
  1501. 'fg_fraction': rpn_fg_fraction,
  1502. 'negative_overlap': .3,
  1503. 'positive_overlap': .7,
  1504. 'use_random': True
  1505. }
  1506. rpn_head = ppdet.modeling.RPNHead(
  1507. anchor_generator=anchor_generator_cfg,
  1508. rpn_target_assign=rpn_target_assign_cfg,
  1509. train_proposal=train_proposal_cfg,
  1510. test_proposal=test_proposal_cfg,
  1511. in_channel=rpn_in_channel)
  1512. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1513. bbox_head = ppdet.modeling.BBoxHead(
  1514. head=bb_head,
  1515. in_channel=bb_head.out_shape[0].channels,
  1516. roi_extractor=bb_roi_extractor_cfg,
  1517. with_pool=with_pool,
  1518. bbox_assigner=bbox_assigner,
  1519. num_classes=num_classes)
  1520. mask_head = ppdet.modeling.MaskHead(
  1521. head=m_head,
  1522. roi_extractor=m_roi_extractor_cfg,
  1523. mask_assigner=mask_assigner,
  1524. share_bbox_feat=share_bbox_feat,
  1525. num_classes=num_classes)
  1526. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1527. num_classes=num_classes,
  1528. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1529. nms=ppdet.modeling.MultiClassNMS(
  1530. score_threshold=score_threshold,
  1531. keep_top_k=keep_top_k,
  1532. nms_threshold=nms_threshold))
  1533. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1534. params = {
  1535. 'backbone': backbone,
  1536. 'neck': neck,
  1537. 'rpn_head': rpn_head,
  1538. 'bbox_head': bbox_head,
  1539. 'mask_head': mask_head,
  1540. 'bbox_post_process': bbox_post_process,
  1541. 'mask_post_process': mask_post_process
  1542. }
  1543. self.with_fpn = with_fpn
  1544. super(MaskRCNN, self).__init__(
  1545. model_name='MaskRCNN', num_classes=num_classes, **params)
  1546. def _compose_batch_transform(self, transforms, mode='train'):
  1547. if mode == 'train':
  1548. default_batch_transforms = [
  1549. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1550. ]
  1551. collate_batch = False
  1552. else:
  1553. default_batch_transforms = [
  1554. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1555. ]
  1556. collate_batch = True
  1557. custom_batch_transforms = []
  1558. for i, op in enumerate(transforms.transforms):
  1559. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1560. if mode != 'train':
  1561. raise Exception(
  1562. "{} cannot be present in the {} transforms. ".format(
  1563. op.__class__.__name__, mode) +
  1564. "Please check the {} transforms.".format(mode))
  1565. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1566. batch_transforms = BatchCompose(
  1567. custom_batch_transforms + default_batch_transforms,
  1568. collate_batch=collate_batch)
  1569. return batch_transforms
  1570. def _fix_transforms_shape(self, image_shape):
  1571. if hasattr(self, 'test_transforms'):
  1572. if self.test_transforms is not None:
  1573. has_resize_op = False
  1574. resize_op_idx = -1
  1575. normalize_op_idx = len(self.test_transforms.transforms)
  1576. for idx, op in enumerate(self.test_transforms.transforms):
  1577. name = op.__class__.__name__
  1578. if name == 'ResizeByShort':
  1579. has_resize_op = True
  1580. resize_op_idx = idx
  1581. if name == 'Normalize':
  1582. normalize_op_idx = idx
  1583. if not has_resize_op:
  1584. self.test_transforms.transforms.insert(
  1585. normalize_op_idx,
  1586. Resize(
  1587. target_size=image_shape,
  1588. keep_ratio=True,
  1589. interp='CUBIC'))
  1590. else:
  1591. self.test_transforms.transforms[resize_op_idx] = Resize(
  1592. target_size=image_shape,
  1593. keep_ratio=True,
  1594. interp='CUBIC')
  1595. self.test_transforms.transforms.append(
  1596. Padding(im_padding_value=[0., 0., 0.]))
  1597. def _get_test_inputs(self, image_shape):
  1598. if image_shape is not None:
  1599. image_shape = self._check_image_shape(image_shape)
  1600. self._fix_transforms_shape(image_shape[-2:])
  1601. else:
  1602. image_shape = [None, 3, -1, -1]
  1603. if self.with_fpn:
  1604. self.test_transforms.transforms.append(
  1605. Padding(im_padding_value=[0., 0., 0.]))
  1606. return self._define_input_spec(image_shape)