detector.py 77 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import paddlex.ppdet as ppdet
  24. from paddlex.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. if 'with_net' in self.init_params:
  41. del self.init_params['with_net']
  42. super(BaseDetector, self).__init__('detector')
  43. if not hasattr(ppdet.modeling, model_name):
  44. raise Exception("ERROR: There's no model named {}.".format(
  45. model_name))
  46. self.model_name = model_name
  47. self.num_classes = num_classes
  48. self.labels = None
  49. if params.get('with_net', True):
  50. params.pop('with_net', None)
  51. self.net = self.build_net(**params)
  52. def build_net(self, **params):
  53. with paddle.utils.unique_name.guard():
  54. net = ppdet.modeling.__dict__[self.model_name](**params)
  55. return net
  56. def _fix_transforms_shape(self, image_shape):
  57. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  58. def _define_input_spec(self, image_shape):
  59. input_spec = [{
  60. "image": InputSpec(
  61. shape=image_shape, name='image', dtype='float32'),
  62. "im_shape": InputSpec(
  63. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  64. "scale_factor": InputSpec(
  65. shape=[image_shape[0], 2],
  66. name='scale_factor',
  67. dtype='float32')
  68. }]
  69. return input_spec
  70. def _check_image_shape(self, image_shape):
  71. if len(image_shape) == 2:
  72. image_shape = [1, 3] + image_shape
  73. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  74. raise Exception(
  75. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  76. format(image_shape[-2:]))
  77. return image_shape
  78. def _get_test_inputs(self, image_shape):
  79. if image_shape is not None:
  80. image_shape = self._check_image_shape(image_shape)
  81. self._fix_transforms_shape(image_shape[-2:])
  82. else:
  83. image_shape = [None, 3, -1, -1]
  84. self.fixed_input_shape = image_shape
  85. return self._define_input_spec(image_shape)
  86. def _get_backbone(self, backbone_name, **params):
  87. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  88. return backbone
  89. def run(self, net, inputs, mode):
  90. net_out = net(inputs)
  91. if mode in ['train', 'eval']:
  92. outputs = net_out
  93. else:
  94. outputs = dict()
  95. for key in net_out:
  96. outputs[key] = net_out[key].numpy()
  97. return outputs
  98. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  99. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  100. num_steps_each_epoch):
  101. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  102. values = [(lr_decay_gamma**i) * learning_rate
  103. for i in range(len(lr_decay_epochs) + 1)]
  104. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  105. boundaries=boundaries, values=values)
  106. if warmup_steps > 0:
  107. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  108. logging.error(
  109. "In function train(), parameters should satisfy: "
  110. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  111. exit=False)
  112. logging.error(
  113. "See this doc for more information: "
  114. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
  115. exit=False)
  116. scheduler = paddle.optimizer.lr.LinearWarmup(
  117. learning_rate=scheduler,
  118. warmup_steps=warmup_steps,
  119. start_lr=warmup_start_lr,
  120. end_lr=learning_rate)
  121. optimizer = paddle.optimizer.Momentum(
  122. scheduler,
  123. momentum=.9,
  124. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  125. parameters=parameters)
  126. return optimizer
  127. def train(self,
  128. num_epochs,
  129. train_dataset,
  130. train_batch_size=64,
  131. eval_dataset=None,
  132. optimizer=None,
  133. save_interval_epochs=1,
  134. log_interval_steps=10,
  135. save_dir='output',
  136. pretrain_weights='IMAGENET',
  137. learning_rate=.001,
  138. warmup_steps=0,
  139. warmup_start_lr=0.0,
  140. lr_decay_epochs=(216, 243),
  141. lr_decay_gamma=0.1,
  142. metric=None,
  143. use_ema=False,
  144. early_stop=False,
  145. early_stop_patience=5,
  146. use_vdl=True,
  147. resume_checkpoint=None):
  148. """
  149. Train the model.
  150. Args:
  151. num_epochs(int): The number of epochs.
  152. train_dataset(paddlex.dataset): Training dataset.
  153. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  154. eval_dataset(paddlex.dataset, optional):
  155. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  156. optimizer(paddle.optimizer.Optimizer or None, optional):
  157. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  158. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  159. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  160. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  161. pretrain_weights(str or None, optional):
  162. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  163. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  164. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  165. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  166. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  167. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  168. metric({'VOC', 'COCO', None}, optional):
  169. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  170. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  171. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  172. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  173. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  174. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  175. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  176. `pretrain_weights` can be set simultaneously. Defaults to None.
  177. """
  178. if train_dataset.pos_num < len(
  179. train_dataset.file_list
  180. ) and train_batch_size != 1 and 'RCNN' in self.__class__.__name__:
  181. train_batch_size = 1
  182. logging.warning(
  183. "Training RCNN models with negative samples only support batch size equals to 1, "
  184. "`train_batch_size` is forcibly set to 1.")
  185. if self.status == 'Infer':
  186. logging.error(
  187. "Exported inference model does not support training.",
  188. exit=True)
  189. if pretrain_weights is not None and resume_checkpoint is not None:
  190. logging.error(
  191. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  192. exit=True)
  193. if train_dataset.__class__.__name__ == 'VOCDetection':
  194. train_dataset.data_fields = {
  195. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  196. 'difficult'
  197. }
  198. elif train_dataset.__class__.__name__ == 'CocoDetection':
  199. if self.__class__.__name__ == 'MaskRCNN':
  200. train_dataset.data_fields = {
  201. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  202. 'gt_poly', 'is_crowd'
  203. }
  204. else:
  205. train_dataset.data_fields = {
  206. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  207. 'is_crowd'
  208. }
  209. if metric is None:
  210. if eval_dataset.__class__.__name__ == 'VOCDetection':
  211. self.metric = 'voc'
  212. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  213. self.metric = 'coco'
  214. else:
  215. assert metric.lower() in ['coco', 'voc'], \
  216. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  217. self.metric = metric.lower()
  218. self.labels = train_dataset.labels
  219. self.num_max_boxes = train_dataset.num_max_boxes
  220. train_dataset.batch_transforms = self._compose_batch_transform(
  221. train_dataset.transforms, mode='train')
  222. # build optimizer if not defined
  223. if optimizer is None:
  224. num_steps_each_epoch = len(train_dataset) // train_batch_size
  225. self.optimizer = self.default_optimizer(
  226. parameters=self.net.parameters(),
  227. learning_rate=learning_rate,
  228. warmup_steps=warmup_steps,
  229. warmup_start_lr=warmup_start_lr,
  230. lr_decay_epochs=lr_decay_epochs,
  231. lr_decay_gamma=lr_decay_gamma,
  232. num_steps_each_epoch=num_steps_each_epoch)
  233. else:
  234. self.optimizer = optimizer
  235. # initiate weights
  236. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  237. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  238. [self.model_name, self.backbone_name])]:
  239. logging.warning(
  240. "Path of pretrain_weights('{}') does not exist!".format(
  241. pretrain_weights))
  242. pretrain_weights = det_pretrain_weights_dict['_'.join(
  243. [self.model_name, self.backbone_name])][0]
  244. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  245. "If you don't want to use pretrain weights, "
  246. "set pretrain_weights to be None.".format(
  247. pretrain_weights))
  248. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  249. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  250. logging.error(
  251. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  252. exit=True)
  253. pretrained_dir = osp.join(save_dir, 'pretrain')
  254. self.net_initialize(
  255. pretrain_weights=pretrain_weights,
  256. save_dir=pretrained_dir,
  257. resume_checkpoint=resume_checkpoint)
  258. if use_ema:
  259. ema = ExponentialMovingAverage(
  260. decay=.9998, model=self.net, use_thres_step=True)
  261. else:
  262. ema = None
  263. # start train loop
  264. self.train_loop(
  265. num_epochs=num_epochs,
  266. train_dataset=train_dataset,
  267. train_batch_size=train_batch_size,
  268. eval_dataset=eval_dataset,
  269. save_interval_epochs=save_interval_epochs,
  270. log_interval_steps=log_interval_steps,
  271. save_dir=save_dir,
  272. ema=ema,
  273. early_stop=early_stop,
  274. early_stop_patience=early_stop_patience,
  275. use_vdl=use_vdl)
  276. def quant_aware_train(self,
  277. num_epochs,
  278. train_dataset,
  279. train_batch_size=64,
  280. eval_dataset=None,
  281. optimizer=None,
  282. save_interval_epochs=1,
  283. log_interval_steps=10,
  284. save_dir='output',
  285. learning_rate=.00001,
  286. warmup_steps=0,
  287. warmup_start_lr=0.0,
  288. lr_decay_epochs=(216, 243),
  289. lr_decay_gamma=0.1,
  290. metric=None,
  291. use_ema=False,
  292. early_stop=False,
  293. early_stop_patience=5,
  294. use_vdl=True,
  295. resume_checkpoint=None,
  296. quant_config=None):
  297. """
  298. Quantization-aware training.
  299. Args:
  300. num_epochs(int): The number of epochs.
  301. train_dataset(paddlex.dataset): Training dataset.
  302. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  303. eval_dataset(paddlex.dataset, optional):
  304. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  305. optimizer(paddle.optimizer.Optimizer or None, optional):
  306. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  307. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  308. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  309. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  310. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  311. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  312. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  313. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  314. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  315. metric({'VOC', 'COCO', None}, optional):
  316. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  317. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  318. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  319. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  320. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  321. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  322. configuration will be used. Defaults to None.
  323. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  324. from. If None, no training checkpoint will be resumed. Defaults to None.
  325. """
  326. self._prepare_qat(quant_config)
  327. self.train(
  328. num_epochs=num_epochs,
  329. train_dataset=train_dataset,
  330. train_batch_size=train_batch_size,
  331. eval_dataset=eval_dataset,
  332. optimizer=optimizer,
  333. save_interval_epochs=save_interval_epochs,
  334. log_interval_steps=log_interval_steps,
  335. save_dir=save_dir,
  336. pretrain_weights=None,
  337. learning_rate=learning_rate,
  338. warmup_steps=warmup_steps,
  339. warmup_start_lr=warmup_start_lr,
  340. lr_decay_epochs=lr_decay_epochs,
  341. lr_decay_gamma=lr_decay_gamma,
  342. metric=metric,
  343. use_ema=use_ema,
  344. early_stop=early_stop,
  345. early_stop_patience=early_stop_patience,
  346. use_vdl=use_vdl,
  347. resume_checkpoint=resume_checkpoint)
  348. def evaluate(self,
  349. eval_dataset,
  350. batch_size=1,
  351. metric=None,
  352. return_details=False):
  353. """
  354. Evaluate the model.
  355. Args:
  356. eval_dataset(paddlex.dataset): Evaluation dataset.
  357. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  358. metric({'VOC', 'COCO', None}, optional):
  359. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  360. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  361. Returns:
  362. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  363. """
  364. if metric is None:
  365. if not hasattr(self, 'metric'):
  366. if eval_dataset.__class__.__name__ == 'VOCDetection':
  367. self.metric = 'voc'
  368. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  369. self.metric = 'coco'
  370. else:
  371. assert metric.lower() in ['coco', 'voc'], \
  372. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  373. self.metric = metric.lower()
  374. if self.metric == 'voc':
  375. eval_dataset.data_fields = {
  376. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  377. 'difficult'
  378. }
  379. elif self.metric == 'coco':
  380. if self.__class__.__name__ == 'MaskRCNN':
  381. eval_dataset.data_fields = {
  382. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  383. 'gt_poly', 'is_crowd'
  384. }
  385. else:
  386. eval_dataset.data_fields = {
  387. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  388. 'is_crowd'
  389. }
  390. eval_dataset.batch_transforms = self._compose_batch_transform(
  391. eval_dataset.transforms, mode='eval')
  392. arrange_transforms(
  393. model_type=self.model_type,
  394. transforms=eval_dataset.transforms,
  395. mode='eval')
  396. self.net.eval()
  397. nranks = paddle.distributed.get_world_size()
  398. local_rank = paddle.distributed.get_rank()
  399. if nranks > 1:
  400. # Initialize parallel environment if not done.
  401. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  402. ):
  403. paddle.distributed.init_parallel_env()
  404. if batch_size > 1:
  405. logging.warning(
  406. "Detector only supports single card evaluation with batch_size=1 "
  407. "during evaluation, so batch_size is forcibly set to 1.")
  408. batch_size = 1
  409. if nranks < 2 or local_rank == 0:
  410. self.eval_data_loader = self.build_data_loader(
  411. eval_dataset, batch_size=batch_size, mode='eval')
  412. is_bbox_normalized = False
  413. if eval_dataset.batch_transforms is not None:
  414. is_bbox_normalized = any(
  415. isinstance(t, _NormalizeBox)
  416. for t in eval_dataset.batch_transforms.batch_transforms)
  417. if self.metric == 'voc':
  418. eval_metric = VOCMetric(
  419. labels=eval_dataset.labels,
  420. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  421. is_bbox_normalized=is_bbox_normalized,
  422. classwise=False)
  423. else:
  424. eval_metric = COCOMetric(
  425. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  426. classwise=False)
  427. scores = collections.OrderedDict()
  428. logging.info(
  429. "Start to evaluate(total_samples={}, total_steps={})...".
  430. format(eval_dataset.num_samples, eval_dataset.num_samples))
  431. with paddle.no_grad():
  432. for step, data in enumerate(self.eval_data_loader):
  433. outputs = self.run(self.net, data, 'eval')
  434. eval_metric.update(data, outputs)
  435. eval_metric.accumulate()
  436. self.eval_details = eval_metric.details
  437. scores.update(eval_metric.get())
  438. eval_metric.reset()
  439. if return_details:
  440. return scores, self.eval_details
  441. return scores
  442. def predict(self, img_file, transforms=None):
  443. """
  444. Do inference.
  445. Args:
  446. img_file(List[np.ndarray or str], str or np.ndarray):
  447. Image path or decoded image data in a BGR format, which also could constitute a list,
  448. meaning all images to be predicted as a mini-batch.
  449. transforms(paddlex.transforms.Compose or None, optional):
  450. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  451. Returns:
  452. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  453. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  454. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  455. category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
  456. category(str): category name
  457. bbox(list): bounding box in [x, y, w, h] format
  458. score(str): confidence
  459. mask(dict): Only for instance segmentation task. Mask of the object in RLE format
  460. """
  461. if transforms is None and not hasattr(self, 'test_transforms'):
  462. raise Exception("transforms need to be defined, now is None.")
  463. if transforms is None:
  464. transforms = self.test_transforms
  465. if isinstance(img_file, (str, np.ndarray)):
  466. images = [img_file]
  467. else:
  468. images = img_file
  469. batch_samples = self._preprocess(images, transforms)
  470. self.net.eval()
  471. outputs = self.run(self.net, batch_samples, 'test')
  472. prediction = self._postprocess(outputs)
  473. if isinstance(img_file, (str, np.ndarray)):
  474. prediction = prediction[0]
  475. return prediction
  476. def _preprocess(self, images, transforms, to_tensor=True):
  477. arrange_transforms(
  478. model_type=self.model_type, transforms=transforms, mode='test')
  479. batch_samples = list()
  480. for im in images:
  481. sample = {'image': im}
  482. batch_samples.append(transforms(sample))
  483. batch_transforms = self._compose_batch_transform(transforms, 'test')
  484. batch_samples = batch_transforms(batch_samples)
  485. if to_tensor:
  486. for k in batch_samples:
  487. batch_samples[k] = paddle.to_tensor(batch_samples[k])
  488. return batch_samples
  489. def _postprocess(self, batch_pred):
  490. infer_result = {}
  491. if 'bbox' in batch_pred:
  492. bboxes = batch_pred['bbox']
  493. bbox_nums = batch_pred['bbox_num']
  494. det_res = []
  495. k = 0
  496. for i in range(len(bbox_nums)):
  497. det_nums = bbox_nums[i]
  498. for j in range(det_nums):
  499. dt = bboxes[k]
  500. k = k + 1
  501. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  502. if int(num_id) < 0:
  503. continue
  504. category = self.labels[int(num_id)]
  505. w = xmax - xmin
  506. h = ymax - ymin
  507. bbox = [xmin, ymin, w, h]
  508. dt_res = {
  509. 'category_id': int(num_id),
  510. 'category': category,
  511. 'bbox': bbox,
  512. 'score': score
  513. }
  514. det_res.append(dt_res)
  515. infer_result['bbox'] = det_res
  516. if 'mask' in batch_pred:
  517. masks = batch_pred['mask']
  518. bboxes = batch_pred['bbox']
  519. mask_nums = batch_pred['bbox_num']
  520. seg_res = []
  521. k = 0
  522. for i in range(len(mask_nums)):
  523. det_nums = mask_nums[i]
  524. for j in range(det_nums):
  525. mask = masks[k].astype(np.uint8)
  526. score = float(bboxes[k][1])
  527. label = int(bboxes[k][0])
  528. k = k + 1
  529. if label == -1:
  530. continue
  531. category = self.labels[int(label)]
  532. sg_res = {
  533. 'category_id': int(label),
  534. 'category': category,
  535. 'mask': mask.astype('uint8'),
  536. 'score': score
  537. }
  538. seg_res.append(sg_res)
  539. infer_result['mask'] = seg_res
  540. bbox_num = batch_pred['bbox_num']
  541. results = []
  542. start = 0
  543. for num in bbox_num:
  544. end = start + num
  545. curr_res = infer_result['bbox'][start:end]
  546. if 'mask' in infer_result:
  547. mask_res = infer_result['mask'][start:end]
  548. for box, mask in zip(curr_res, mask_res):
  549. box.update(mask)
  550. results.append(curr_res)
  551. start = end
  552. return results
  553. class YOLOv3(BaseDetector):
  554. def __init__(self,
  555. num_classes=80,
  556. backbone='MobileNetV1',
  557. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  558. [59, 119], [116, 90], [156, 198], [373, 326]],
  559. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  560. ignore_threshold=0.7,
  561. nms_score_threshold=0.01,
  562. nms_topk=1000,
  563. nms_keep_topk=100,
  564. nms_iou_threshold=0.45,
  565. label_smooth=False,
  566. **params):
  567. self.init_params = locals()
  568. if backbone not in [
  569. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  570. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  571. ]:
  572. raise ValueError(
  573. "backbone: {} is not supported. Please choose one of "
  574. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  575. format(backbone))
  576. self.backbone_name = backbone
  577. if params.get('with_net', True):
  578. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  579. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  580. norm_type = 'sync_bn'
  581. else:
  582. norm_type = 'bn'
  583. if 'MobileNetV1' in backbone:
  584. norm_type = 'bn'
  585. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  586. elif 'MobileNetV3' in backbone:
  587. backbone = self._get_backbone(
  588. 'MobileNetV3',
  589. norm_type=norm_type,
  590. feature_maps=[7, 13, 16])
  591. elif backbone == 'ResNet50_vd_dcn':
  592. backbone = self._get_backbone(
  593. 'ResNet',
  594. norm_type=norm_type,
  595. variant='d',
  596. return_idx=[1, 2, 3],
  597. dcn_v2_stages=[3],
  598. freeze_at=-1,
  599. freeze_norm=False)
  600. elif backbone == 'ResNet34':
  601. backbone = self._get_backbone(
  602. 'ResNet',
  603. depth=34,
  604. norm_type=norm_type,
  605. return_idx=[1, 2, 3],
  606. freeze_at=-1,
  607. freeze_norm=False,
  608. norm_decay=0.)
  609. else:
  610. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  611. neck = ppdet.modeling.YOLOv3FPN(
  612. norm_type=norm_type,
  613. in_channels=[i.channels for i in backbone.out_shape])
  614. loss = ppdet.modeling.YOLOv3Loss(
  615. num_classes=num_classes,
  616. ignore_thresh=ignore_threshold,
  617. label_smooth=label_smooth)
  618. yolo_head = ppdet.modeling.YOLOv3Head(
  619. in_channels=[i.channels for i in neck.out_shape],
  620. anchors=anchors,
  621. anchor_masks=anchor_masks,
  622. num_classes=num_classes,
  623. loss=loss)
  624. post_process = ppdet.modeling.BBoxPostProcess(
  625. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  626. nms=ppdet.modeling.MultiClassNMS(
  627. score_threshold=nms_score_threshold,
  628. nms_top_k=nms_topk,
  629. keep_top_k=nms_keep_topk,
  630. nms_threshold=nms_iou_threshold))
  631. params.update({
  632. 'backbone': backbone,
  633. 'neck': neck,
  634. 'yolo_head': yolo_head,
  635. 'post_process': post_process
  636. })
  637. super(YOLOv3, self).__init__(
  638. model_name='YOLOv3', num_classes=num_classes, **params)
  639. self.anchors = anchors
  640. self.anchor_masks = anchor_masks
  641. def _compose_batch_transform(self, transforms, mode='train'):
  642. if mode == 'train':
  643. default_batch_transforms = [
  644. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  645. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  646. _Gt2YoloTarget(
  647. anchor_masks=self.anchor_masks,
  648. anchors=self.anchors,
  649. downsample_ratios=getattr(self, 'downsample_ratios',
  650. [32, 16, 8]),
  651. num_classes=self.num_classes)
  652. ]
  653. else:
  654. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  655. if mode == 'eval' and self.metric == 'voc':
  656. collate_batch = False
  657. else:
  658. collate_batch = True
  659. custom_batch_transforms = []
  660. for i, op in enumerate(transforms.transforms):
  661. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  662. if mode != 'train':
  663. raise Exception(
  664. "{} cannot be present in the {} transforms. ".format(
  665. op.__class__.__name__, mode) +
  666. "Please check the {} transforms.".format(mode))
  667. custom_batch_transforms.insert(0, copy.deepcopy(op))
  668. batch_transforms = BatchCompose(
  669. custom_batch_transforms + default_batch_transforms,
  670. collate_batch=collate_batch)
  671. return batch_transforms
  672. def _fix_transforms_shape(self, image_shape):
  673. if hasattr(self, 'test_transforms'):
  674. if self.test_transforms is not None:
  675. has_resize_op = False
  676. resize_op_idx = -1
  677. normalize_op_idx = len(self.test_transforms.transforms)
  678. for idx, op in enumerate(self.test_transforms.transforms):
  679. name = op.__class__.__name__
  680. if name == 'Resize':
  681. has_resize_op = True
  682. resize_op_idx = idx
  683. if name == 'Normalize':
  684. normalize_op_idx = idx
  685. if not has_resize_op:
  686. self.test_transforms.transforms.insert(
  687. normalize_op_idx,
  688. Resize(
  689. target_size=image_shape, interp='CUBIC'))
  690. else:
  691. self.test_transforms.transforms[
  692. resize_op_idx].target_size = image_shape
  693. class FasterRCNN(BaseDetector):
  694. def __init__(self,
  695. num_classes=80,
  696. backbone='ResNet50',
  697. with_fpn=True,
  698. with_dcn=False,
  699. aspect_ratios=[0.5, 1.0, 2.0],
  700. anchor_sizes=[[32], [64], [128], [256], [512]],
  701. keep_top_k=100,
  702. nms_threshold=0.5,
  703. score_threshold=0.05,
  704. fpn_num_channels=256,
  705. rpn_batch_size_per_im=256,
  706. rpn_fg_fraction=0.5,
  707. test_pre_nms_top_n=None,
  708. test_post_nms_top_n=1000,
  709. **params):
  710. self.init_params = locals()
  711. if backbone not in [
  712. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  713. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  714. ]:
  715. raise ValueError(
  716. "backbone: {} is not supported. Please choose one of "
  717. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  718. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  719. self.backbone_name = backbone
  720. if params.get('with_net', True):
  721. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  722. if backbone == 'HRNet_W18':
  723. if not with_fpn:
  724. logging.warning(
  725. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  726. format(backbone))
  727. with_fpn = True
  728. if with_dcn:
  729. logging.warning(
  730. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  731. format(backbone))
  732. backbone = self._get_backbone(
  733. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  734. elif backbone == 'ResNet50_vd_ssld':
  735. if not with_fpn:
  736. logging.warning(
  737. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  738. format(backbone))
  739. with_fpn = True
  740. backbone = self._get_backbone(
  741. 'ResNet',
  742. variant='d',
  743. norm_type='bn',
  744. freeze_at=0,
  745. return_idx=[0, 1, 2, 3],
  746. num_stages=4,
  747. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  748. dcn_v2_stages=dcn_v2_stages)
  749. elif 'ResNet50' in backbone:
  750. if with_fpn:
  751. backbone = self._get_backbone(
  752. 'ResNet',
  753. variant='d' if '_vd' in backbone else 'b',
  754. norm_type='bn',
  755. freeze_at=0,
  756. return_idx=[0, 1, 2, 3],
  757. num_stages=4,
  758. dcn_v2_stages=dcn_v2_stages)
  759. else:
  760. if with_dcn:
  761. logging.warning(
  762. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  763. format(backbone))
  764. backbone = self._get_backbone(
  765. 'ResNet',
  766. variant='d' if '_vd' in backbone else 'b',
  767. norm_type='bn',
  768. freeze_at=0,
  769. return_idx=[2],
  770. num_stages=3)
  771. elif 'ResNet34' in backbone:
  772. if not with_fpn:
  773. logging.warning(
  774. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  775. format(backbone))
  776. with_fpn = True
  777. backbone = self._get_backbone(
  778. 'ResNet',
  779. depth=34,
  780. variant='d' if 'vd' in backbone else 'b',
  781. norm_type='bn',
  782. freeze_at=0,
  783. return_idx=[0, 1, 2, 3],
  784. num_stages=4,
  785. dcn_v2_stages=dcn_v2_stages)
  786. else:
  787. if not with_fpn:
  788. logging.warning(
  789. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  790. format(backbone))
  791. with_fpn = True
  792. backbone = self._get_backbone(
  793. 'ResNet',
  794. depth=101,
  795. variant='d' if 'vd' in backbone else 'b',
  796. norm_type='bn',
  797. freeze_at=0,
  798. return_idx=[0, 1, 2, 3],
  799. num_stages=4,
  800. dcn_v2_stages=dcn_v2_stages)
  801. rpn_in_channel = backbone.out_shape[0].channels
  802. if with_fpn:
  803. self.backbone_name = self.backbone_name + '_fpn'
  804. if 'HRNet' in self.backbone_name:
  805. neck = ppdet.modeling.HRFPN(
  806. in_channels=[i.channels for i in backbone.out_shape],
  807. out_channel=fpn_num_channels,
  808. spatial_scales=[
  809. 1.0 / i.stride for i in backbone.out_shape
  810. ],
  811. share_conv=False)
  812. else:
  813. neck = ppdet.modeling.FPN(
  814. in_channels=[i.channels for i in backbone.out_shape],
  815. out_channel=fpn_num_channels,
  816. spatial_scales=[
  817. 1.0 / i.stride for i in backbone.out_shape
  818. ])
  819. rpn_in_channel = neck.out_shape[0].channels
  820. anchor_generator_cfg = {
  821. 'aspect_ratios': aspect_ratios,
  822. 'anchor_sizes': anchor_sizes,
  823. 'strides': [4, 8, 16, 32, 64]
  824. }
  825. train_proposal_cfg = {
  826. 'min_size': 0.0,
  827. 'nms_thresh': .7,
  828. 'pre_nms_top_n': 2000,
  829. 'post_nms_top_n': 1000,
  830. 'topk_after_collect': True
  831. }
  832. test_proposal_cfg = {
  833. 'min_size': 0.0,
  834. 'nms_thresh': .7,
  835. 'pre_nms_top_n': 1000
  836. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  837. 'post_nms_top_n': test_post_nms_top_n
  838. }
  839. head = ppdet.modeling.TwoFCHead(
  840. in_channel=neck.out_shape[0].channels, out_channel=1024)
  841. roi_extractor_cfg = {
  842. 'resolution': 7,
  843. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  844. 'sampling_ratio': 0,
  845. 'aligned': True
  846. }
  847. with_pool = False
  848. else:
  849. neck = None
  850. anchor_generator_cfg = {
  851. 'aspect_ratios': aspect_ratios,
  852. 'anchor_sizes': anchor_sizes,
  853. 'strides': [16]
  854. }
  855. train_proposal_cfg = {
  856. 'min_size': 0.0,
  857. 'nms_thresh': .7,
  858. 'pre_nms_top_n': 12000,
  859. 'post_nms_top_n': 2000,
  860. 'topk_after_collect': False
  861. }
  862. test_proposal_cfg = {
  863. 'min_size': 0.0,
  864. 'nms_thresh': .7,
  865. 'pre_nms_top_n': 6000
  866. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  867. 'post_nms_top_n': test_post_nms_top_n
  868. }
  869. head = ppdet.modeling.Res5Head()
  870. roi_extractor_cfg = {
  871. 'resolution': 14,
  872. 'spatial_scale':
  873. [1. / i.stride for i in backbone.out_shape],
  874. 'sampling_ratio': 0,
  875. 'aligned': True
  876. }
  877. with_pool = True
  878. rpn_target_assign_cfg = {
  879. 'batch_size_per_im': rpn_batch_size_per_im,
  880. 'fg_fraction': rpn_fg_fraction,
  881. 'negative_overlap': .3,
  882. 'positive_overlap': .7,
  883. 'use_random': True
  884. }
  885. rpn_head = ppdet.modeling.RPNHead(
  886. anchor_generator=anchor_generator_cfg,
  887. rpn_target_assign=rpn_target_assign_cfg,
  888. train_proposal=train_proposal_cfg,
  889. test_proposal=test_proposal_cfg,
  890. in_channel=rpn_in_channel)
  891. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  892. bbox_head = ppdet.modeling.BBoxHead(
  893. head=head,
  894. in_channel=head.out_shape[0].channels,
  895. roi_extractor=roi_extractor_cfg,
  896. with_pool=with_pool,
  897. bbox_assigner=bbox_assigner,
  898. num_classes=num_classes)
  899. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  900. num_classes=num_classes,
  901. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  902. nms=ppdet.modeling.MultiClassNMS(
  903. score_threshold=score_threshold,
  904. keep_top_k=keep_top_k,
  905. nms_threshold=nms_threshold))
  906. params.update({
  907. 'backbone': backbone,
  908. 'neck': neck,
  909. 'rpn_head': rpn_head,
  910. 'bbox_head': bbox_head,
  911. 'bbox_post_process': bbox_post_process
  912. })
  913. else:
  914. if backbone not in ['ResNet50', 'ResNet50_vd']:
  915. with_fpn = True
  916. self.with_fpn = with_fpn
  917. super(FasterRCNN, self).__init__(
  918. model_name='FasterRCNN', num_classes=num_classes, **params)
  919. def _compose_batch_transform(self, transforms, mode='train'):
  920. if mode == 'train':
  921. default_batch_transforms = [
  922. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  923. ]
  924. collate_batch = False
  925. else:
  926. default_batch_transforms = [
  927. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  928. ]
  929. collate_batch = True
  930. custom_batch_transforms = []
  931. for i, op in enumerate(transforms.transforms):
  932. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  933. if mode != 'train':
  934. raise Exception(
  935. "{} cannot be present in the {} transforms. ".format(
  936. op.__class__.__name__, mode) +
  937. "Please check the {} transforms.".format(mode))
  938. custom_batch_transforms.insert(0, copy.deepcopy(op))
  939. batch_transforms = BatchCompose(
  940. custom_batch_transforms + default_batch_transforms,
  941. collate_batch=collate_batch)
  942. return batch_transforms
  943. def _fix_transforms_shape(self, image_shape):
  944. if hasattr(self, 'test_transforms'):
  945. if self.test_transforms is not None:
  946. has_resize_op = False
  947. resize_op_idx = -1
  948. normalize_op_idx = len(self.test_transforms.transforms)
  949. for idx, op in enumerate(self.test_transforms.transforms):
  950. name = op.__class__.__name__
  951. if name == 'ResizeByShort':
  952. has_resize_op = True
  953. resize_op_idx = idx
  954. if name == 'Normalize':
  955. normalize_op_idx = idx
  956. if not has_resize_op:
  957. self.test_transforms.transforms.insert(
  958. normalize_op_idx,
  959. Resize(
  960. target_size=image_shape,
  961. keep_ratio=True,
  962. interp='CUBIC'))
  963. else:
  964. self.test_transforms.transforms[resize_op_idx] = Resize(
  965. target_size=image_shape,
  966. keep_ratio=True,
  967. interp='CUBIC')
  968. self.test_transforms.transforms.append(
  969. Padding(im_padding_value=[0., 0., 0.]))
  970. def _get_test_inputs(self, image_shape):
  971. if image_shape is not None:
  972. image_shape = self._check_image_shape(image_shape)
  973. self._fix_transforms_shape(image_shape[-2:])
  974. else:
  975. image_shape = [None, 3, -1, -1]
  976. if self.with_fpn:
  977. self.test_transforms.transforms.append(
  978. Padding(im_padding_value=[0., 0., 0.]))
  979. self.fixed_input_shape = image_shape
  980. return self._define_input_spec(image_shape)
  981. class PPYOLO(YOLOv3):
  982. def __init__(self,
  983. num_classes=80,
  984. backbone='ResNet50_vd_dcn',
  985. anchors=None,
  986. anchor_masks=None,
  987. use_coord_conv=True,
  988. use_iou_aware=True,
  989. use_spp=True,
  990. use_drop_block=True,
  991. scale_x_y=1.05,
  992. ignore_threshold=0.7,
  993. label_smooth=False,
  994. use_iou_loss=True,
  995. use_matrix_nms=True,
  996. nms_score_threshold=0.01,
  997. nms_topk=-1,
  998. nms_keep_topk=100,
  999. nms_iou_threshold=0.45,
  1000. **params):
  1001. self.init_params = locals()
  1002. if backbone not in [
  1003. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  1004. 'MobileNetV3_small'
  1005. ]:
  1006. raise ValueError(
  1007. "backbone: {} is not supported. Please choose one of "
  1008. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  1009. format(backbone))
  1010. self.backbone_name = backbone
  1011. self.downsample_ratios = [
  1012. 32, 16, 8
  1013. ] if backbone == 'ResNet50_vd_dcn' else [32, 16]
  1014. if params.get('with_net', True):
  1015. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1016. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1017. norm_type = 'sync_bn'
  1018. else:
  1019. norm_type = 'bn'
  1020. if anchors is None and anchor_masks is None:
  1021. if 'MobileNetV3' in backbone:
  1022. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1023. [120, 195], [254, 235]]
  1024. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1025. elif backbone == 'ResNet50_vd_dcn':
  1026. anchors = [[10, 13], [16, 30], [33, 23], [30, 61],
  1027. [62, 45], [59, 119], [116, 90], [156, 198],
  1028. [373, 326]]
  1029. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1030. else:
  1031. anchors = [[10, 14], [23, 27], [37, 58], [81, 82],
  1032. [135, 169], [344, 319]]
  1033. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1034. elif anchors is None or anchor_masks is None:
  1035. raise ValueError(
  1036. "Please define both anchors and anchor_masks.")
  1037. if backbone == 'ResNet50_vd_dcn':
  1038. backbone = self._get_backbone(
  1039. 'ResNet',
  1040. variant='d',
  1041. norm_type=norm_type,
  1042. return_idx=[1, 2, 3],
  1043. dcn_v2_stages=[3],
  1044. freeze_at=-1,
  1045. freeze_norm=False,
  1046. norm_decay=0.)
  1047. elif backbone == 'ResNet18_vd':
  1048. backbone = self._get_backbone(
  1049. 'ResNet',
  1050. depth=18,
  1051. variant='d',
  1052. norm_type=norm_type,
  1053. return_idx=[2, 3],
  1054. freeze_at=-1,
  1055. freeze_norm=False,
  1056. norm_decay=0.)
  1057. elif backbone == 'MobileNetV3_large':
  1058. backbone = self._get_backbone(
  1059. 'MobileNetV3',
  1060. model_name='large',
  1061. norm_type=norm_type,
  1062. scale=1,
  1063. with_extra_blocks=False,
  1064. extra_block_filters=[],
  1065. feature_maps=[13, 16])
  1066. elif backbone == 'MobileNetV3_small':
  1067. backbone = self._get_backbone(
  1068. 'MobileNetV3',
  1069. model_name='small',
  1070. norm_type=norm_type,
  1071. scale=1,
  1072. with_extra_blocks=False,
  1073. extra_block_filters=[],
  1074. feature_maps=[9, 12])
  1075. neck = ppdet.modeling.PPYOLOFPN(
  1076. norm_type=norm_type,
  1077. in_channels=[i.channels for i in backbone.out_shape],
  1078. coord_conv=use_coord_conv,
  1079. drop_block=use_drop_block,
  1080. spp=use_spp,
  1081. conv_block_num=0
  1082. if ('MobileNetV3' in self.backbone_name or
  1083. self.backbone_name == 'ResNet18_vd') else 2)
  1084. loss = ppdet.modeling.YOLOv3Loss(
  1085. num_classes=num_classes,
  1086. ignore_thresh=ignore_threshold,
  1087. downsample=self.downsample_ratios,
  1088. label_smooth=label_smooth,
  1089. scale_x_y=scale_x_y,
  1090. iou_loss=ppdet.modeling.IouLoss(
  1091. loss_weight=2.5, loss_square=True)
  1092. if use_iou_loss else None,
  1093. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1094. if use_iou_aware else None)
  1095. yolo_head = ppdet.modeling.YOLOv3Head(
  1096. in_channels=[i.channels for i in neck.out_shape],
  1097. anchors=anchors,
  1098. anchor_masks=anchor_masks,
  1099. num_classes=num_classes,
  1100. loss=loss,
  1101. iou_aware=use_iou_aware)
  1102. if use_matrix_nms:
  1103. nms = ppdet.modeling.MatrixNMS(
  1104. keep_top_k=nms_keep_topk,
  1105. score_threshold=nms_score_threshold,
  1106. post_threshold=.05
  1107. if 'MobileNetV3' in self.backbone_name else .01,
  1108. nms_top_k=nms_topk,
  1109. background_label=-1)
  1110. else:
  1111. nms = ppdet.modeling.MultiClassNMS(
  1112. score_threshold=nms_score_threshold,
  1113. nms_top_k=nms_topk,
  1114. keep_top_k=nms_keep_topk,
  1115. nms_threshold=nms_iou_threshold)
  1116. post_process = ppdet.modeling.BBoxPostProcess(
  1117. decode=ppdet.modeling.YOLOBox(
  1118. num_classes=num_classes,
  1119. conf_thresh=.005
  1120. if 'MobileNetV3' in self.backbone_name else .01,
  1121. scale_x_y=scale_x_y),
  1122. nms=nms)
  1123. params.update({
  1124. 'backbone': backbone,
  1125. 'neck': neck,
  1126. 'yolo_head': yolo_head,
  1127. 'post_process': post_process
  1128. })
  1129. super(YOLOv3, self).__init__(
  1130. model_name='YOLOv3', num_classes=num_classes, **params)
  1131. self.anchors = anchors
  1132. self.anchor_masks = anchor_masks
  1133. self.model_name = 'PPYOLO'
  1134. def _get_test_inputs(self, image_shape):
  1135. if image_shape is not None:
  1136. image_shape = self._check_image_shape(image_shape)
  1137. self._fix_transforms_shape(image_shape[-2:])
  1138. else:
  1139. image_shape = [None, 3, 608, 608]
  1140. if hasattr(self, 'test_transforms'):
  1141. if self.test_transforms is not None:
  1142. for idx, op in enumerate(self.test_transforms.transforms):
  1143. name = op.__class__.__name__
  1144. if name == 'Resize':
  1145. image_shape = [None, 3] + list(
  1146. self.test_transforms.transforms[
  1147. idx].target_size)
  1148. logging.warning(
  1149. '[Important!!!] When exporting inference model for {},'.format(
  1150. self.__class__.__name__) +
  1151. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1152. format(image_shape) +
  1153. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1154. format(image_shape[1:]) + 'should be specified manually.')
  1155. self.fixed_input_shape = image_shape
  1156. return self._define_input_spec(image_shape)
  1157. class PPYOLOTiny(YOLOv3):
  1158. def __init__(self,
  1159. num_classes=80,
  1160. backbone='MobileNetV3',
  1161. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1162. [60, 170], [220, 125], [128, 222], [264, 266]],
  1163. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1164. use_iou_aware=False,
  1165. use_spp=True,
  1166. use_drop_block=True,
  1167. scale_x_y=1.05,
  1168. ignore_threshold=0.5,
  1169. label_smooth=False,
  1170. use_iou_loss=True,
  1171. use_matrix_nms=False,
  1172. nms_score_threshold=0.005,
  1173. nms_topk=1000,
  1174. nms_keep_topk=100,
  1175. nms_iou_threshold=0.45,
  1176. **params):
  1177. self.init_params = locals()
  1178. if backbone != 'MobileNetV3':
  1179. logging.warning(
  1180. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1181. "Backbone is forcibly set to MobileNetV3.")
  1182. self.backbone_name = 'MobileNetV3'
  1183. self.downsample_ratios = [32, 16, 8]
  1184. if params.get('with_net', True):
  1185. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1186. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1187. norm_type = 'sync_bn'
  1188. else:
  1189. norm_type = 'bn'
  1190. backbone = self._get_backbone(
  1191. 'MobileNetV3',
  1192. model_name='large',
  1193. norm_type=norm_type,
  1194. scale=.5,
  1195. with_extra_blocks=False,
  1196. extra_block_filters=[],
  1197. feature_maps=[7, 13, 16])
  1198. neck = ppdet.modeling.PPYOLOTinyFPN(
  1199. detection_block_channels=[160, 128, 96],
  1200. in_channels=[i.channels for i in backbone.out_shape],
  1201. spp=use_spp,
  1202. drop_block=use_drop_block)
  1203. loss = ppdet.modeling.YOLOv3Loss(
  1204. num_classes=num_classes,
  1205. ignore_thresh=ignore_threshold,
  1206. downsample=self.downsample_ratios,
  1207. label_smooth=label_smooth,
  1208. scale_x_y=scale_x_y,
  1209. iou_loss=ppdet.modeling.IouLoss(
  1210. loss_weight=2.5, loss_square=True)
  1211. if use_iou_loss else None,
  1212. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1213. if use_iou_aware else None)
  1214. yolo_head = ppdet.modeling.YOLOv3Head(
  1215. in_channels=[i.channels for i in neck.out_shape],
  1216. anchors=anchors,
  1217. anchor_masks=anchor_masks,
  1218. num_classes=num_classes,
  1219. loss=loss,
  1220. iou_aware=use_iou_aware)
  1221. if use_matrix_nms:
  1222. nms = ppdet.modeling.MatrixNMS(
  1223. keep_top_k=nms_keep_topk,
  1224. score_threshold=nms_score_threshold,
  1225. post_threshold=.05,
  1226. nms_top_k=nms_topk,
  1227. background_label=-1)
  1228. else:
  1229. nms = ppdet.modeling.MultiClassNMS(
  1230. score_threshold=nms_score_threshold,
  1231. nms_top_k=nms_topk,
  1232. keep_top_k=nms_keep_topk,
  1233. nms_threshold=nms_iou_threshold)
  1234. post_process = ppdet.modeling.BBoxPostProcess(
  1235. decode=ppdet.modeling.YOLOBox(
  1236. num_classes=num_classes,
  1237. conf_thresh=.005,
  1238. downsample_ratio=32,
  1239. clip_bbox=True,
  1240. scale_x_y=scale_x_y),
  1241. nms=nms)
  1242. params.update({
  1243. 'backbone': backbone,
  1244. 'neck': neck,
  1245. 'yolo_head': yolo_head,
  1246. 'post_process': post_process
  1247. })
  1248. super(YOLOv3, self).__init__(
  1249. model_name='YOLOv3', num_classes=num_classes, **params)
  1250. self.anchors = anchors
  1251. self.anchor_masks = anchor_masks
  1252. self.model_name = 'PPYOLOTiny'
  1253. def _get_test_inputs(self, image_shape):
  1254. if image_shape is not None:
  1255. image_shape = self._check_image_shape(image_shape)
  1256. self._fix_transforms_shape(image_shape[-2:])
  1257. else:
  1258. image_shape = [None, 3, 320, 320]
  1259. if hasattr(self, 'test_transforms'):
  1260. if self.test_transforms is not None:
  1261. for idx, op in enumerate(self.test_transforms.transforms):
  1262. name = op.__class__.__name__
  1263. if name == 'Resize':
  1264. image_shape = [None, 3] + list(
  1265. self.test_transforms.transforms[
  1266. idx].target_size)
  1267. logging.warning(
  1268. '[Important!!!] When exporting inference model for {},'.format(
  1269. self.__class__.__name__) +
  1270. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1271. format(image_shape) +
  1272. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1273. format(image_shape[1:]) + 'should be specified manually.')
  1274. self.fixed_input_shape = image_shape
  1275. return self._define_input_spec(image_shape)
  1276. class PPYOLOv2(YOLOv3):
  1277. def __init__(self,
  1278. num_classes=80,
  1279. backbone='ResNet50_vd_dcn',
  1280. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1281. [59, 119], [116, 90], [156, 198], [373, 326]],
  1282. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1283. use_iou_aware=True,
  1284. use_spp=True,
  1285. use_drop_block=True,
  1286. scale_x_y=1.05,
  1287. ignore_threshold=0.7,
  1288. label_smooth=False,
  1289. use_iou_loss=True,
  1290. use_matrix_nms=True,
  1291. nms_score_threshold=0.01,
  1292. nms_topk=-1,
  1293. nms_keep_topk=100,
  1294. nms_iou_threshold=0.45,
  1295. **params):
  1296. self.init_params = locals()
  1297. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1298. raise ValueError(
  1299. "backbone: {} is not supported. Please choose one of "
  1300. "('ResNet50_vd_dcn', 'ResNet101_vd_dcn')".format(backbone))
  1301. self.backbone_name = backbone
  1302. self.downsample_ratios = [32, 16, 8]
  1303. if params.get('with_net', True):
  1304. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1305. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1306. norm_type = 'sync_bn'
  1307. else:
  1308. norm_type = 'bn'
  1309. if backbone == 'ResNet50_vd_dcn':
  1310. backbone = self._get_backbone(
  1311. 'ResNet',
  1312. variant='d',
  1313. norm_type=norm_type,
  1314. return_idx=[1, 2, 3],
  1315. dcn_v2_stages=[3],
  1316. freeze_at=-1,
  1317. freeze_norm=False,
  1318. norm_decay=0.)
  1319. elif backbone == 'ResNet101_vd_dcn':
  1320. backbone = self._get_backbone(
  1321. 'ResNet',
  1322. depth=101,
  1323. variant='d',
  1324. norm_type=norm_type,
  1325. return_idx=[1, 2, 3],
  1326. dcn_v2_stages=[3],
  1327. freeze_at=-1,
  1328. freeze_norm=False,
  1329. norm_decay=0.)
  1330. neck = ppdet.modeling.PPYOLOPAN(
  1331. norm_type=norm_type,
  1332. in_channels=[i.channels for i in backbone.out_shape],
  1333. drop_block=use_drop_block,
  1334. block_size=3,
  1335. keep_prob=.9,
  1336. spp=use_spp)
  1337. loss = ppdet.modeling.YOLOv3Loss(
  1338. num_classes=num_classes,
  1339. ignore_thresh=ignore_threshold,
  1340. downsample=self.downsample_ratios,
  1341. label_smooth=label_smooth,
  1342. scale_x_y=scale_x_y,
  1343. iou_loss=ppdet.modeling.IouLoss(
  1344. loss_weight=2.5, loss_square=True)
  1345. if use_iou_loss else None,
  1346. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1347. if use_iou_aware else None)
  1348. yolo_head = ppdet.modeling.YOLOv3Head(
  1349. in_channels=[i.channels for i in neck.out_shape],
  1350. anchors=anchors,
  1351. anchor_masks=anchor_masks,
  1352. num_classes=num_classes,
  1353. loss=loss,
  1354. iou_aware=use_iou_aware,
  1355. iou_aware_factor=.5)
  1356. if use_matrix_nms:
  1357. nms = ppdet.modeling.MatrixNMS(
  1358. keep_top_k=nms_keep_topk,
  1359. score_threshold=nms_score_threshold,
  1360. post_threshold=.01,
  1361. nms_top_k=nms_topk,
  1362. background_label=-1)
  1363. else:
  1364. nms = ppdet.modeling.MultiClassNMS(
  1365. score_threshold=nms_score_threshold,
  1366. nms_top_k=nms_topk,
  1367. keep_top_k=nms_keep_topk,
  1368. nms_threshold=nms_iou_threshold)
  1369. post_process = ppdet.modeling.BBoxPostProcess(
  1370. decode=ppdet.modeling.YOLOBox(
  1371. num_classes=num_classes,
  1372. conf_thresh=.01,
  1373. downsample_ratio=32,
  1374. clip_bbox=True,
  1375. scale_x_y=scale_x_y),
  1376. nms=nms)
  1377. params.update({
  1378. 'backbone': backbone,
  1379. 'neck': neck,
  1380. 'yolo_head': yolo_head,
  1381. 'post_process': post_process
  1382. })
  1383. super(YOLOv3, self).__init__(
  1384. model_name='YOLOv3', num_classes=num_classes, **params)
  1385. self.anchors = anchors
  1386. self.anchor_masks = anchor_masks
  1387. self.model_name = 'PPYOLOv2'
  1388. def _get_test_inputs(self, image_shape):
  1389. if image_shape is not None:
  1390. image_shape = self._check_image_shape(image_shape)
  1391. self._fix_transforms_shape(image_shape[-2:])
  1392. else:
  1393. image_shape = [None, 3, 640, 640]
  1394. if hasattr(self, 'test_transforms'):
  1395. if self.test_transforms is not None:
  1396. for idx, op in enumerate(self.test_transforms.transforms):
  1397. name = op.__class__.__name__
  1398. if name == 'Resize':
  1399. image_shape = [None, 3] + list(
  1400. self.test_transforms.transforms[
  1401. idx].target_size)
  1402. logging.warning(
  1403. '[Important!!!] When exporting inference model for {},'.format(
  1404. self.__class__.__name__) +
  1405. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1406. format(image_shape) +
  1407. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1408. format(image_shape[1:]) + 'should be specified manually.')
  1409. self.fixed_input_shape = image_shape
  1410. return self._define_input_spec(image_shape)
  1411. class MaskRCNN(BaseDetector):
  1412. def __init__(self,
  1413. num_classes=80,
  1414. backbone='ResNet50_vd',
  1415. with_fpn=True,
  1416. with_dcn=False,
  1417. aspect_ratios=[0.5, 1.0, 2.0],
  1418. anchor_sizes=[[32], [64], [128], [256], [512]],
  1419. keep_top_k=100,
  1420. nms_threshold=0.5,
  1421. score_threshold=0.05,
  1422. fpn_num_channels=256,
  1423. rpn_batch_size_per_im=256,
  1424. rpn_fg_fraction=0.5,
  1425. test_pre_nms_top_n=None,
  1426. test_post_nms_top_n=1000,
  1427. **params):
  1428. self.init_params = locals()
  1429. if backbone not in [
  1430. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1431. 'ResNet101_vd'
  1432. ]:
  1433. raise ValueError(
  1434. "backbone: {} is not supported. Please choose one of "
  1435. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1436. format(backbone))
  1437. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1438. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1439. if params.get('with_net', True):
  1440. if backbone == 'ResNet50':
  1441. if with_fpn:
  1442. backbone = self._get_backbone(
  1443. 'ResNet',
  1444. norm_type='bn',
  1445. freeze_at=0,
  1446. return_idx=[0, 1, 2, 3],
  1447. num_stages=4,
  1448. dcn_v2_stages=dcn_v2_stages)
  1449. else:
  1450. if with_dcn:
  1451. logging.warning(
  1452. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1453. format(backbone))
  1454. backbone = self._get_backbone(
  1455. 'ResNet',
  1456. norm_type='bn',
  1457. freeze_at=0,
  1458. return_idx=[2],
  1459. num_stages=3)
  1460. elif 'ResNet50_vd' in backbone:
  1461. if not with_fpn:
  1462. logging.warning(
  1463. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1464. format(backbone))
  1465. with_fpn = True
  1466. backbone = self._get_backbone(
  1467. 'ResNet',
  1468. variant='d',
  1469. norm_type='bn',
  1470. freeze_at=0,
  1471. return_idx=[0, 1, 2, 3],
  1472. num_stages=4,
  1473. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1474. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1475. dcn_v2_stages=dcn_v2_stages)
  1476. else:
  1477. if not with_fpn:
  1478. logging.warning(
  1479. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1480. format(backbone))
  1481. with_fpn = True
  1482. backbone = self._get_backbone(
  1483. 'ResNet',
  1484. variant='d' if '_vd' in backbone else 'b',
  1485. depth=101,
  1486. norm_type='bn',
  1487. freeze_at=0,
  1488. return_idx=[0, 1, 2, 3],
  1489. num_stages=4,
  1490. dcn_v2_stages=dcn_v2_stages)
  1491. rpn_in_channel = backbone.out_shape[0].channels
  1492. if with_fpn:
  1493. neck = ppdet.modeling.FPN(
  1494. in_channels=[i.channels for i in backbone.out_shape],
  1495. out_channel=fpn_num_channels,
  1496. spatial_scales=[
  1497. 1.0 / i.stride for i in backbone.out_shape
  1498. ])
  1499. rpn_in_channel = neck.out_shape[0].channels
  1500. anchor_generator_cfg = {
  1501. 'aspect_ratios': aspect_ratios,
  1502. 'anchor_sizes': anchor_sizes,
  1503. 'strides': [4, 8, 16, 32, 64]
  1504. }
  1505. train_proposal_cfg = {
  1506. 'min_size': 0.0,
  1507. 'nms_thresh': .7,
  1508. 'pre_nms_top_n': 2000,
  1509. 'post_nms_top_n': 1000,
  1510. 'topk_after_collect': True
  1511. }
  1512. test_proposal_cfg = {
  1513. 'min_size': 0.0,
  1514. 'nms_thresh': .7,
  1515. 'pre_nms_top_n': 1000
  1516. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1517. 'post_nms_top_n': test_post_nms_top_n
  1518. }
  1519. bb_head = ppdet.modeling.TwoFCHead(
  1520. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1521. bb_roi_extractor_cfg = {
  1522. 'resolution': 7,
  1523. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1524. 'sampling_ratio': 0,
  1525. 'aligned': True
  1526. }
  1527. with_pool = False
  1528. m_head = ppdet.modeling.MaskFeat(
  1529. in_channel=neck.out_shape[0].channels,
  1530. out_channel=256,
  1531. num_convs=4)
  1532. m_roi_extractor_cfg = {
  1533. 'resolution': 14,
  1534. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1535. 'sampling_ratio': 0,
  1536. 'aligned': True
  1537. }
  1538. mask_assigner = MaskAssigner(
  1539. num_classes=num_classes, mask_resolution=28)
  1540. share_bbox_feat = False
  1541. else:
  1542. neck = None
  1543. anchor_generator_cfg = {
  1544. 'aspect_ratios': aspect_ratios,
  1545. 'anchor_sizes': anchor_sizes,
  1546. 'strides': [16]
  1547. }
  1548. train_proposal_cfg = {
  1549. 'min_size': 0.0,
  1550. 'nms_thresh': .7,
  1551. 'pre_nms_top_n': 12000,
  1552. 'post_nms_top_n': 2000,
  1553. 'topk_after_collect': False
  1554. }
  1555. test_proposal_cfg = {
  1556. 'min_size': 0.0,
  1557. 'nms_thresh': .7,
  1558. 'pre_nms_top_n': 6000
  1559. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1560. 'post_nms_top_n': test_post_nms_top_n
  1561. }
  1562. bb_head = ppdet.modeling.Res5Head()
  1563. bb_roi_extractor_cfg = {
  1564. 'resolution': 14,
  1565. 'spatial_scale':
  1566. [1. / i.stride for i in backbone.out_shape],
  1567. 'sampling_ratio': 0,
  1568. 'aligned': True
  1569. }
  1570. with_pool = True
  1571. m_head = ppdet.modeling.MaskFeat(
  1572. in_channel=bb_head.out_shape[0].channels,
  1573. out_channel=256,
  1574. num_convs=0)
  1575. m_roi_extractor_cfg = {
  1576. 'resolution': 14,
  1577. 'spatial_scale':
  1578. [1. / i.stride for i in backbone.out_shape],
  1579. 'sampling_ratio': 0,
  1580. 'aligned': True
  1581. }
  1582. mask_assigner = MaskAssigner(
  1583. num_classes=num_classes, mask_resolution=14)
  1584. share_bbox_feat = True
  1585. rpn_target_assign_cfg = {
  1586. 'batch_size_per_im': rpn_batch_size_per_im,
  1587. 'fg_fraction': rpn_fg_fraction,
  1588. 'negative_overlap': .3,
  1589. 'positive_overlap': .7,
  1590. 'use_random': True
  1591. }
  1592. rpn_head = ppdet.modeling.RPNHead(
  1593. anchor_generator=anchor_generator_cfg,
  1594. rpn_target_assign=rpn_target_assign_cfg,
  1595. train_proposal=train_proposal_cfg,
  1596. test_proposal=test_proposal_cfg,
  1597. in_channel=rpn_in_channel)
  1598. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1599. bbox_head = ppdet.modeling.BBoxHead(
  1600. head=bb_head,
  1601. in_channel=bb_head.out_shape[0].channels,
  1602. roi_extractor=bb_roi_extractor_cfg,
  1603. with_pool=with_pool,
  1604. bbox_assigner=bbox_assigner,
  1605. num_classes=num_classes)
  1606. mask_head = ppdet.modeling.MaskHead(
  1607. head=m_head,
  1608. roi_extractor=m_roi_extractor_cfg,
  1609. mask_assigner=mask_assigner,
  1610. share_bbox_feat=share_bbox_feat,
  1611. num_classes=num_classes)
  1612. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1613. num_classes=num_classes,
  1614. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1615. nms=ppdet.modeling.MultiClassNMS(
  1616. score_threshold=score_threshold,
  1617. keep_top_k=keep_top_k,
  1618. nms_threshold=nms_threshold))
  1619. mask_post_process = ppdet.modeling.MaskPostProcess(
  1620. binary_thresh=.5)
  1621. params.update({
  1622. 'backbone': backbone,
  1623. 'neck': neck,
  1624. 'rpn_head': rpn_head,
  1625. 'bbox_head': bbox_head,
  1626. 'mask_head': mask_head,
  1627. 'bbox_post_process': bbox_post_process,
  1628. 'mask_post_process': mask_post_process
  1629. })
  1630. self.with_fpn = with_fpn
  1631. super(MaskRCNN, self).__init__(
  1632. model_name='MaskRCNN', num_classes=num_classes, **params)
  1633. def _compose_batch_transform(self, transforms, mode='train'):
  1634. if mode == 'train':
  1635. default_batch_transforms = [
  1636. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1637. ]
  1638. collate_batch = False
  1639. else:
  1640. default_batch_transforms = [
  1641. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1642. ]
  1643. collate_batch = True
  1644. custom_batch_transforms = []
  1645. for i, op in enumerate(transforms.transforms):
  1646. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1647. if mode != 'train':
  1648. raise Exception(
  1649. "{} cannot be present in the {} transforms. ".format(
  1650. op.__class__.__name__, mode) +
  1651. "Please check the {} transforms.".format(mode))
  1652. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1653. batch_transforms = BatchCompose(
  1654. custom_batch_transforms + default_batch_transforms,
  1655. collate_batch=collate_batch)
  1656. return batch_transforms
  1657. def _fix_transforms_shape(self, image_shape):
  1658. if hasattr(self, 'test_transforms'):
  1659. if self.test_transforms is not None:
  1660. has_resize_op = False
  1661. resize_op_idx = -1
  1662. normalize_op_idx = len(self.test_transforms.transforms)
  1663. for idx, op in enumerate(self.test_transforms.transforms):
  1664. name = op.__class__.__name__
  1665. if name == 'ResizeByShort':
  1666. has_resize_op = True
  1667. resize_op_idx = idx
  1668. if name == 'Normalize':
  1669. normalize_op_idx = idx
  1670. if not has_resize_op:
  1671. self.test_transforms.transforms.insert(
  1672. normalize_op_idx,
  1673. Resize(
  1674. target_size=image_shape,
  1675. keep_ratio=True,
  1676. interp='CUBIC'))
  1677. else:
  1678. self.test_transforms.transforms[resize_op_idx] = Resize(
  1679. target_size=image_shape,
  1680. keep_ratio=True,
  1681. interp='CUBIC')
  1682. self.test_transforms.transforms.append(
  1683. Padding(im_padding_value=[0., 0., 0.]))
  1684. def _get_test_inputs(self, image_shape):
  1685. if image_shape is not None:
  1686. image_shape = self._check_image_shape(image_shape)
  1687. self._fix_transforms_shape(image_shape[-2:])
  1688. else:
  1689. image_shape = [None, 3, -1, -1]
  1690. if self.with_fpn:
  1691. self.test_transforms.transforms.append(
  1692. Padding(im_padding_value=[0., 0., 0.]))
  1693. self.fixed_input_shape = image_shape
  1694. return self._define_input_spec(image_shape)