detector.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import ppdet
  24. from ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. del self.init_params['params']
  41. super(BaseDetector, self).__init__('detector')
  42. if not hasattr(ppdet.modeling, model_name):
  43. raise Exception("ERROR: There's no model named {}.".format(
  44. model_name))
  45. self.model_name = model_name
  46. self.num_classes = num_classes
  47. self.labels = None
  48. self.net = self.build_net(**params)
  49. def build_net(self, **params):
  50. with paddle.utils.unique_name.guard():
  51. net = ppdet.modeling.__dict__[self.model_name](**params)
  52. return net
  53. def get_test_inputs(self, image_shape):
  54. input_spec = [{
  55. "image": InputSpec(
  56. shape=[None, 3] + image_shape, name='image', dtype='float32'),
  57. "im_shape": InputSpec(
  58. shape=[None, 2], name='im_shape', dtype='float32'),
  59. "scale_factor": InputSpec(
  60. shape=[None, 2], name='scale_factor', dtype='float32')
  61. }]
  62. return input_spec
  63. def _get_backbone(self, backbone_name, **params):
  64. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  65. return backbone
  66. def run(self, net, inputs, mode):
  67. net_out = net(inputs)
  68. if mode in ['train', 'eval']:
  69. outputs = net_out
  70. else:
  71. for key in ['im_shape', 'scale_factor']:
  72. net_out[key] = inputs[key]
  73. outputs = dict()
  74. for key in net_out:
  75. outputs[key] = net_out[key].numpy()
  76. return outputs
  77. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  78. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  79. num_steps_each_epoch):
  80. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  81. values = [(lr_decay_gamma**i) * learning_rate
  82. for i in range(len(lr_decay_epochs) + 1)]
  83. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  84. boundaries=boundaries, values=values)
  85. if warmup_steps > 0:
  86. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  87. logging.error(
  88. "In function train(), parameters should satisfy: "
  89. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  90. exit=False)
  91. logging.error(
  92. "See this doc for more information: "
  93. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  94. exit=False)
  95. scheduler = paddle.optimizer.lr.LinearWarmup(
  96. learning_rate=scheduler,
  97. warmup_steps=warmup_steps,
  98. start_lr=warmup_start_lr,
  99. end_lr=learning_rate)
  100. optimizer = paddle.optimizer.Momentum(
  101. scheduler,
  102. momentum=.9,
  103. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  104. parameters=parameters)
  105. return optimizer
  106. def train(self,
  107. num_epochs,
  108. train_dataset,
  109. train_batch_size=64,
  110. eval_dataset=None,
  111. optimizer=None,
  112. save_interval_epochs=1,
  113. log_interval_steps=10,
  114. save_dir='output',
  115. pretrain_weights='IMAGENET',
  116. learning_rate=.001,
  117. warmup_steps=0,
  118. warmup_start_lr=0.0,
  119. lr_decay_epochs=(216, 243),
  120. lr_decay_gamma=0.1,
  121. metric=None,
  122. use_ema=False,
  123. early_stop=False,
  124. early_stop_patience=5,
  125. use_vdl=True):
  126. """
  127. Train the model.
  128. Args:
  129. num_epochs(int): The number of epochs.
  130. train_dataset(paddlex.dataset): Training dataset.
  131. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  132. eval_dataset(paddlex.dataset, optional):
  133. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  134. optimizer(paddle.optimizer.Optimizer or None, optional):
  135. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  136. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  137. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  138. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  139. pretrain_weights(str or None, optional):
  140. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  141. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  142. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  143. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  144. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  145. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  146. metric({'VOC', 'COCO', None}, optional):
  147. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  148. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  149. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  150. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  151. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  152. """
  153. if train_dataset.__class__.__name__ == 'VOCDetection':
  154. train_dataset.data_fields = {
  155. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  156. 'difficult'
  157. }
  158. elif train_dataset.__class__.__name__ == 'CocoDetection':
  159. if self.__class__.__name__ == 'MaskRCNN':
  160. train_dataset.data_fields = {
  161. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  162. 'gt_poly', 'is_crowd'
  163. }
  164. else:
  165. train_dataset.data_fields = {
  166. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  167. 'is_crowd'
  168. }
  169. if metric is None:
  170. if eval_dataset.__class__.__name__ == 'VOCDetection':
  171. self.metric = 'voc'
  172. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  173. self.metric = 'coco'
  174. else:
  175. assert metric.lower() in ['coco', 'voc'], \
  176. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  177. self.metric = metric.lower()
  178. train_dataset.batch_transforms = self._compose_batch_transform(
  179. train_dataset.transforms, mode='train')
  180. self.labels = train_dataset.labels
  181. # build optimizer if not defined
  182. if optimizer is None:
  183. num_steps_each_epoch = len(train_dataset) // train_batch_size
  184. self.optimizer = self.default_optimizer(
  185. parameters=self.net.parameters(),
  186. learning_rate=learning_rate,
  187. warmup_steps=warmup_steps,
  188. warmup_start_lr=warmup_start_lr,
  189. lr_decay_epochs=lr_decay_epochs,
  190. lr_decay_gamma=lr_decay_gamma,
  191. num_steps_each_epoch=num_steps_each_epoch)
  192. else:
  193. self.optimizer = optimizer
  194. # initiate weights
  195. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  196. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  197. [self.model_name, self.backbone_name])]:
  198. logging.warning(
  199. "Path of pretrain_weights('{}') does not exist!".format(
  200. pretrain_weights))
  201. pretrain_weights = det_pretrain_weights_dict['_'.join(
  202. [self.model_name, self.backbone_name])][0]
  203. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  204. "If don't want to use pretrain weights, "
  205. "set pretrain_weights to be None.".format(
  206. pretrain_weights))
  207. pretrained_dir = osp.join(save_dir, 'pretrain')
  208. self.net_initialize(
  209. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  210. if use_ema:
  211. ema = ExponentialMovingAverage(
  212. decay=.9998, model=self.net, use_thres_step=True)
  213. else:
  214. ema = None
  215. # start train loop
  216. self.train_loop(
  217. num_epochs=num_epochs,
  218. train_dataset=train_dataset,
  219. train_batch_size=train_batch_size,
  220. eval_dataset=eval_dataset,
  221. save_interval_epochs=save_interval_epochs,
  222. log_interval_steps=log_interval_steps,
  223. save_dir=save_dir,
  224. ema=ema,
  225. early_stop=early_stop,
  226. early_stop_patience=early_stop_patience,
  227. use_vdl=use_vdl)
  228. def quant_aware_train(self,
  229. num_epochs,
  230. train_dataset,
  231. train_batch_size=64,
  232. eval_dataset=None,
  233. optimizer=None,
  234. save_interval_epochs=1,
  235. log_interval_steps=10,
  236. save_dir='output',
  237. learning_rate=.00001,
  238. warmup_steps=0,
  239. warmup_start_lr=0.0,
  240. lr_decay_epochs=(216, 243),
  241. lr_decay_gamma=0.1,
  242. metric=None,
  243. use_ema=False,
  244. early_stop=False,
  245. early_stop_patience=5,
  246. use_vdl=True,
  247. quant_config=None):
  248. """
  249. Quantization-aware training.
  250. Args:
  251. num_epochs(int): The number of epochs.
  252. train_dataset(paddlex.dataset): Training dataset.
  253. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  254. eval_dataset(paddlex.dataset, optional):
  255. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  256. optimizer(paddle.optimizer.Optimizer or None, optional):
  257. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  258. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  259. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  260. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  261. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  262. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  263. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  264. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  265. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  266. metric({'VOC', 'COCO', None}, optional):
  267. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  268. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  269. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  270. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  271. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  272. infer_image_shape(List[int], optional): The shape of input images during inference process, in [w, h] format.
  273. If the shape of images is variable, set `infer_image_shape` to [-1, -1]. Defaults to [-1, -1].
  274. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  275. configuration will be used. Defaults to None.
  276. """
  277. self._prepare_qat(quant_config)
  278. self.train(
  279. num_epochs=num_epochs,
  280. train_dataset=train_dataset,
  281. train_batch_size=train_batch_size,
  282. eval_dataset=eval_dataset,
  283. optimizer=optimizer,
  284. save_interval_epochs=save_interval_epochs,
  285. log_interval_steps=log_interval_steps,
  286. save_dir=save_dir,
  287. pretrain_weights=None,
  288. learning_rate=learning_rate,
  289. warmup_steps=warmup_steps,
  290. warmup_start_lr=warmup_start_lr,
  291. lr_decay_epochs=lr_decay_epochs,
  292. lr_decay_gamma=lr_decay_gamma,
  293. metric=metric,
  294. use_ema=use_ema,
  295. early_stop=early_stop,
  296. early_stop_patience=early_stop_patience,
  297. use_vdl=use_vdl)
  298. def evaluate(self,
  299. eval_dataset,
  300. batch_size=1,
  301. metric=None,
  302. return_details=False):
  303. """
  304. Evaluate the model.
  305. Args:
  306. eval_dataset(paddlex.dataset): Evaluation dataset.
  307. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  308. metric({'VOC', 'COCO', None}, optional):
  309. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  310. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  311. Returns:
  312. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  313. """
  314. if eval_dataset.__class__.__name__ == 'VOCDetection':
  315. eval_dataset.data_fields = {
  316. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  317. 'difficult'
  318. }
  319. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  320. if self.__class__.__name__ == 'MaskRCNN':
  321. eval_dataset.data_fields = {
  322. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  323. 'gt_poly', 'is_crowd'
  324. }
  325. else:
  326. eval_dataset.data_fields = {
  327. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  328. 'is_crowd'
  329. }
  330. eval_dataset.batch_transforms = self._compose_batch_transform(
  331. eval_dataset.transforms, mode='eval')
  332. arrange_transforms(
  333. model_type=self.model_type,
  334. transforms=eval_dataset.transforms,
  335. mode='eval')
  336. self.net.eval()
  337. nranks = paddle.distributed.get_world_size()
  338. local_rank = paddle.distributed.get_rank()
  339. if nranks > 1:
  340. # Initialize parallel environment if not done.
  341. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  342. ):
  343. paddle.distributed.init_parallel_env()
  344. if batch_size > 1:
  345. logging.warning(
  346. "Detector only supports single card evaluation with batch_size=1 "
  347. "during evaluation, so batch_size is forcibly set to 1.")
  348. batch_size = 1
  349. if nranks < 2 or local_rank == 0:
  350. self.eval_data_loader = self.build_data_loader(
  351. eval_dataset, batch_size=batch_size, mode='eval')
  352. is_bbox_normalized = False
  353. if eval_dataset.batch_transforms is not None:
  354. is_bbox_normalized = any(
  355. isinstance(t, _NormalizeBox)
  356. for t in eval_dataset.batch_transforms.batch_transforms)
  357. if metric is None:
  358. if getattr(self, 'metric', None) is not None:
  359. if self.metric == 'voc':
  360. eval_metric = VOCMetric(
  361. labels=eval_dataset.labels,
  362. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  363. is_bbox_normalized=is_bbox_normalized,
  364. classwise=False)
  365. else:
  366. eval_metric = COCOMetric(
  367. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  368. classwise=False)
  369. else:
  370. if eval_dataset.__class__.__name__ == 'VOCDetection':
  371. eval_metric = VOCMetric(
  372. labels=eval_dataset.labels,
  373. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  374. is_bbox_normalized=is_bbox_normalized,
  375. classwise=False)
  376. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  377. eval_metric = COCOMetric(
  378. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  379. classwise=False)
  380. else:
  381. assert metric.lower() in ['coco', 'voc'], \
  382. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  383. if metric.lower() == 'coco':
  384. eval_metric = COCOMetric(
  385. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  386. classwise=False)
  387. else:
  388. eval_metric = VOCMetric(
  389. labels=eval_dataset.labels,
  390. is_bbox_normalized=is_bbox_normalized,
  391. classwise=False)
  392. scores = collections.OrderedDict()
  393. logging.info(
  394. "Start to evaluate(total_samples={}, total_steps={})...".
  395. format(eval_dataset.num_samples, eval_dataset.num_samples))
  396. with paddle.no_grad():
  397. for step, data in enumerate(self.eval_data_loader):
  398. outputs = self.run(self.net, data, 'eval')
  399. eval_metric.update(data, outputs)
  400. eval_metric.accumulate()
  401. self.eval_details = eval_metric.details
  402. scores.update(eval_metric.get())
  403. eval_metric.reset()
  404. if return_details:
  405. return scores, self.eval_details
  406. return scores
  407. def predict(self, img_file, transforms=None):
  408. """
  409. Do inference.
  410. Args:
  411. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  412. Image path or decoded image data in a BGR format, which also could constitute a list,
  413. meaning all images to be predicted as a mini-batch.
  414. transforms(paddlex.transforms.Compose or None, optional):
  415. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  416. Returns:
  417. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  418. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  419. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  420. category_id(int): the predicted category ID
  421. category(str): category name
  422. bbox(list): bounding box in [x, y, w, h] format
  423. score(str): confidence
  424. """
  425. if transforms is None and not hasattr(self, 'test_transforms'):
  426. raise Exception("transforms need to be defined, now is None.")
  427. if transforms is None:
  428. transforms = self.test_transforms
  429. if isinstance(img_file, (str, np.ndarray)):
  430. images = [img_file]
  431. else:
  432. images = img_file
  433. batch_samples = self._preprocess(images, transforms)
  434. self.net.eval()
  435. outputs = self.run(self.net, batch_samples, 'test')
  436. prediction = self._postprocess(outputs)
  437. if isinstance(img_file, (str, np.ndarray)):
  438. prediction = prediction[0]
  439. return prediction
  440. def _preprocess(self, images, transforms):
  441. arrange_transforms(
  442. model_type=self.model_type, transforms=transforms, mode='test')
  443. batch_samples = list()
  444. for im in images:
  445. sample = {'image': im}
  446. batch_samples.append(transforms(sample))
  447. batch_transforms = self._compose_batch_transform(transforms, 'test')
  448. batch_samples = batch_transforms(batch_samples)
  449. for k, v in batch_samples.items():
  450. batch_samples[k] = paddle.to_tensor(v)
  451. return batch_samples
  452. def _postprocess(self, batch_pred):
  453. infer_result = {}
  454. if 'bbox' in batch_pred:
  455. bboxes = batch_pred['bbox']
  456. bbox_nums = batch_pred['bbox_num']
  457. det_res = []
  458. k = 0
  459. for i in range(len(bbox_nums)):
  460. det_nums = bbox_nums[i]
  461. for j in range(det_nums):
  462. dt = bboxes[k]
  463. k = k + 1
  464. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  465. if int(num_id) < 0:
  466. continue
  467. category = self.labels[int(num_id)]
  468. w = xmax - xmin
  469. h = ymax - ymin
  470. bbox = [xmin, ymin, w, h]
  471. dt_res = {
  472. 'category_id': int(num_id),
  473. 'category': category,
  474. 'bbox': bbox,
  475. 'score': score
  476. }
  477. det_res.append(dt_res)
  478. infer_result['bbox'] = det_res
  479. if 'mask' in batch_pred:
  480. masks = batch_pred['mask']
  481. bboxes = batch_pred['bbox']
  482. mask_nums = batch_pred['bbox_num']
  483. seg_res = []
  484. k = 0
  485. for i in range(len(mask_nums)):
  486. det_nums = mask_nums[i]
  487. for j in range(det_nums):
  488. mask = masks[k].astype(np.uint8)
  489. score = float(bboxes[k][1])
  490. label = int(bboxes[k][0])
  491. k = k + 1
  492. if label == -1:
  493. continue
  494. category = self.labels[int(label)]
  495. import pycocotools.mask as mask_util
  496. rle = mask_util.encode(
  497. np.array(
  498. mask[:, :, None], order="F", dtype="uint8"))[0]
  499. if six.PY3:
  500. if 'counts' in rle:
  501. rle['counts'] = rle['counts'].decode("utf8")
  502. sg_res = {
  503. 'category': category,
  504. 'segmentation': rle,
  505. 'score': score
  506. }
  507. seg_res.append(sg_res)
  508. infer_result['mask'] = seg_res
  509. bbox_num = batch_pred['bbox_num']
  510. results = []
  511. start = 0
  512. for num in bbox_num:
  513. end = start + num
  514. curr_res = infer_result['bbox'][start:end]
  515. if 'mask' in infer_result:
  516. mask_res = infer_result['mask'][start:end]
  517. for box, mask in zip(curr_res, mask_res):
  518. box.update(mask)
  519. results.append(curr_res)
  520. start = end
  521. return results
  522. class YOLOv3(BaseDetector):
  523. def __init__(self,
  524. num_classes=80,
  525. backbone='MobileNetV1',
  526. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  527. [59, 119], [116, 90], [156, 198], [373, 326]],
  528. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  529. ignore_threshold=0.7,
  530. nms_score_threshold=0.01,
  531. nms_topk=1000,
  532. nms_keep_topk=100,
  533. nms_iou_threshold=0.45,
  534. label_smooth=False):
  535. self.init_params = locals()
  536. if backbone not in [
  537. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  538. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  539. ]:
  540. raise ValueError(
  541. "backbone: {} is not supported. Please choose one of "
  542. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  543. format(backbone))
  544. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  545. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  546. norm_type = 'sync_bn'
  547. else:
  548. norm_type = 'bn'
  549. self.backbone_name = backbone
  550. if 'MobileNetV1' in backbone:
  551. norm_type = 'bn'
  552. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  553. elif 'MobileNetV3' in backbone:
  554. backbone = self._get_backbone(
  555. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  556. elif backbone == 'ResNet50_vd_dcn':
  557. backbone = self._get_backbone(
  558. 'ResNet',
  559. norm_type=norm_type,
  560. variant='d',
  561. return_idx=[1, 2, 3],
  562. dcn_v2_stages=[3],
  563. freeze_at=-1,
  564. freeze_norm=False)
  565. elif backbone == 'ResNet34':
  566. backbone = self._get_backbone(
  567. 'ResNet',
  568. depth=34,
  569. norm_type=norm_type,
  570. return_idx=[1, 2, 3],
  571. freeze_at=-1,
  572. freeze_norm=False,
  573. norm_decay=0.)
  574. else:
  575. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  576. neck = ppdet.modeling.YOLOv3FPN(
  577. norm_type=norm_type,
  578. in_channels=[i.channels for i in backbone.out_shape])
  579. loss = ppdet.modeling.YOLOv3Loss(
  580. num_classes=num_classes,
  581. ignore_thresh=ignore_threshold,
  582. label_smooth=label_smooth)
  583. yolo_head = ppdet.modeling.YOLOv3Head(
  584. in_channels=[i.channels for i in neck.out_shape],
  585. anchors=anchors,
  586. anchor_masks=anchor_masks,
  587. num_classes=num_classes,
  588. loss=loss)
  589. post_process = ppdet.modeling.BBoxPostProcess(
  590. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  591. nms=ppdet.modeling.MultiClassNMS(
  592. score_threshold=nms_score_threshold,
  593. nms_top_k=nms_topk,
  594. keep_top_k=nms_keep_topk,
  595. nms_threshold=nms_iou_threshold))
  596. params = {
  597. 'backbone': backbone,
  598. 'neck': neck,
  599. 'yolo_head': yolo_head,
  600. 'post_process': post_process
  601. }
  602. super(YOLOv3, self).__init__(
  603. model_name='YOLOv3', num_classes=num_classes, **params)
  604. self.anchors = anchors
  605. self.anchor_masks = anchor_masks
  606. def _compose_batch_transform(self, transforms, mode='train'):
  607. if mode == 'train':
  608. default_batch_transforms = [
  609. _BatchPadding(
  610. pad_to_stride=-1, pad_gt=False), _NormalizeBox(),
  611. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  612. _Gt2YoloTarget(
  613. anchor_masks=self.anchor_masks,
  614. anchors=self.anchors,
  615. downsample_ratios=getattr(self, 'downsample_ratios',
  616. [32, 16, 8]),
  617. num_classes=self.num_classes)
  618. ]
  619. else:
  620. default_batch_transforms = [
  621. _BatchPadding(
  622. pad_to_stride=-1, pad_gt=False)
  623. ]
  624. custom_batch_transforms = []
  625. for i, op in enumerate(transforms.transforms):
  626. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  627. if mode != 'train':
  628. raise Exception(
  629. "{} cannot be present in the {} transforms. ".format(
  630. op.__class__.__name__, mode) +
  631. "Please check the {} transforms.".format(mode))
  632. custom_batch_transforms.insert(0, copy.deepcopy(op))
  633. batch_transforms = BatchCompose(custom_batch_transforms +
  634. default_batch_transforms)
  635. return batch_transforms
  636. class FasterRCNN(BaseDetector):
  637. def __init__(self,
  638. num_classes=80,
  639. backbone='ResNet50',
  640. with_fpn=True,
  641. aspect_ratios=[0.5, 1.0, 2.0],
  642. anchor_sizes=[[32], [64], [128], [256], [512]],
  643. keep_top_k=100,
  644. nms_threshold=0.5,
  645. score_threshold=0.05,
  646. fpn_num_channels=256,
  647. rpn_batch_size_per_im=256,
  648. rpn_fg_fraction=0.5,
  649. test_pre_nms_top_n=None,
  650. test_post_nms_top_n=1000):
  651. self.init_params = locals()
  652. if backbone not in [
  653. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  654. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd'
  655. ]:
  656. raise ValueError(
  657. "backbone: {} is not supported. Please choose one of "
  658. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  659. "'ResNet101', 'ResNet101_vd')".format(backbone))
  660. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  661. if backbone == 'ResNet50_vd_ssld':
  662. if not with_fpn:
  663. logging.warning(
  664. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  665. format(backbone))
  666. with_fpn = True
  667. backbone = self._get_backbone(
  668. 'ResNet',
  669. variant='d',
  670. norm_type='bn',
  671. freeze_at=0,
  672. return_idx=[0, 1, 2, 3],
  673. num_stages=4,
  674. lr_mult_list=[0.05, 0.05, 0.1, 0.15])
  675. elif 'ResNet50' in backbone:
  676. if with_fpn:
  677. backbone = self._get_backbone(
  678. 'ResNet',
  679. variant='d' if '_vd' in backbone else 'b',
  680. norm_type='bn',
  681. freeze_at=0,
  682. return_idx=[0, 1, 2, 3],
  683. num_stages=4)
  684. else:
  685. backbone = self._get_backbone(
  686. 'ResNet',
  687. variant='d' if '_vd' in backbone else 'b',
  688. norm_type='bn',
  689. freeze_at=0,
  690. return_idx=[2],
  691. num_stages=3)
  692. elif 'ResNet34' in backbone:
  693. if not with_fpn:
  694. logging.warning(
  695. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  696. format(backbone))
  697. with_fpn = True
  698. backbone = self._get_backbone(
  699. 'ResNet',
  700. depth=34,
  701. variant='d' if 'vd' in backbone else 'b',
  702. norm_type='bn',
  703. freeze_at=0,
  704. return_idx=[0, 1, 2, 3],
  705. num_stages=4)
  706. else:
  707. if not with_fpn:
  708. logging.warning(
  709. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  710. format(backbone))
  711. with_fpn = True
  712. backbone = self._get_backbone(
  713. 'ResNet',
  714. depth=101,
  715. variant='d' if 'vd' in backbone else 'b',
  716. norm_type='bn',
  717. freeze_at=0,
  718. return_idx=[0, 1, 2, 3],
  719. num_stages=4)
  720. rpn_in_channel = backbone.out_shape[0].channels
  721. if with_fpn:
  722. neck = ppdet.modeling.FPN(
  723. in_channels=[i.channels for i in backbone.out_shape],
  724. out_channel=fpn_num_channels,
  725. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  726. rpn_in_channel = neck.out_shape[0].channels
  727. anchor_generator_cfg = {
  728. 'aspect_ratios': aspect_ratios,
  729. 'anchor_sizes': anchor_sizes,
  730. 'strides': [4, 8, 16, 32, 64]
  731. }
  732. train_proposal_cfg = {
  733. 'min_size': 0.0,
  734. 'nms_thresh': .7,
  735. 'pre_nms_top_n': 2000,
  736. 'post_nms_top_n': 1000,
  737. 'topk_after_collect': True
  738. }
  739. test_proposal_cfg = {
  740. 'min_size': 0.0,
  741. 'nms_thresh': .7,
  742. 'pre_nms_top_n': 1000
  743. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  744. 'post_nms_top_n': test_post_nms_top_n
  745. }
  746. head = ppdet.modeling.TwoFCHead(out_channel=1024)
  747. roi_extractor_cfg = {
  748. 'resolution': 7,
  749. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  750. 'sampling_ratio': 0,
  751. 'aligned': True
  752. }
  753. with_pool = False
  754. else:
  755. neck = None
  756. anchor_generator_cfg = {
  757. 'aspect_ratios': aspect_ratios,
  758. 'anchor_sizes': anchor_sizes,
  759. 'strides': [16]
  760. }
  761. train_proposal_cfg = {
  762. 'min_size': 0.0,
  763. 'nms_thresh': .7,
  764. 'pre_nms_top_n': 12000,
  765. 'post_nms_top_n': 2000,
  766. 'topk_after_collect': False
  767. }
  768. test_proposal_cfg = {
  769. 'min_size': 0.0,
  770. 'nms_thresh': .7,
  771. 'pre_nms_top_n': 6000
  772. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  773. 'post_nms_top_n': test_post_nms_top_n
  774. }
  775. head = ppdet.modeling.Res5Head()
  776. roi_extractor_cfg = {
  777. 'resolution': 14,
  778. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  779. 'sampling_ratio': 0,
  780. 'aligned': True
  781. }
  782. with_pool = True
  783. rpn_target_assign_cfg = {
  784. 'batch_size_per_im': rpn_batch_size_per_im,
  785. 'fg_fraction': rpn_fg_fraction,
  786. 'negative_overlap': .3,
  787. 'positive_overlap': .7,
  788. 'use_random': True
  789. }
  790. rpn_head = ppdet.modeling.RPNHead(
  791. anchor_generator=anchor_generator_cfg,
  792. rpn_target_assign=rpn_target_assign_cfg,
  793. train_proposal=train_proposal_cfg,
  794. test_proposal=test_proposal_cfg,
  795. in_channel=rpn_in_channel)
  796. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  797. bbox_head = ppdet.modeling.BBoxHead(
  798. head=head,
  799. in_channel=head.out_shape[0].channels,
  800. roi_extractor=roi_extractor_cfg,
  801. with_pool=with_pool,
  802. bbox_assigner=bbox_assigner,
  803. num_classes=num_classes)
  804. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  805. num_classes=num_classes,
  806. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  807. nms=ppdet.modeling.MultiClassNMS(
  808. score_threshold=score_threshold,
  809. keep_top_k=keep_top_k,
  810. nms_threshold=nms_threshold))
  811. params = {
  812. 'backbone': backbone,
  813. 'neck': neck,
  814. 'rpn_head': rpn_head,
  815. 'bbox_head': bbox_head,
  816. 'bbox_post_process': bbox_post_process
  817. }
  818. self.with_fpn = with_fpn
  819. super(FasterRCNN, self).__init__(
  820. model_name='FasterRCNN', num_classes=num_classes, **params)
  821. def _compose_batch_transform(self, transforms, mode='train'):
  822. if mode == 'train':
  823. default_batch_transforms = [
  824. _BatchPadding(
  825. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  826. ]
  827. else:
  828. default_batch_transforms = [
  829. _BatchPadding(
  830. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  831. ]
  832. custom_batch_transforms = []
  833. for i, op in enumerate(transforms.transforms):
  834. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  835. if mode != 'train':
  836. raise Exception(
  837. "{} cannot be present in the {} transforms. ".format(
  838. op.__class__.__name__, mode) +
  839. "Please check the {} transforms.".format(mode))
  840. custom_batch_transforms.insert(0, copy.deepcopy(op))
  841. batch_transforms = BatchCompose(custom_batch_transforms +
  842. default_batch_transforms)
  843. return batch_transforms
  844. class PPYOLO(YOLOv3):
  845. def __init__(self,
  846. num_classes=80,
  847. backbone='ResNet50_vd_dcn',
  848. anchors=None,
  849. anchor_masks=None,
  850. use_coord_conv=True,
  851. use_iou_aware=True,
  852. use_spp=True,
  853. use_drop_block=True,
  854. scale_x_y=1.05,
  855. ignore_threshold=0.7,
  856. label_smooth=False,
  857. use_iou_loss=True,
  858. use_matrix_nms=True,
  859. nms_score_threshold=0.01,
  860. nms_topk=-1,
  861. nms_keep_topk=100,
  862. nms_iou_threshold=0.45):
  863. self.init_params = locals()
  864. if backbone not in [
  865. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  866. 'MobileNetV3_small'
  867. ]:
  868. raise ValueError(
  869. "backbone: {} is not supported. Please choose one of "
  870. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  871. format(backbone))
  872. self.backbone_name = backbone
  873. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  874. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  875. norm_type = 'sync_bn'
  876. else:
  877. norm_type = 'bn'
  878. if anchors is None and anchor_masks is None:
  879. if 'MobileNetV3' in backbone:
  880. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  881. [120, 195], [254, 235]]
  882. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  883. elif backbone == 'ResNet50_vd_dcn':
  884. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  885. [59, 119], [116, 90], [156, 198], [373, 326]]
  886. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  887. else:
  888. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  889. [344, 319]]
  890. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  891. elif anchors is None or anchor_masks is None:
  892. raise ValueError("Please define both anchors and anchor_masks.")
  893. if backbone == 'ResNet50_vd_dcn':
  894. backbone = self._get_backbone(
  895. 'ResNet',
  896. variant='d',
  897. norm_type=norm_type,
  898. return_idx=[1, 2, 3],
  899. dcn_v2_stages=[3],
  900. freeze_at=-1,
  901. freeze_norm=False,
  902. norm_decay=0.)
  903. downsample_ratios = [32, 16, 8]
  904. elif backbone == 'ResNet18_vd':
  905. backbone = self._get_backbone(
  906. 'ResNet',
  907. depth=18,
  908. variant='d',
  909. norm_type=norm_type,
  910. return_idx=[2, 3],
  911. freeze_at=-1,
  912. freeze_norm=False,
  913. norm_decay=0.)
  914. downsample_ratios = [32, 16, 8]
  915. elif backbone == 'MobileNetV3_large':
  916. backbone = self._get_backbone(
  917. 'MobileNetV3',
  918. model_name='large',
  919. norm_type=norm_type,
  920. scale=1,
  921. with_extra_blocks=False,
  922. extra_block_filters=[],
  923. feature_maps=[13, 16])
  924. downsample_ratios = [32, 16]
  925. elif backbone == 'MobileNetV3_small':
  926. backbone = self._get_backbone(
  927. 'MobileNetV3',
  928. model_name='small',
  929. norm_type=norm_type,
  930. scale=1,
  931. with_extra_blocks=False,
  932. extra_block_filters=[],
  933. feature_maps=[9, 12])
  934. downsample_ratios = [32, 16]
  935. neck = ppdet.modeling.PPYOLOFPN(
  936. norm_type=norm_type,
  937. in_channels=[i.channels for i in backbone.out_shape],
  938. coord_conv=use_coord_conv,
  939. drop_block=use_drop_block,
  940. spp=use_spp,
  941. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  942. self.backbone_name == 'ResNet18_vd') else 2)
  943. loss = ppdet.modeling.YOLOv3Loss(
  944. num_classes=num_classes,
  945. ignore_thresh=ignore_threshold,
  946. downsample=downsample_ratios,
  947. label_smooth=label_smooth,
  948. scale_x_y=scale_x_y,
  949. iou_loss=ppdet.modeling.IouLoss(
  950. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  951. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  952. if use_iou_aware else None)
  953. yolo_head = ppdet.modeling.YOLOv3Head(
  954. in_channels=[i.channels for i in neck.out_shape],
  955. anchors=anchors,
  956. anchor_masks=anchor_masks,
  957. num_classes=num_classes,
  958. loss=loss,
  959. iou_aware=use_iou_aware)
  960. if use_matrix_nms:
  961. nms = ppdet.modeling.MatrixNMS(
  962. keep_top_k=nms_keep_topk,
  963. score_threshold=nms_score_threshold,
  964. post_threshold=.05
  965. if 'MobileNetV3' in self.backbone_name else .01,
  966. nms_top_k=nms_topk,
  967. background_label=-1)
  968. else:
  969. nms = ppdet.modeling.MultiClassNMS(
  970. score_threshold=nms_score_threshold,
  971. nms_top_k=nms_topk,
  972. keep_top_k=nms_keep_topk,
  973. nms_threshold=nms_iou_threshold)
  974. post_process = ppdet.modeling.BBoxPostProcess(
  975. decode=ppdet.modeling.YOLOBox(
  976. num_classes=num_classes,
  977. conf_thresh=.005
  978. if 'MobileNetV3' in self.backbone_name else .01,
  979. scale_x_y=scale_x_y),
  980. nms=nms)
  981. params = {
  982. 'backbone': backbone,
  983. 'neck': neck,
  984. 'yolo_head': yolo_head,
  985. 'post_process': post_process
  986. }
  987. super(YOLOv3, self).__init__(
  988. model_name='YOLOv3', num_classes=num_classes, **params)
  989. self.anchors = anchors
  990. self.anchor_masks = anchor_masks
  991. self.downsample_ratios = downsample_ratios
  992. self.model_name = 'PPYOLO'
  993. class PPYOLOTiny(YOLOv3):
  994. def __init__(self,
  995. num_classes=80,
  996. backbone='MobileNetV3',
  997. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  998. [60, 170], [220, 125], [128, 222], [264, 266]],
  999. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1000. use_iou_aware=False,
  1001. use_spp=True,
  1002. use_drop_block=True,
  1003. scale_x_y=1.05,
  1004. ignore_threshold=0.5,
  1005. label_smooth=False,
  1006. use_iou_loss=True,
  1007. use_matrix_nms=False,
  1008. nms_score_threshold=0.005,
  1009. nms_topk=1000,
  1010. nms_keep_topk=100,
  1011. nms_iou_threshold=0.45):
  1012. self.init_params = locals()
  1013. if backbone != 'MobileNetV3':
  1014. logging.warning(
  1015. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1016. "Backbone is forcibly set to MobileNetV3.")
  1017. self.backbone_name = 'MobileNetV3'
  1018. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1019. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1020. norm_type = 'sync_bn'
  1021. else:
  1022. norm_type = 'bn'
  1023. backbone = self._get_backbone(
  1024. 'MobileNetV3',
  1025. model_name='large',
  1026. norm_type=norm_type,
  1027. scale=.5,
  1028. with_extra_blocks=False,
  1029. extra_block_filters=[],
  1030. feature_maps=[7, 13, 16])
  1031. downsample_ratios = [32, 16, 8]
  1032. neck = ppdet.modeling.PPYOLOTinyFPN(
  1033. detection_block_channels=[160, 128, 96],
  1034. in_channels=[i.channels for i in backbone.out_shape],
  1035. spp=use_spp,
  1036. drop_block=use_drop_block)
  1037. loss = ppdet.modeling.YOLOv3Loss(
  1038. num_classes=num_classes,
  1039. ignore_thresh=ignore_threshold,
  1040. downsample=downsample_ratios,
  1041. label_smooth=label_smooth,
  1042. scale_x_y=scale_x_y,
  1043. iou_loss=ppdet.modeling.IouLoss(
  1044. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1045. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1046. if use_iou_aware else None)
  1047. yolo_head = ppdet.modeling.YOLOv3Head(
  1048. in_channels=[i.channels for i in neck.out_shape],
  1049. anchors=anchors,
  1050. anchor_masks=anchor_masks,
  1051. num_classes=num_classes,
  1052. loss=loss,
  1053. iou_aware=use_iou_aware)
  1054. if use_matrix_nms:
  1055. nms = ppdet.modeling.MatrixNMS(
  1056. keep_top_k=nms_keep_topk,
  1057. score_threshold=nms_score_threshold,
  1058. post_threshold=.05,
  1059. nms_top_k=nms_topk,
  1060. background_label=-1)
  1061. else:
  1062. nms = ppdet.modeling.MultiClassNMS(
  1063. score_threshold=nms_score_threshold,
  1064. nms_top_k=nms_topk,
  1065. keep_top_k=nms_keep_topk,
  1066. nms_threshold=nms_iou_threshold)
  1067. post_process = ppdet.modeling.BBoxPostProcess(
  1068. decode=ppdet.modeling.YOLOBox(
  1069. num_classes=num_classes,
  1070. conf_thresh=.005,
  1071. downsample_ratio=32,
  1072. clip_bbox=True,
  1073. scale_x_y=scale_x_y),
  1074. nms=nms)
  1075. params = {
  1076. 'backbone': backbone,
  1077. 'neck': neck,
  1078. 'yolo_head': yolo_head,
  1079. 'post_process': post_process
  1080. }
  1081. super(YOLOv3, self).__init__(
  1082. model_name='YOLOv3', num_classes=num_classes, **params)
  1083. self.anchors = anchors
  1084. self.anchor_masks = anchor_masks
  1085. self.downsample_ratios = downsample_ratios
  1086. self.num_max_boxes = 100
  1087. self.model_name = 'PPYOLOTiny'
  1088. class PPYOLOv2(YOLOv3):
  1089. def __init__(self,
  1090. num_classes=80,
  1091. backbone='ResNet50_vd_dcn',
  1092. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1093. [59, 119], [116, 90], [156, 198], [373, 326]],
  1094. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1095. use_iou_aware=True,
  1096. use_spp=True,
  1097. use_drop_block=True,
  1098. scale_x_y=1.05,
  1099. ignore_threshold=0.7,
  1100. label_smooth=False,
  1101. use_iou_loss=True,
  1102. use_matrix_nms=True,
  1103. nms_score_threshold=0.01,
  1104. nms_topk=-1,
  1105. nms_keep_topk=100,
  1106. nms_iou_threshold=0.45):
  1107. self.init_params = locals()
  1108. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1109. raise ValueError(
  1110. "backbone: {} is not supported. Please choose one of "
  1111. "('ResNet50_vd_dcn', 'ResNet18_vd')".format(backbone))
  1112. self.backbone_name = backbone
  1113. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1114. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1115. norm_type = 'sync_bn'
  1116. else:
  1117. norm_type = 'bn'
  1118. if backbone == 'ResNet50_vd_dcn':
  1119. backbone = self._get_backbone(
  1120. 'ResNet',
  1121. variant='d',
  1122. norm_type=norm_type,
  1123. return_idx=[1, 2, 3],
  1124. dcn_v2_stages=[3],
  1125. freeze_at=-1,
  1126. freeze_norm=False,
  1127. norm_decay=0.)
  1128. downsample_ratios = [32, 16, 8]
  1129. elif backbone == 'ResNet101_vd_dcn':
  1130. backbone = self._get_backbone(
  1131. 'ResNet',
  1132. depth=101,
  1133. variant='d',
  1134. norm_type=norm_type,
  1135. return_idx=[1, 2, 3],
  1136. dcn_v2_stages=[3],
  1137. freeze_at=-1,
  1138. freeze_norm=False,
  1139. norm_decay=0.)
  1140. downsample_ratios = [32, 16, 8]
  1141. neck = ppdet.modeling.PPYOLOPAN(
  1142. norm_type=norm_type,
  1143. in_channels=[i.channels for i in backbone.out_shape],
  1144. drop_block=use_drop_block,
  1145. block_size=3,
  1146. keep_prob=.9,
  1147. spp=use_spp)
  1148. loss = ppdet.modeling.YOLOv3Loss(
  1149. num_classes=num_classes,
  1150. ignore_thresh=ignore_threshold,
  1151. downsample=downsample_ratios,
  1152. label_smooth=label_smooth,
  1153. scale_x_y=scale_x_y,
  1154. iou_loss=ppdet.modeling.IouLoss(
  1155. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1156. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1157. if use_iou_aware else None)
  1158. yolo_head = ppdet.modeling.YOLOv3Head(
  1159. in_channels=[i.channels for i in neck.out_shape],
  1160. anchors=anchors,
  1161. anchor_masks=anchor_masks,
  1162. num_classes=num_classes,
  1163. loss=loss,
  1164. iou_aware=use_iou_aware,
  1165. iou_aware_factor=.5)
  1166. if use_matrix_nms:
  1167. nms = ppdet.modeling.MatrixNMS(
  1168. keep_top_k=nms_keep_topk,
  1169. score_threshold=nms_score_threshold,
  1170. post_threshold=.01,
  1171. nms_top_k=nms_topk,
  1172. background_label=-1)
  1173. else:
  1174. nms = ppdet.modeling.MultiClassNMS(
  1175. score_threshold=nms_score_threshold,
  1176. nms_top_k=nms_topk,
  1177. keep_top_k=nms_keep_topk,
  1178. nms_threshold=nms_iou_threshold)
  1179. post_process = ppdet.modeling.BBoxPostProcess(
  1180. decode=ppdet.modeling.YOLOBox(
  1181. num_classes=num_classes,
  1182. conf_thresh=.01,
  1183. downsample_ratio=32,
  1184. clip_bbox=True,
  1185. scale_x_y=scale_x_y),
  1186. nms=nms)
  1187. params = {
  1188. 'backbone': backbone,
  1189. 'neck': neck,
  1190. 'yolo_head': yolo_head,
  1191. 'post_process': post_process
  1192. }
  1193. super(YOLOv3, self).__init__(
  1194. model_name='YOLOv3', num_classes=num_classes, **params)
  1195. self.anchors = anchors
  1196. self.anchor_masks = anchor_masks
  1197. self.downsample_ratios = downsample_ratios
  1198. self.num_max_boxes = 100
  1199. self.model_name = 'PPYOLOv2'
  1200. class MaskRCNN(BaseDetector):
  1201. def __init__(self,
  1202. num_classes=80,
  1203. backbone='ResNet50_vd',
  1204. with_fpn=True,
  1205. aspect_ratios=[0.5, 1.0, 2.0],
  1206. anchor_sizes=[[32], [64], [128], [256], [512]],
  1207. keep_top_k=100,
  1208. nms_threshold=0.5,
  1209. score_threshold=0.05,
  1210. fpn_num_channels=256,
  1211. rpn_batch_size_per_im=256,
  1212. rpn_fg_fraction=0.5,
  1213. test_pre_nms_top_n=None,
  1214. test_post_nms_top_n=1000):
  1215. self.init_params = locals()
  1216. if backbone not in [
  1217. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1218. 'ResNet101_vd'
  1219. ]:
  1220. raise ValueError(
  1221. "backbone: {} is not supported. Please choose one of "
  1222. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1223. format(backbone))
  1224. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1225. if backbone == 'ResNet50':
  1226. if with_fpn:
  1227. backbone = self._get_backbone(
  1228. 'ResNet',
  1229. norm_type='bn',
  1230. freeze_at=0,
  1231. return_idx=[0, 1, 2, 3],
  1232. num_stages=4)
  1233. else:
  1234. backbone = self._get_backbone(
  1235. 'ResNet',
  1236. norm_type='bn',
  1237. freeze_at=0,
  1238. return_idx=[2],
  1239. num_stages=3)
  1240. elif 'ResNet50_vd' in backbone:
  1241. if not with_fpn:
  1242. logging.warning(
  1243. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1244. format(backbone))
  1245. with_fpn = True
  1246. backbone = self._get_backbone(
  1247. 'ResNet',
  1248. variant='d',
  1249. norm_type='bn',
  1250. freeze_at=0,
  1251. return_idx=[0, 1, 2, 3],
  1252. num_stages=4,
  1253. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1254. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0])
  1255. else:
  1256. if not with_fpn:
  1257. logging.warning(
  1258. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1259. format(backbone))
  1260. with_fpn = True
  1261. backbone = self._get_backbone(
  1262. 'ResNet',
  1263. variant='d' if '_vd' in backbone else 'b',
  1264. depth=101,
  1265. norm_type='bn',
  1266. freeze_at=0,
  1267. return_idx=[0, 1, 2, 3],
  1268. num_stages=4)
  1269. rpn_in_channel = backbone.out_shape[0].channels
  1270. if with_fpn:
  1271. neck = ppdet.modeling.FPN(
  1272. in_channels=[i.channels for i in backbone.out_shape],
  1273. out_channel=fpn_num_channels,
  1274. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1275. rpn_in_channel = neck.out_shape[0].channels
  1276. anchor_generator_cfg = {
  1277. 'aspect_ratios': aspect_ratios,
  1278. 'anchor_sizes': anchor_sizes,
  1279. 'strides': [4, 8, 16, 32, 64]
  1280. }
  1281. train_proposal_cfg = {
  1282. 'min_size': 0.0,
  1283. 'nms_thresh': .7,
  1284. 'pre_nms_top_n': 2000,
  1285. 'post_nms_top_n': 1000,
  1286. 'topk_after_collect': True
  1287. }
  1288. test_proposal_cfg = {
  1289. 'min_size': 0.0,
  1290. 'nms_thresh': .7,
  1291. 'pre_nms_top_n': 1000
  1292. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1293. 'post_nms_top_n': test_post_nms_top_n
  1294. }
  1295. bb_head = ppdet.modeling.TwoFCHead(
  1296. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1297. bb_roi_extractor_cfg = {
  1298. 'resolution': 7,
  1299. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1300. 'sampling_ratio': 0,
  1301. 'aligned': True
  1302. }
  1303. with_pool = False
  1304. m_head = ppdet.modeling.MaskFeat(
  1305. in_channel=neck.out_shape[0].channels,
  1306. out_channel=256,
  1307. num_convs=4)
  1308. m_roi_extractor_cfg = {
  1309. 'resolution': 14,
  1310. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1311. 'sampling_ratio': 0,
  1312. 'aligned': True
  1313. }
  1314. mask_assigner = MaskAssigner(
  1315. num_classes=num_classes, mask_resolution=28)
  1316. share_bbox_feat = False
  1317. else:
  1318. neck = None
  1319. anchor_generator_cfg = {
  1320. 'aspect_ratios': aspect_ratios,
  1321. 'anchor_sizes': anchor_sizes,
  1322. 'strides': [16]
  1323. }
  1324. train_proposal_cfg = {
  1325. 'min_size': 0.0,
  1326. 'nms_thresh': .7,
  1327. 'pre_nms_top_n': 12000,
  1328. 'post_nms_top_n': 2000,
  1329. 'topk_after_collect': False
  1330. }
  1331. test_proposal_cfg = {
  1332. 'min_size': 0.0,
  1333. 'nms_thresh': .7,
  1334. 'pre_nms_top_n': 6000
  1335. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1336. 'post_nms_top_n': test_post_nms_top_n
  1337. }
  1338. bb_head = ppdet.modeling.Res5Head()
  1339. bb_roi_extractor_cfg = {
  1340. 'resolution': 14,
  1341. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1342. 'sampling_ratio': 0,
  1343. 'aligned': True
  1344. }
  1345. with_pool = True
  1346. m_head = ppdet.modeling.MaskFeat(
  1347. in_channel=bb_head.out_shape[0].channels,
  1348. out_channel=256,
  1349. num_convs=0)
  1350. m_roi_extractor_cfg = {
  1351. 'resolution': 14,
  1352. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1353. 'sampling_ratio': 0,
  1354. 'aligned': True
  1355. }
  1356. mask_assigner = MaskAssigner(
  1357. num_classes=num_classes, mask_resolution=14)
  1358. share_bbox_feat = True
  1359. rpn_target_assign_cfg = {
  1360. 'batch_size_per_im': rpn_batch_size_per_im,
  1361. 'fg_fraction': rpn_fg_fraction,
  1362. 'negative_overlap': .3,
  1363. 'positive_overlap': .7,
  1364. 'use_random': True
  1365. }
  1366. rpn_head = ppdet.modeling.RPNHead(
  1367. anchor_generator=anchor_generator_cfg,
  1368. rpn_target_assign=rpn_target_assign_cfg,
  1369. train_proposal=train_proposal_cfg,
  1370. test_proposal=test_proposal_cfg,
  1371. in_channel=rpn_in_channel)
  1372. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1373. bbox_head = ppdet.modeling.BBoxHead(
  1374. head=bb_head,
  1375. in_channel=bb_head.out_shape[0].channels,
  1376. roi_extractor=bb_roi_extractor_cfg,
  1377. with_pool=with_pool,
  1378. bbox_assigner=bbox_assigner,
  1379. num_classes=num_classes)
  1380. mask_head = ppdet.modeling.MaskHead(
  1381. head=m_head,
  1382. roi_extractor=m_roi_extractor_cfg,
  1383. mask_assigner=mask_assigner,
  1384. share_bbox_feat=share_bbox_feat,
  1385. num_classes=num_classes)
  1386. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1387. num_classes=num_classes,
  1388. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1389. nms=ppdet.modeling.MultiClassNMS(
  1390. score_threshold=score_threshold,
  1391. keep_top_k=keep_top_k,
  1392. nms_threshold=nms_threshold))
  1393. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1394. params = {
  1395. 'backbone': backbone,
  1396. 'neck': neck,
  1397. 'rpn_head': rpn_head,
  1398. 'bbox_head': bbox_head,
  1399. 'mask_head': mask_head,
  1400. 'bbox_post_process': bbox_post_process,
  1401. 'mask_post_process': mask_post_process
  1402. }
  1403. self.with_fpn = with_fpn
  1404. super(MaskRCNN, self).__init__(
  1405. model_name='MaskRCNN', num_classes=num_classes, **params)
  1406. def _compose_batch_transform(self, transforms, mode='train'):
  1407. if mode == 'train':
  1408. default_batch_transforms = [
  1409. _BatchPadding(
  1410. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  1411. ]
  1412. else:
  1413. default_batch_transforms = [
  1414. _BatchPadding(
  1415. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  1416. ]
  1417. custom_batch_transforms = []
  1418. for i, op in enumerate(transforms.transforms):
  1419. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1420. if mode != 'train':
  1421. raise Exception(
  1422. "{} cannot be present in the {} transforms. ".format(
  1423. op.__class__.__name__, mode) +
  1424. "Please check the {} transforms.".format(mode))
  1425. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1426. batch_transforms = BatchCompose(custom_batch_transforms +
  1427. default_batch_transforms)
  1428. return batch_transforms