detector.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import ppdet
  24. from ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. del self.init_params['params']
  41. super(BaseDetector, self).__init__('detector')
  42. if not hasattr(ppdet.modeling, model_name):
  43. raise Exception("ERROR: There's no model named {}.".format(
  44. model_name))
  45. self.model_name = model_name
  46. self.num_classes = num_classes
  47. self.labels = None
  48. self.net = self.build_net(**params)
  49. def build_net(self, **params):
  50. with paddle.utils.unique_name.guard():
  51. net = ppdet.modeling.__dict__[self.model_name](**params)
  52. return net
  53. def get_test_inputs(self, image_shape):
  54. input_spec = [{
  55. "image": InputSpec(
  56. shape=[None, 3] + image_shape, name='image', dtype='float32'),
  57. "im_shape": InputSpec(
  58. shape=[None, 2], name='im_shape', dtype='float32'),
  59. "scale_factor": InputSpec(
  60. shape=[None, 2], name='scale_factor', dtype='float32')
  61. }]
  62. return input_spec
  63. def _get_backbone(self, backbone_name, **params):
  64. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  65. return backbone
  66. def run(self, net, inputs, mode):
  67. net_out = net(inputs)
  68. if mode in ['train', 'eval']:
  69. outputs = net_out
  70. else:
  71. for key in ['im_shape', 'scale_factor']:
  72. net_out[key] = inputs[key]
  73. outputs = dict()
  74. for key in net_out:
  75. outputs[key] = net_out[key].numpy()
  76. return outputs
  77. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  78. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  79. num_steps_each_epoch):
  80. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  81. values = [(lr_decay_gamma**i) * learning_rate
  82. for i in range(len(lr_decay_epochs) + 1)]
  83. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  84. boundaries=boundaries, values=values)
  85. if warmup_steps > 0:
  86. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  87. logging.error(
  88. "In function train(), parameters should satisfy: "
  89. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  90. exit=False)
  91. logging.error(
  92. "See this doc for more information: "
  93. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  94. exit=False)
  95. scheduler = paddle.optimizer.lr.LinearWarmup(
  96. learning_rate=scheduler,
  97. warmup_steps=warmup_steps,
  98. start_lr=warmup_start_lr,
  99. end_lr=learning_rate)
  100. optimizer = paddle.optimizer.Momentum(
  101. scheduler,
  102. momentum=.9,
  103. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  104. parameters=parameters)
  105. return optimizer
  106. def train(self,
  107. num_epochs,
  108. train_dataset,
  109. train_batch_size=64,
  110. eval_dataset=None,
  111. optimizer=None,
  112. save_interval_epochs=1,
  113. log_interval_steps=10,
  114. save_dir='output',
  115. pretrain_weights='IMAGENET',
  116. learning_rate=.001,
  117. warmup_steps=0,
  118. warmup_start_lr=0.0,
  119. lr_decay_epochs=(216, 243),
  120. lr_decay_gamma=0.1,
  121. metric=None,
  122. use_ema=False,
  123. early_stop=False,
  124. early_stop_patience=5,
  125. use_vdl=True):
  126. """
  127. Train the model.
  128. Args:
  129. num_epochs(int): The number of epochs.
  130. train_dataset(paddlex.dataset): Training dataset.
  131. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  132. eval_dataset(paddlex.dataset, optional):
  133. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  134. optimizer(paddle.optimizer.Optimizer or None, optional):
  135. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  136. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  137. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  138. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  139. pretrain_weights(str or None, optional):
  140. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  141. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  142. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  143. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  144. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  145. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  146. metric({'VOC', 'COCO', None}, optional):
  147. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  148. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  149. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  150. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  151. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  152. """
  153. if train_dataset.__class__.__name__ == 'VOCDetection':
  154. train_dataset.data_fields = {
  155. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  156. 'difficult'
  157. }
  158. elif train_dataset.__class__.__name__ == 'CocoDetection':
  159. if self.__class__.__name__ == 'MaskRCNN':
  160. train_dataset.data_fields = {
  161. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  162. 'gt_poly', 'is_crowd'
  163. }
  164. else:
  165. train_dataset.data_fields = {
  166. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  167. 'is_crowd'
  168. }
  169. if metric is None:
  170. if eval_dataset.__class__.__name__ == 'VOCDetection':
  171. self.metric = 'voc'
  172. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  173. self.metric = 'coco'
  174. else:
  175. assert metric.lower() in ['coco', 'voc'], \
  176. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  177. self.metric = metric.lower()
  178. train_dataset.batch_transforms = self._compose_batch_transform(
  179. train_dataset.transforms, mode='train')
  180. self.labels = train_dataset.labels
  181. # build optimizer if not defined
  182. if optimizer is None:
  183. num_steps_each_epoch = len(train_dataset) // train_batch_size
  184. self.optimizer = self.default_optimizer(
  185. parameters=self.net.parameters(),
  186. learning_rate=learning_rate,
  187. warmup_steps=warmup_steps,
  188. warmup_start_lr=warmup_start_lr,
  189. lr_decay_epochs=lr_decay_epochs,
  190. lr_decay_gamma=lr_decay_gamma,
  191. num_steps_each_epoch=num_steps_each_epoch)
  192. else:
  193. self.optimizer = optimizer
  194. # initiate weights
  195. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  196. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  197. [self.model_name, self.backbone_name])]:
  198. logging.warning(
  199. "Path of pretrain_weights('{}') does not exist!".format(
  200. pretrain_weights))
  201. pretrain_weights = det_pretrain_weights_dict['_'.join(
  202. [self.model_name, self.backbone_name])][0]
  203. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  204. "If don't want to use pretrain weights, "
  205. "set pretrain_weights to be None.".format(
  206. pretrain_weights))
  207. pretrained_dir = osp.join(save_dir, 'pretrain')
  208. self.net_initialize(
  209. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  210. if use_ema:
  211. ema = ExponentialMovingAverage(
  212. decay=.9998, model=self.net, use_thres_step=True)
  213. else:
  214. ema = None
  215. # start train loop
  216. self.train_loop(
  217. num_epochs=num_epochs,
  218. train_dataset=train_dataset,
  219. train_batch_size=train_batch_size,
  220. eval_dataset=eval_dataset,
  221. save_interval_epochs=save_interval_epochs,
  222. log_interval_steps=log_interval_steps,
  223. save_dir=save_dir,
  224. ema=ema,
  225. early_stop=early_stop,
  226. early_stop_patience=early_stop_patience,
  227. use_vdl=use_vdl)
  228. def evaluate(self,
  229. eval_dataset,
  230. batch_size=1,
  231. metric=None,
  232. return_details=False):
  233. """
  234. Evaluate the model.
  235. Args:
  236. eval_dataset(paddlex.dataset): Evaluation dataset.
  237. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  238. metric({'VOC', 'COCO', None}, optional):
  239. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  240. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  241. Returns:
  242. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  243. """
  244. if eval_dataset.__class__.__name__ == 'VOCDetection':
  245. eval_dataset.data_fields = {
  246. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  247. 'difficult'
  248. }
  249. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  250. if self.__class__.__name__ == 'MaskRCNN':
  251. eval_dataset.data_fields = {
  252. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  253. 'gt_poly', 'is_crowd'
  254. }
  255. else:
  256. eval_dataset.data_fields = {
  257. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  258. 'is_crowd'
  259. }
  260. eval_dataset.batch_transforms = self._compose_batch_transform(
  261. eval_dataset.transforms, mode='eval')
  262. arrange_transforms(
  263. model_type=self.model_type,
  264. transforms=eval_dataset.transforms,
  265. mode='eval')
  266. self.net.eval()
  267. nranks = paddle.distributed.get_world_size()
  268. local_rank = paddle.distributed.get_rank()
  269. if nranks > 1:
  270. # Initialize parallel environment if not done.
  271. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  272. ):
  273. paddle.distributed.init_parallel_env()
  274. if batch_size > 1:
  275. logging.warning(
  276. "Detector only supports single card evaluation with batch_size=1 "
  277. "during evaluation, so batch_size is forcibly set to 1.")
  278. batch_size = 1
  279. if nranks < 2 or local_rank == 0:
  280. self.eval_data_loader = self.build_data_loader(
  281. eval_dataset, batch_size=batch_size, mode='eval')
  282. is_bbox_normalized = False
  283. if eval_dataset.batch_transforms is not None:
  284. is_bbox_normalized = any(
  285. isinstance(t, _NormalizeBox)
  286. for t in eval_dataset.batch_transforms.batch_transforms)
  287. if metric is None:
  288. if getattr(self, 'metric', None) is not None:
  289. if self.metric == 'voc':
  290. eval_metric = VOCMetric(
  291. labels=eval_dataset.labels,
  292. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  293. is_bbox_normalized=is_bbox_normalized,
  294. classwise=False)
  295. else:
  296. eval_metric = COCOMetric(
  297. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  298. classwise=False)
  299. else:
  300. if eval_dataset.__class__.__name__ == 'VOCDetection':
  301. eval_metric = VOCMetric(
  302. labels=eval_dataset.labels,
  303. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  304. is_bbox_normalized=is_bbox_normalized,
  305. classwise=False)
  306. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  307. eval_metric = COCOMetric(
  308. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  309. classwise=False)
  310. else:
  311. assert metric.lower() in ['coco', 'voc'], \
  312. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  313. if metric.lower() == 'coco':
  314. eval_metric = COCOMetric(
  315. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  316. classwise=False)
  317. else:
  318. eval_metric = VOCMetric(
  319. labels=eval_dataset.labels,
  320. is_bbox_normalized=is_bbox_normalized,
  321. classwise=False)
  322. scores = collections.OrderedDict()
  323. logging.info(
  324. "Start to evaluate(total_samples={}, total_steps={})...".
  325. format(eval_dataset.num_samples, eval_dataset.num_samples))
  326. with paddle.no_grad():
  327. for step, data in enumerate(self.eval_data_loader):
  328. outputs = self.run(self.net, data, 'eval')
  329. eval_metric.update(data, outputs)
  330. eval_metric.accumulate()
  331. self.eval_details = eval_metric.details
  332. scores.update(eval_metric.get())
  333. eval_metric.reset()
  334. if return_details:
  335. return scores, self.eval_details
  336. return scores
  337. def predict(self, img_file, transforms=None):
  338. """
  339. Do inference.
  340. Args:
  341. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  342. Image path or decoded image data in a BGR format, which also could constitute a list,
  343. meaning all images to be predicted as a mini-batch.
  344. transforms(paddlex.transforms.Compose or None, optional):
  345. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  346. Returns:
  347. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  348. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  349. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  350. category_id(int): the predicted category ID
  351. category(str): category name
  352. bbox(list): bounding box in [x, y, w, h] format
  353. score(str): confidence
  354. """
  355. if transforms is None and not hasattr(self, 'test_transforms'):
  356. raise Exception("transforms need to be defined, now is None.")
  357. if transforms is None:
  358. transforms = self.test_transforms
  359. if isinstance(img_file, (str, np.ndarray)):
  360. images = [img_file]
  361. else:
  362. images = img_file
  363. batch_samples = self._preprocess(images, transforms)
  364. self.net.eval()
  365. outputs = self.run(self.net, batch_samples, 'test')
  366. prediction = self._postprocess(outputs)
  367. if isinstance(img_file, (str, np.ndarray)):
  368. prediction = prediction[0]
  369. return prediction
  370. def _preprocess(self, images, transforms):
  371. arrange_transforms(
  372. model_type=self.model_type, transforms=transforms, mode='test')
  373. batch_samples = list()
  374. for im in images:
  375. sample = {'image': im}
  376. batch_samples.append(transforms(sample))
  377. batch_transforms = self._compose_batch_transform(transforms, 'test')
  378. batch_samples = batch_transforms(batch_samples)
  379. for k, v in batch_samples.items():
  380. batch_samples[k] = paddle.to_tensor(v)
  381. return batch_samples
  382. def _postprocess(self, batch_pred):
  383. infer_result = {}
  384. if 'bbox' in batch_pred:
  385. bboxes = batch_pred['bbox']
  386. bbox_nums = batch_pred['bbox_num']
  387. det_res = []
  388. k = 0
  389. for i in range(len(bbox_nums)):
  390. det_nums = bbox_nums[i]
  391. for j in range(det_nums):
  392. dt = bboxes[k]
  393. k = k + 1
  394. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  395. if int(num_id) < 0:
  396. continue
  397. category = self.labels[int(num_id)]
  398. w = xmax - xmin
  399. h = ymax - ymin
  400. bbox = [xmin, ymin, w, h]
  401. dt_res = {
  402. 'category_id': int(num_id),
  403. 'category': category,
  404. 'bbox': bbox,
  405. 'score': score
  406. }
  407. det_res.append(dt_res)
  408. infer_result['bbox'] = det_res
  409. if 'mask' in batch_pred:
  410. masks = batch_pred['mask']
  411. bboxes = batch_pred['bbox']
  412. mask_nums = batch_pred['bbox_num']
  413. seg_res = []
  414. k = 0
  415. for i in range(len(mask_nums)):
  416. det_nums = mask_nums[i]
  417. for j in range(det_nums):
  418. mask = masks[k].astype(np.uint8)
  419. score = float(bboxes[k][1])
  420. label = int(bboxes[k][0])
  421. k = k + 1
  422. if label == -1:
  423. continue
  424. category = self.labels[int(label)]
  425. import pycocotools.mask as mask_util
  426. rle = mask_util.encode(
  427. np.array(
  428. mask[:, :, None], order="F", dtype="uint8"))[0]
  429. if six.PY3:
  430. if 'counts' in rle:
  431. rle['counts'] = rle['counts'].decode("utf8")
  432. sg_res = {
  433. 'category': category,
  434. 'segmentation': rle,
  435. 'score': score
  436. }
  437. seg_res.append(sg_res)
  438. infer_result['mask'] = seg_res
  439. bbox_num = batch_pred['bbox_num']
  440. results = []
  441. start = 0
  442. for num in bbox_num:
  443. end = start + num
  444. curr_res = infer_result['bbox'][start:end]
  445. if 'mask' in infer_result:
  446. mask_res = infer_result['mask'][start:end]
  447. for box, mask in zip(curr_res, mask_res):
  448. box.update(mask)
  449. results.append(curr_res)
  450. start = end
  451. return results
  452. class YOLOv3(BaseDetector):
  453. def __init__(self,
  454. num_classes=80,
  455. backbone='MobileNetV1',
  456. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  457. [59, 119], [116, 90], [156, 198], [373, 326]],
  458. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  459. ignore_threshold=0.7,
  460. nms_score_threshold=0.01,
  461. nms_topk=1000,
  462. nms_keep_topk=100,
  463. nms_iou_threshold=0.45,
  464. label_smooth=False):
  465. self.init_params = locals()
  466. if backbone not in [
  467. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  468. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  469. ]:
  470. raise ValueError(
  471. "backbone: {} is not supported. Please choose one of "
  472. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  473. format(backbone))
  474. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  475. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  476. norm_type = 'sync_bn'
  477. else:
  478. norm_type = 'bn'
  479. self.backbone_name = backbone
  480. if 'MobileNetV1' in backbone:
  481. norm_type = 'bn'
  482. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  483. elif 'MobileNetV3' in backbone:
  484. backbone = self._get_backbone(
  485. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  486. elif backbone == 'ResNet50_vd_dcn':
  487. backbone = self._get_backbone(
  488. 'ResNet',
  489. norm_type=norm_type,
  490. variant='d',
  491. return_idx=[1, 2, 3],
  492. dcn_v2_stages=[3],
  493. freeze_at=-1,
  494. freeze_norm=False)
  495. elif backbone == 'ResNet34':
  496. backbone = self._get_backbone(
  497. 'ResNet',
  498. depth=34,
  499. norm_type=norm_type,
  500. return_idx=[1, 2, 3],
  501. freeze_at=-1,
  502. freeze_norm=False,
  503. norm_decay=0.)
  504. else:
  505. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  506. neck = ppdet.modeling.YOLOv3FPN(
  507. norm_type=norm_type,
  508. in_channels=[i.channels for i in backbone.out_shape])
  509. loss = ppdet.modeling.YOLOv3Loss(
  510. num_classes=num_classes,
  511. ignore_thresh=ignore_threshold,
  512. label_smooth=label_smooth)
  513. yolo_head = ppdet.modeling.YOLOv3Head(
  514. in_channels=[i.channels for i in neck.out_shape],
  515. anchors=anchors,
  516. anchor_masks=anchor_masks,
  517. num_classes=num_classes,
  518. loss=loss)
  519. post_process = ppdet.modeling.BBoxPostProcess(
  520. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  521. nms=ppdet.modeling.MultiClassNMS(
  522. score_threshold=nms_score_threshold,
  523. nms_top_k=nms_topk,
  524. keep_top_k=nms_keep_topk,
  525. nms_threshold=nms_iou_threshold))
  526. params = {
  527. 'backbone': backbone,
  528. 'neck': neck,
  529. 'yolo_head': yolo_head,
  530. 'post_process': post_process
  531. }
  532. super(YOLOv3, self).__init__(
  533. model_name='YOLOv3', num_classes=num_classes, **params)
  534. self.anchors = anchors
  535. self.anchor_masks = anchor_masks
  536. def _compose_batch_transform(self, transforms, mode='train'):
  537. if mode == 'train':
  538. default_batch_transforms = [
  539. _BatchPadding(
  540. pad_to_stride=-1, pad_gt=False), _NormalizeBox(),
  541. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  542. _Gt2YoloTarget(
  543. anchor_masks=self.anchor_masks,
  544. anchors=self.anchors,
  545. downsample_ratios=getattr(self, 'downsample_ratios',
  546. [32, 16, 8]),
  547. num_classes=self.num_classes)
  548. ]
  549. else:
  550. default_batch_transforms = [
  551. _BatchPadding(
  552. pad_to_stride=-1, pad_gt=False)
  553. ]
  554. custom_batch_transforms = []
  555. for i, op in enumerate(transforms.transforms):
  556. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  557. if mode != 'train':
  558. raise Exception(
  559. "{} cannot be present in the {} transforms. ".format(
  560. op.__class__.__name__, mode) +
  561. "Please check the {} transforms.".format(mode))
  562. custom_batch_transforms.insert(0, copy.deepcopy(op))
  563. batch_transforms = BatchCompose(custom_batch_transforms +
  564. default_batch_transforms)
  565. return batch_transforms
  566. class FasterRCNN(BaseDetector):
  567. def __init__(self,
  568. num_classes=80,
  569. backbone='ResNet50',
  570. with_fpn=True,
  571. aspect_ratios=[0.5, 1.0, 2.0],
  572. anchor_sizes=[[32], [64], [128], [256], [512]],
  573. keep_top_k=100,
  574. nms_threshold=0.5,
  575. score_threshold=0.05,
  576. fpn_num_channels=256,
  577. rpn_batch_size_per_im=256,
  578. rpn_fg_fraction=0.5,
  579. test_pre_nms_top_n=None,
  580. test_post_nms_top_n=1000):
  581. self.init_params = locals()
  582. if backbone not in [
  583. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  584. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd'
  585. ]:
  586. raise ValueError(
  587. "backbone: {} is not supported. Please choose one of "
  588. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  589. "'ResNet101', 'ResNet101_vd')".format(backbone))
  590. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  591. if backbone == 'ResNet50_vd_ssld':
  592. if not with_fpn:
  593. logging.warning(
  594. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  595. format(backbone))
  596. with_fpn = True
  597. backbone = self._get_backbone(
  598. 'ResNet',
  599. variant='d',
  600. norm_type='bn',
  601. freeze_at=0,
  602. return_idx=[0, 1, 2, 3],
  603. num_stages=4,
  604. lr_mult_list=[0.05, 0.05, 0.1, 0.15])
  605. elif 'ResNet50' in backbone:
  606. if with_fpn:
  607. backbone = self._get_backbone(
  608. 'ResNet',
  609. variant='d' if '_vd' in backbone else 'b',
  610. norm_type='bn',
  611. freeze_at=0,
  612. return_idx=[0, 1, 2, 3],
  613. num_stages=4)
  614. else:
  615. backbone = self._get_backbone(
  616. 'ResNet',
  617. variant='d' if '_vd' in backbone else 'b',
  618. norm_type='bn',
  619. freeze_at=0,
  620. return_idx=[2],
  621. num_stages=3)
  622. elif 'ResNet34' in backbone:
  623. if not with_fpn:
  624. logging.warning(
  625. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  626. format(backbone))
  627. with_fpn = True
  628. backbone = self._get_backbone(
  629. 'ResNet',
  630. depth=34,
  631. variant='d' if 'vd' in backbone else 'b',
  632. norm_type='bn',
  633. freeze_at=0,
  634. return_idx=[0, 1, 2, 3],
  635. num_stages=4)
  636. else:
  637. if not with_fpn:
  638. logging.warning(
  639. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  640. format(backbone))
  641. with_fpn = True
  642. backbone = self._get_backbone(
  643. 'ResNet',
  644. depth=101,
  645. variant='d' if 'vd' in backbone else 'b',
  646. norm_type='bn',
  647. freeze_at=0,
  648. return_idx=[0, 1, 2, 3],
  649. num_stages=4)
  650. rpn_in_channel = backbone.out_shape[0].channels
  651. if with_fpn:
  652. neck = ppdet.modeling.FPN(
  653. in_channels=[i.channels for i in backbone.out_shape],
  654. out_channel=fpn_num_channels,
  655. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  656. rpn_in_channel = neck.out_shape[0].channels
  657. anchor_generator_cfg = {
  658. 'aspect_ratios': aspect_ratios,
  659. 'anchor_sizes': anchor_sizes,
  660. 'strides': [4, 8, 16, 32, 64]
  661. }
  662. train_proposal_cfg = {
  663. 'min_size': 0.0,
  664. 'nms_thresh': .7,
  665. 'pre_nms_top_n': 2000,
  666. 'post_nms_top_n': 1000,
  667. 'topk_after_collect': True
  668. }
  669. test_proposal_cfg = {
  670. 'min_size': 0.0,
  671. 'nms_thresh': .7,
  672. 'pre_nms_top_n': 1000
  673. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  674. 'post_nms_top_n': test_post_nms_top_n
  675. }
  676. head = ppdet.modeling.TwoFCHead(out_channel=1024)
  677. roi_extractor_cfg = {
  678. 'resolution': 7,
  679. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  680. 'sampling_ratio': 0,
  681. 'aligned': True
  682. }
  683. with_pool = False
  684. else:
  685. neck = None
  686. anchor_generator_cfg = {
  687. 'aspect_ratios': aspect_ratios,
  688. 'anchor_sizes': anchor_sizes,
  689. 'strides': [16]
  690. }
  691. train_proposal_cfg = {
  692. 'min_size': 0.0,
  693. 'nms_thresh': .7,
  694. 'pre_nms_top_n': 12000,
  695. 'post_nms_top_n': 2000,
  696. 'topk_after_collect': False
  697. }
  698. test_proposal_cfg = {
  699. 'min_size': 0.0,
  700. 'nms_thresh': .7,
  701. 'pre_nms_top_n': 6000
  702. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  703. 'post_nms_top_n': test_post_nms_top_n
  704. }
  705. head = ppdet.modeling.Res5Head()
  706. roi_extractor_cfg = {
  707. 'resolution': 14,
  708. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  709. 'sampling_ratio': 0,
  710. 'aligned': True
  711. }
  712. with_pool = True
  713. rpn_target_assign_cfg = {
  714. 'batch_size_per_im': rpn_batch_size_per_im,
  715. 'fg_fraction': rpn_fg_fraction,
  716. 'negative_overlap': .3,
  717. 'positive_overlap': .7,
  718. 'use_random': True
  719. }
  720. rpn_head = ppdet.modeling.RPNHead(
  721. anchor_generator=anchor_generator_cfg,
  722. rpn_target_assign=rpn_target_assign_cfg,
  723. train_proposal=train_proposal_cfg,
  724. test_proposal=test_proposal_cfg,
  725. in_channel=rpn_in_channel)
  726. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  727. bbox_head = ppdet.modeling.BBoxHead(
  728. head=head,
  729. in_channel=head.out_shape[0].channels,
  730. roi_extractor=roi_extractor_cfg,
  731. with_pool=with_pool,
  732. bbox_assigner=bbox_assigner,
  733. num_classes=num_classes)
  734. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  735. num_classes=num_classes,
  736. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  737. nms=ppdet.modeling.MultiClassNMS(
  738. score_threshold=score_threshold,
  739. keep_top_k=keep_top_k,
  740. nms_threshold=nms_threshold))
  741. params = {
  742. 'backbone': backbone,
  743. 'neck': neck,
  744. 'rpn_head': rpn_head,
  745. 'bbox_head': bbox_head,
  746. 'bbox_post_process': bbox_post_process
  747. }
  748. self.with_fpn = with_fpn
  749. super(FasterRCNN, self).__init__(
  750. model_name='FasterRCNN', num_classes=num_classes, **params)
  751. def _compose_batch_transform(self, transforms, mode='train'):
  752. if mode == 'train':
  753. default_batch_transforms = [
  754. _BatchPadding(
  755. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  756. ]
  757. else:
  758. default_batch_transforms = [
  759. _BatchPadding(
  760. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  761. ]
  762. custom_batch_transforms = []
  763. for i, op in enumerate(transforms.transforms):
  764. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  765. if mode != 'train':
  766. raise Exception(
  767. "{} cannot be present in the {} transforms. ".format(
  768. op.__class__.__name__, mode) +
  769. "Please check the {} transforms.".format(mode))
  770. custom_batch_transforms.insert(0, copy.deepcopy(op))
  771. batch_transforms = BatchCompose(custom_batch_transforms +
  772. default_batch_transforms)
  773. return batch_transforms
  774. class PPYOLO(YOLOv3):
  775. def __init__(self,
  776. num_classes=80,
  777. backbone='ResNet50_vd_dcn',
  778. anchors=None,
  779. anchor_masks=None,
  780. use_coord_conv=True,
  781. use_iou_aware=True,
  782. use_spp=True,
  783. use_drop_block=True,
  784. scale_x_y=1.05,
  785. ignore_threshold=0.7,
  786. label_smooth=False,
  787. use_iou_loss=True,
  788. use_matrix_nms=True,
  789. nms_score_threshold=0.01,
  790. nms_topk=-1,
  791. nms_keep_topk=100,
  792. nms_iou_threshold=0.45):
  793. self.init_params = locals()
  794. if backbone not in [
  795. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  796. 'MobileNetV3_small'
  797. ]:
  798. raise ValueError(
  799. "backbone: {} is not supported. Please choose one of "
  800. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  801. format(backbone))
  802. self.backbone_name = backbone
  803. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  804. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  805. norm_type = 'sync_bn'
  806. else:
  807. norm_type = 'bn'
  808. if anchors is None and anchor_masks is None:
  809. if 'MobileNetV3' in backbone:
  810. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  811. [120, 195], [254, 235]]
  812. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  813. elif backbone == 'ResNet50_vd_dcn':
  814. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  815. [59, 119], [116, 90], [156, 198], [373, 326]]
  816. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  817. else:
  818. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  819. [344, 319]]
  820. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  821. elif anchors is None or anchor_masks is None:
  822. raise ValueError("Please define both anchors and anchor_masks.")
  823. if backbone == 'ResNet50_vd_dcn':
  824. backbone = self._get_backbone(
  825. 'ResNet',
  826. variant='d',
  827. norm_type=norm_type,
  828. return_idx=[1, 2, 3],
  829. dcn_v2_stages=[3],
  830. freeze_at=-1,
  831. freeze_norm=False,
  832. norm_decay=0.)
  833. downsample_ratios = [32, 16, 8]
  834. elif backbone == 'ResNet18_vd':
  835. backbone = self._get_backbone(
  836. 'ResNet',
  837. depth=18,
  838. variant='d',
  839. norm_type=norm_type,
  840. return_idx=[2, 3],
  841. freeze_at=-1,
  842. freeze_norm=False,
  843. norm_decay=0.)
  844. downsample_ratios = [32, 16, 8]
  845. elif backbone == 'MobileNetV3_large':
  846. backbone = self._get_backbone(
  847. 'MobileNetV3',
  848. model_name='large',
  849. norm_type=norm_type,
  850. scale=1,
  851. with_extra_blocks=False,
  852. extra_block_filters=[],
  853. feature_maps=[13, 16])
  854. downsample_ratios = [32, 16]
  855. elif backbone == 'MobileNetV3_small':
  856. backbone = self._get_backbone(
  857. 'MobileNetV3',
  858. model_name='small',
  859. norm_type=norm_type,
  860. scale=1,
  861. with_extra_blocks=False,
  862. extra_block_filters=[],
  863. feature_maps=[9, 12])
  864. downsample_ratios = [32, 16]
  865. neck = ppdet.modeling.PPYOLOFPN(
  866. norm_type=norm_type,
  867. in_channels=[i.channels for i in backbone.out_shape],
  868. coord_conv=use_coord_conv,
  869. drop_block=use_drop_block,
  870. spp=use_spp,
  871. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  872. self.backbone_name == 'ResNet18_vd') else 2)
  873. loss = ppdet.modeling.YOLOv3Loss(
  874. num_classes=num_classes,
  875. ignore_thresh=ignore_threshold,
  876. downsample=downsample_ratios,
  877. label_smooth=label_smooth,
  878. scale_x_y=scale_x_y,
  879. iou_loss=ppdet.modeling.IouLoss(
  880. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  881. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  882. if use_iou_aware else None)
  883. yolo_head = ppdet.modeling.YOLOv3Head(
  884. in_channels=[i.channels for i in neck.out_shape],
  885. anchors=anchors,
  886. anchor_masks=anchor_masks,
  887. num_classes=num_classes,
  888. loss=loss,
  889. iou_aware=use_iou_aware)
  890. if use_matrix_nms:
  891. nms = ppdet.modeling.MatrixNMS(
  892. keep_top_k=nms_keep_topk,
  893. score_threshold=nms_score_threshold,
  894. post_threshold=.05
  895. if 'MobileNetV3' in self.backbone_name else .01,
  896. nms_top_k=nms_topk,
  897. background_label=-1)
  898. else:
  899. nms = ppdet.modeling.MultiClassNMS(
  900. score_threshold=nms_score_threshold,
  901. nms_top_k=nms_topk,
  902. keep_top_k=nms_keep_topk,
  903. nms_threshold=nms_iou_threshold)
  904. post_process = ppdet.modeling.BBoxPostProcess(
  905. decode=ppdet.modeling.YOLOBox(
  906. num_classes=num_classes,
  907. conf_thresh=.005
  908. if 'MobileNetV3' in self.backbone_name else .01,
  909. scale_x_y=scale_x_y),
  910. nms=nms)
  911. params = {
  912. 'backbone': backbone,
  913. 'neck': neck,
  914. 'yolo_head': yolo_head,
  915. 'post_process': post_process
  916. }
  917. super(YOLOv3, self).__init__(
  918. model_name='YOLOv3', num_classes=num_classes, **params)
  919. self.anchors = anchors
  920. self.anchor_masks = anchor_masks
  921. self.downsample_ratios = downsample_ratios
  922. self.model_name = 'PPYOLO'
  923. class PPYOLOTiny(YOLOv3):
  924. def __init__(self,
  925. num_classes=80,
  926. backbone='MobileNetV3',
  927. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  928. [60, 170], [220, 125], [128, 222], [264, 266]],
  929. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  930. use_iou_aware=False,
  931. use_spp=True,
  932. use_drop_block=True,
  933. scale_x_y=1.05,
  934. ignore_threshold=0.5,
  935. label_smooth=False,
  936. use_iou_loss=True,
  937. use_matrix_nms=False,
  938. nms_score_threshold=0.005,
  939. nms_topk=1000,
  940. nms_keep_topk=100,
  941. nms_iou_threshold=0.45):
  942. self.init_params = locals()
  943. if backbone != 'MobileNetV3':
  944. logging.warning(
  945. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  946. "Backbone is forcibly set to MobileNetV3.")
  947. self.backbone_name = 'MobileNetV3'
  948. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  949. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  950. norm_type = 'sync_bn'
  951. else:
  952. norm_type = 'bn'
  953. backbone = self._get_backbone(
  954. 'MobileNetV3',
  955. model_name='large',
  956. norm_type=norm_type,
  957. scale=.5,
  958. with_extra_blocks=False,
  959. extra_block_filters=[],
  960. feature_maps=[7, 13, 16])
  961. downsample_ratios = [32, 16, 8]
  962. neck = ppdet.modeling.PPYOLOTinyFPN(
  963. detection_block_channels=[160, 128, 96],
  964. in_channels=[i.channels for i in backbone.out_shape],
  965. spp=use_spp,
  966. drop_block=use_drop_block)
  967. loss = ppdet.modeling.YOLOv3Loss(
  968. num_classes=num_classes,
  969. ignore_thresh=ignore_threshold,
  970. downsample=downsample_ratios,
  971. label_smooth=label_smooth,
  972. scale_x_y=scale_x_y,
  973. iou_loss=ppdet.modeling.IouLoss(
  974. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  975. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  976. if use_iou_aware else None)
  977. yolo_head = ppdet.modeling.YOLOv3Head(
  978. in_channels=[i.channels for i in neck.out_shape],
  979. anchors=anchors,
  980. anchor_masks=anchor_masks,
  981. num_classes=num_classes,
  982. loss=loss,
  983. iou_aware=use_iou_aware)
  984. if use_matrix_nms:
  985. nms = ppdet.modeling.MatrixNMS(
  986. keep_top_k=nms_keep_topk,
  987. score_threshold=nms_score_threshold,
  988. post_threshold=.05,
  989. nms_top_k=nms_topk,
  990. background_label=-1)
  991. else:
  992. nms = ppdet.modeling.MultiClassNMS(
  993. score_threshold=nms_score_threshold,
  994. nms_top_k=nms_topk,
  995. keep_top_k=nms_keep_topk,
  996. nms_threshold=nms_iou_threshold)
  997. post_process = ppdet.modeling.BBoxPostProcess(
  998. decode=ppdet.modeling.YOLOBox(
  999. num_classes=num_classes,
  1000. conf_thresh=.005,
  1001. downsample_ratio=32,
  1002. clip_bbox=True,
  1003. scale_x_y=scale_x_y),
  1004. nms=nms)
  1005. params = {
  1006. 'backbone': backbone,
  1007. 'neck': neck,
  1008. 'yolo_head': yolo_head,
  1009. 'post_process': post_process
  1010. }
  1011. super(YOLOv3, self).__init__(
  1012. model_name='YOLOv3', num_classes=num_classes, **params)
  1013. self.anchors = anchors
  1014. self.anchor_masks = anchor_masks
  1015. self.downsample_ratios = downsample_ratios
  1016. self.num_max_boxes = 100
  1017. self.model_name = 'PPYOLOTiny'
  1018. class PPYOLOv2(YOLOv3):
  1019. def __init__(self,
  1020. num_classes=80,
  1021. backbone='ResNet50_vd_dcn',
  1022. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1023. [59, 119], [116, 90], [156, 198], [373, 326]],
  1024. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1025. use_iou_aware=True,
  1026. use_spp=True,
  1027. use_drop_block=True,
  1028. scale_x_y=1.05,
  1029. ignore_threshold=0.7,
  1030. label_smooth=False,
  1031. use_iou_loss=True,
  1032. use_matrix_nms=True,
  1033. nms_score_threshold=0.01,
  1034. nms_topk=-1,
  1035. nms_keep_topk=100,
  1036. nms_iou_threshold=0.45):
  1037. self.init_params = locals()
  1038. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1039. raise ValueError(
  1040. "backbone: {} is not supported. Please choose one of "
  1041. "('ResNet50_vd_dcn', 'ResNet18_vd')".format(backbone))
  1042. self.backbone_name = backbone
  1043. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1044. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1045. norm_type = 'sync_bn'
  1046. else:
  1047. norm_type = 'bn'
  1048. if backbone == 'ResNet50_vd_dcn':
  1049. backbone = self._get_backbone(
  1050. 'ResNet',
  1051. variant='d',
  1052. norm_type=norm_type,
  1053. return_idx=[1, 2, 3],
  1054. dcn_v2_stages=[3],
  1055. freeze_at=-1,
  1056. freeze_norm=False,
  1057. norm_decay=0.)
  1058. downsample_ratios = [32, 16, 8]
  1059. elif backbone == 'ResNet101_vd_dcn':
  1060. backbone = self._get_backbone(
  1061. 'ResNet',
  1062. depth=101,
  1063. variant='d',
  1064. norm_type=norm_type,
  1065. return_idx=[1, 2, 3],
  1066. dcn_v2_stages=[3],
  1067. freeze_at=-1,
  1068. freeze_norm=False,
  1069. norm_decay=0.)
  1070. downsample_ratios = [32, 16, 8]
  1071. neck = ppdet.modeling.PPYOLOPAN(
  1072. norm_type=norm_type,
  1073. in_channels=[i.channels for i in backbone.out_shape],
  1074. drop_block=use_drop_block,
  1075. block_size=3,
  1076. keep_prob=.9,
  1077. spp=use_spp)
  1078. loss = ppdet.modeling.YOLOv3Loss(
  1079. num_classes=num_classes,
  1080. ignore_thresh=ignore_threshold,
  1081. downsample=downsample_ratios,
  1082. label_smooth=label_smooth,
  1083. scale_x_y=scale_x_y,
  1084. iou_loss=ppdet.modeling.IouLoss(
  1085. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1086. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1087. if use_iou_aware else None)
  1088. yolo_head = ppdet.modeling.YOLOv3Head(
  1089. in_channels=[i.channels for i in neck.out_shape],
  1090. anchors=anchors,
  1091. anchor_masks=anchor_masks,
  1092. num_classes=num_classes,
  1093. loss=loss,
  1094. iou_aware=use_iou_aware,
  1095. iou_aware_factor=.5)
  1096. if use_matrix_nms:
  1097. nms = ppdet.modeling.MatrixNMS(
  1098. keep_top_k=nms_keep_topk,
  1099. score_threshold=nms_score_threshold,
  1100. post_threshold=.01,
  1101. nms_top_k=nms_topk,
  1102. background_label=-1)
  1103. else:
  1104. nms = ppdet.modeling.MultiClassNMS(
  1105. score_threshold=nms_score_threshold,
  1106. nms_top_k=nms_topk,
  1107. keep_top_k=nms_keep_topk,
  1108. nms_threshold=nms_iou_threshold)
  1109. post_process = ppdet.modeling.BBoxPostProcess(
  1110. decode=ppdet.modeling.YOLOBox(
  1111. num_classes=num_classes,
  1112. conf_thresh=.01,
  1113. downsample_ratio=32,
  1114. clip_bbox=True,
  1115. scale_x_y=scale_x_y),
  1116. nms=nms)
  1117. params = {
  1118. 'backbone': backbone,
  1119. 'neck': neck,
  1120. 'yolo_head': yolo_head,
  1121. 'post_process': post_process
  1122. }
  1123. super(YOLOv3, self).__init__(
  1124. model_name='YOLOv3', num_classes=num_classes, **params)
  1125. self.anchors = anchors
  1126. self.anchor_masks = anchor_masks
  1127. self.downsample_ratios = downsample_ratios
  1128. self.num_max_boxes = 100
  1129. self.model_name = 'PPYOLOv2'
  1130. class MaskRCNN(BaseDetector):
  1131. def __init__(self,
  1132. num_classes=80,
  1133. backbone='ResNet50_vd',
  1134. with_fpn=True,
  1135. aspect_ratios=[0.5, 1.0, 2.0],
  1136. anchor_sizes=[[32], [64], [128], [256], [512]],
  1137. keep_top_k=100,
  1138. nms_threshold=0.5,
  1139. score_threshold=0.05,
  1140. fpn_num_channels=256,
  1141. rpn_batch_size_per_im=256,
  1142. rpn_fg_fraction=0.5,
  1143. test_pre_nms_top_n=None,
  1144. test_post_nms_top_n=1000):
  1145. self.init_params = locals()
  1146. if backbone not in [
  1147. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1148. 'ResNet101_vd'
  1149. ]:
  1150. raise ValueError(
  1151. "backbone: {} is not supported. Please choose one of "
  1152. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1153. format(backbone))
  1154. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1155. if backbone == 'ResNet50':
  1156. if with_fpn:
  1157. backbone = self._get_backbone(
  1158. 'ResNet',
  1159. norm_type='bn',
  1160. freeze_at=0,
  1161. return_idx=[0, 1, 2, 3],
  1162. num_stages=4)
  1163. else:
  1164. backbone = self._get_backbone(
  1165. 'ResNet',
  1166. norm_type='bn',
  1167. freeze_at=0,
  1168. return_idx=[2],
  1169. num_stages=3)
  1170. elif 'ResNet50_vd' in backbone:
  1171. if not with_fpn:
  1172. logging.warning(
  1173. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1174. format(backbone))
  1175. with_fpn = True
  1176. backbone = self._get_backbone(
  1177. 'ResNet',
  1178. variant='d',
  1179. norm_type='bn',
  1180. freeze_at=0,
  1181. return_idx=[0, 1, 2, 3],
  1182. num_stages=4,
  1183. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1184. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0])
  1185. else:
  1186. if not with_fpn:
  1187. logging.warning(
  1188. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1189. format(backbone))
  1190. with_fpn = True
  1191. backbone = self._get_backbone(
  1192. 'ResNet',
  1193. variant='d' if '_vd' in backbone else 'b',
  1194. depth=101,
  1195. norm_type='bn',
  1196. freeze_at=0,
  1197. return_idx=[0, 1, 2, 3],
  1198. num_stages=4)
  1199. rpn_in_channel = backbone.out_shape[0].channels
  1200. if with_fpn:
  1201. neck = ppdet.modeling.FPN(
  1202. in_channels=[i.channels for i in backbone.out_shape],
  1203. out_channel=fpn_num_channels,
  1204. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1205. rpn_in_channel = neck.out_shape[0].channels
  1206. anchor_generator_cfg = {
  1207. 'aspect_ratios': aspect_ratios,
  1208. 'anchor_sizes': anchor_sizes,
  1209. 'strides': [4, 8, 16, 32, 64]
  1210. }
  1211. train_proposal_cfg = {
  1212. 'min_size': 0.0,
  1213. 'nms_thresh': .7,
  1214. 'pre_nms_top_n': 2000,
  1215. 'post_nms_top_n': 1000,
  1216. 'topk_after_collect': True
  1217. }
  1218. test_proposal_cfg = {
  1219. 'min_size': 0.0,
  1220. 'nms_thresh': .7,
  1221. 'pre_nms_top_n': 1000
  1222. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1223. 'post_nms_top_n': test_post_nms_top_n
  1224. }
  1225. bb_head = ppdet.modeling.TwoFCHead(
  1226. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1227. bb_roi_extractor_cfg = {
  1228. 'resolution': 7,
  1229. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1230. 'sampling_ratio': 0,
  1231. 'aligned': True
  1232. }
  1233. with_pool = False
  1234. m_head = ppdet.modeling.MaskFeat(
  1235. in_channel=neck.out_shape[0].channels,
  1236. out_channel=256,
  1237. num_convs=4)
  1238. m_roi_extractor_cfg = {
  1239. 'resolution': 14,
  1240. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1241. 'sampling_ratio': 0,
  1242. 'aligned': True
  1243. }
  1244. mask_assigner = MaskAssigner(
  1245. num_classes=num_classes, mask_resolution=28)
  1246. share_bbox_feat = False
  1247. else:
  1248. neck = None
  1249. anchor_generator_cfg = {
  1250. 'aspect_ratios': aspect_ratios,
  1251. 'anchor_sizes': anchor_sizes,
  1252. 'strides': [16]
  1253. }
  1254. train_proposal_cfg = {
  1255. 'min_size': 0.0,
  1256. 'nms_thresh': .7,
  1257. 'pre_nms_top_n': 12000,
  1258. 'post_nms_top_n': 2000,
  1259. 'topk_after_collect': False
  1260. }
  1261. test_proposal_cfg = {
  1262. 'min_size': 0.0,
  1263. 'nms_thresh': .7,
  1264. 'pre_nms_top_n': 6000
  1265. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1266. 'post_nms_top_n': test_post_nms_top_n
  1267. }
  1268. bb_head = ppdet.modeling.Res5Head()
  1269. bb_roi_extractor_cfg = {
  1270. 'resolution': 14,
  1271. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1272. 'sampling_ratio': 0,
  1273. 'aligned': True
  1274. }
  1275. with_pool = True
  1276. m_head = ppdet.modeling.MaskFeat(
  1277. in_channel=bb_head.out_shape[0].channels,
  1278. out_channel=256,
  1279. num_convs=0)
  1280. m_roi_extractor_cfg = {
  1281. 'resolution': 14,
  1282. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1283. 'sampling_ratio': 0,
  1284. 'aligned': True
  1285. }
  1286. mask_assigner = MaskAssigner(
  1287. num_classes=num_classes, mask_resolution=14)
  1288. share_bbox_feat = True
  1289. rpn_target_assign_cfg = {
  1290. 'batch_size_per_im': rpn_batch_size_per_im,
  1291. 'fg_fraction': rpn_fg_fraction,
  1292. 'negative_overlap': .3,
  1293. 'positive_overlap': .7,
  1294. 'use_random': True
  1295. }
  1296. rpn_head = ppdet.modeling.RPNHead(
  1297. anchor_generator=anchor_generator_cfg,
  1298. rpn_target_assign=rpn_target_assign_cfg,
  1299. train_proposal=train_proposal_cfg,
  1300. test_proposal=test_proposal_cfg,
  1301. in_channel=rpn_in_channel)
  1302. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1303. bbox_head = ppdet.modeling.BBoxHead(
  1304. head=bb_head,
  1305. in_channel=bb_head.out_shape[0].channels,
  1306. roi_extractor=bb_roi_extractor_cfg,
  1307. with_pool=with_pool,
  1308. bbox_assigner=bbox_assigner,
  1309. num_classes=num_classes)
  1310. mask_head = ppdet.modeling.MaskHead(
  1311. head=m_head,
  1312. roi_extractor=m_roi_extractor_cfg,
  1313. mask_assigner=mask_assigner,
  1314. share_bbox_feat=share_bbox_feat,
  1315. num_classes=num_classes)
  1316. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1317. num_classes=num_classes,
  1318. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1319. nms=ppdet.modeling.MultiClassNMS(
  1320. score_threshold=score_threshold,
  1321. keep_top_k=keep_top_k,
  1322. nms_threshold=nms_threshold))
  1323. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1324. params = {
  1325. 'backbone': backbone,
  1326. 'neck': neck,
  1327. 'rpn_head': rpn_head,
  1328. 'bbox_head': bbox_head,
  1329. 'mask_head': mask_head,
  1330. 'bbox_post_process': bbox_post_process,
  1331. 'mask_post_process': mask_post_process
  1332. }
  1333. self.with_fpn = with_fpn
  1334. super(MaskRCNN, self).__init__(
  1335. model_name='MaskRCNN', num_classes=num_classes, **params)
  1336. def _compose_batch_transform(self, transforms, mode='train'):
  1337. if mode == 'train':
  1338. default_batch_transforms = [
  1339. _BatchPadding(
  1340. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  1341. ]
  1342. else:
  1343. default_batch_transforms = [
  1344. _BatchPadding(
  1345. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  1346. ]
  1347. custom_batch_transforms = []
  1348. for i, op in enumerate(transforms.transforms):
  1349. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1350. if mode != 'train':
  1351. raise Exception(
  1352. "{} cannot be present in the {} transforms. ".format(
  1353. op.__class__.__name__, mode) +
  1354. "Please check the {} transforms.".format(mode))
  1355. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1356. batch_transforms = BatchCompose(custom_batch_transforms +
  1357. default_batch_transforms)
  1358. return batch_transforms