detector.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import ppdet
  24. from ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. super(BaseDetector, self).__init__('detector')
  41. if not hasattr(ppdet.modeling, model_name):
  42. raise Exception("ERROR: There's no model named {}.".format(
  43. model_name))
  44. self.model_name = model_name
  45. self.num_classes = num_classes
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. def build_net(self, **params):
  49. with paddle.utils.unique_name.guard():
  50. net = ppdet.modeling.__dict__[self.model_name](**params)
  51. return net
  52. def _fix_transforms_shape(self, image_shape):
  53. pass
  54. def _get_test_inputs(self, image_shape):
  55. if image_shape is not None:
  56. if len(image_shape) == 2:
  57. image_shape = [None, 3] + image_shape
  58. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  59. raise Exception(
  60. "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
  61. format(image_shape[-2:]))
  62. self._fix_transforms_shape(image_shape[-2:])
  63. else:
  64. image_shape = [None, 3, -1, -1]
  65. input_spec = [{
  66. "image": InputSpec(
  67. shape=image_shape, name='image', dtype='float32'),
  68. "im_shape": InputSpec(
  69. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  70. "scale_factor": InputSpec(
  71. shape=[image_shape[0], 2],
  72. name='scale_factor',
  73. dtype='float32')
  74. }]
  75. return input_spec
  76. def _get_backbone(self, backbone_name, **params):
  77. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  78. return backbone
  79. def run(self, net, inputs, mode):
  80. net_out = net(inputs)
  81. if mode in ['train', 'eval']:
  82. outputs = net_out
  83. else:
  84. for key in ['im_shape', 'scale_factor']:
  85. net_out[key] = inputs[key]
  86. outputs = dict()
  87. for key in net_out:
  88. outputs[key] = net_out[key].numpy()
  89. return outputs
  90. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  91. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  92. num_steps_each_epoch):
  93. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  94. values = [(lr_decay_gamma**i) * learning_rate
  95. for i in range(len(lr_decay_epochs) + 1)]
  96. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  97. boundaries=boundaries, values=values)
  98. if warmup_steps > 0:
  99. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  100. logging.error(
  101. "In function train(), parameters should satisfy: "
  102. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  103. exit=False)
  104. logging.error(
  105. "See this doc for more information: "
  106. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  107. exit=False)
  108. scheduler = paddle.optimizer.lr.LinearWarmup(
  109. learning_rate=scheduler,
  110. warmup_steps=warmup_steps,
  111. start_lr=warmup_start_lr,
  112. end_lr=learning_rate)
  113. optimizer = paddle.optimizer.Momentum(
  114. scheduler,
  115. momentum=.9,
  116. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  117. parameters=parameters)
  118. return optimizer
  119. def train(self,
  120. num_epochs,
  121. train_dataset,
  122. train_batch_size=64,
  123. eval_dataset=None,
  124. optimizer=None,
  125. save_interval_epochs=1,
  126. log_interval_steps=10,
  127. save_dir='output',
  128. pretrain_weights='IMAGENET',
  129. learning_rate=.001,
  130. warmup_steps=0,
  131. warmup_start_lr=0.0,
  132. lr_decay_epochs=(216, 243),
  133. lr_decay_gamma=0.1,
  134. metric=None,
  135. use_ema=False,
  136. early_stop=False,
  137. early_stop_patience=5,
  138. use_vdl=True):
  139. """
  140. Train the model.
  141. Args:
  142. num_epochs(int): The number of epochs.
  143. train_dataset(paddlex.dataset): Training dataset.
  144. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  145. eval_dataset(paddlex.dataset, optional):
  146. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  147. optimizer(paddle.optimizer.Optimizer or None, optional):
  148. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  149. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  150. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  151. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  152. pretrain_weights(str or None, optional):
  153. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  154. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  155. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  156. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  157. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  158. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  159. metric({'VOC', 'COCO', None}, optional):
  160. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  161. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  162. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  163. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  164. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  165. """
  166. if train_dataset.__class__.__name__ == 'VOCDetection':
  167. train_dataset.data_fields = {
  168. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  169. 'difficult'
  170. }
  171. elif train_dataset.__class__.__name__ == 'CocoDetection':
  172. if self.__class__.__name__ == 'MaskRCNN':
  173. train_dataset.data_fields = {
  174. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  175. 'gt_poly', 'is_crowd'
  176. }
  177. else:
  178. train_dataset.data_fields = {
  179. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  180. 'is_crowd'
  181. }
  182. if metric is None:
  183. if eval_dataset.__class__.__name__ == 'VOCDetection':
  184. self.metric = 'voc'
  185. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  186. self.metric = 'coco'
  187. else:
  188. assert metric.lower() in ['coco', 'voc'], \
  189. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  190. self.metric = metric.lower()
  191. train_dataset.batch_transforms = self._compose_batch_transform(
  192. train_dataset.transforms, mode='train')
  193. self.labels = train_dataset.labels
  194. # build optimizer if not defined
  195. if optimizer is None:
  196. num_steps_each_epoch = len(train_dataset) // train_batch_size
  197. self.optimizer = self.default_optimizer(
  198. parameters=self.net.parameters(),
  199. learning_rate=learning_rate,
  200. warmup_steps=warmup_steps,
  201. warmup_start_lr=warmup_start_lr,
  202. lr_decay_epochs=lr_decay_epochs,
  203. lr_decay_gamma=lr_decay_gamma,
  204. num_steps_each_epoch=num_steps_each_epoch)
  205. else:
  206. self.optimizer = optimizer
  207. # initiate weights
  208. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  209. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  210. [self.model_name, self.backbone_name])]:
  211. logging.warning(
  212. "Path of pretrain_weights('{}') does not exist!".format(
  213. pretrain_weights))
  214. pretrain_weights = det_pretrain_weights_dict['_'.join(
  215. [self.model_name, self.backbone_name])][0]
  216. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  217. "If don't want to use pretrain weights, "
  218. "set pretrain_weights to be None.".format(
  219. pretrain_weights))
  220. pretrained_dir = osp.join(save_dir, 'pretrain')
  221. self.net_initialize(
  222. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  223. if use_ema:
  224. ema = ExponentialMovingAverage(
  225. decay=.9998, model=self.net, use_thres_step=True)
  226. else:
  227. ema = None
  228. # start train loop
  229. self.train_loop(
  230. num_epochs=num_epochs,
  231. train_dataset=train_dataset,
  232. train_batch_size=train_batch_size,
  233. eval_dataset=eval_dataset,
  234. save_interval_epochs=save_interval_epochs,
  235. log_interval_steps=log_interval_steps,
  236. save_dir=save_dir,
  237. ema=ema,
  238. early_stop=early_stop,
  239. early_stop_patience=early_stop_patience,
  240. use_vdl=use_vdl)
  241. def evaluate(self,
  242. eval_dataset,
  243. batch_size=1,
  244. metric=None,
  245. return_details=False):
  246. """
  247. Evaluate the model.
  248. Args:
  249. eval_dataset(paddlex.dataset): Evaluation dataset.
  250. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  251. metric({'VOC', 'COCO', None}, optional):
  252. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  253. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  254. Returns:
  255. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  256. """
  257. if eval_dataset.__class__.__name__ == 'VOCDetection':
  258. eval_dataset.data_fields = {
  259. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  260. 'difficult'
  261. }
  262. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  263. if self.__class__.__name__ == 'MaskRCNN':
  264. eval_dataset.data_fields = {
  265. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  266. 'gt_poly', 'is_crowd'
  267. }
  268. else:
  269. eval_dataset.data_fields = {
  270. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  271. 'is_crowd'
  272. }
  273. eval_dataset.batch_transforms = self._compose_batch_transform(
  274. eval_dataset.transforms, mode='eval')
  275. arrange_transforms(
  276. model_type=self.model_type,
  277. transforms=eval_dataset.transforms,
  278. mode='eval')
  279. self.net.eval()
  280. nranks = paddle.distributed.get_world_size()
  281. local_rank = paddle.distributed.get_rank()
  282. if nranks > 1:
  283. # Initialize parallel environment if not done.
  284. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  285. ):
  286. paddle.distributed.init_parallel_env()
  287. if batch_size > 1:
  288. logging.warning(
  289. "Detector only supports single card evaluation with batch_size=1 "
  290. "during evaluation, so batch_size is forcibly set to 1.")
  291. batch_size = 1
  292. if nranks < 2 or local_rank == 0:
  293. self.eval_data_loader = self.build_data_loader(
  294. eval_dataset, batch_size=batch_size, mode='eval')
  295. is_bbox_normalized = False
  296. if eval_dataset.batch_transforms is not None:
  297. is_bbox_normalized = any(
  298. isinstance(t, _NormalizeBox)
  299. for t in eval_dataset.batch_transforms.batch_transforms)
  300. if metric is None:
  301. if getattr(self, 'metric', None) is not None:
  302. if self.metric == 'voc':
  303. eval_metric = VOCMetric(
  304. labels=eval_dataset.labels,
  305. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  306. is_bbox_normalized=is_bbox_normalized,
  307. classwise=False)
  308. else:
  309. eval_metric = COCOMetric(
  310. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  311. classwise=False)
  312. else:
  313. if eval_dataset.__class__.__name__ == 'VOCDetection':
  314. eval_metric = VOCMetric(
  315. labels=eval_dataset.labels,
  316. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  317. is_bbox_normalized=is_bbox_normalized,
  318. classwise=False)
  319. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  320. eval_metric = COCOMetric(
  321. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  322. classwise=False)
  323. else:
  324. assert metric.lower() in ['coco', 'voc'], \
  325. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  326. if metric.lower() == 'coco':
  327. eval_metric = COCOMetric(
  328. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  329. classwise=False)
  330. else:
  331. eval_metric = VOCMetric(
  332. labels=eval_dataset.labels,
  333. is_bbox_normalized=is_bbox_normalized,
  334. classwise=False)
  335. scores = collections.OrderedDict()
  336. logging.info(
  337. "Start to evaluate(total_samples={}, total_steps={})...".
  338. format(eval_dataset.num_samples, eval_dataset.num_samples))
  339. with paddle.no_grad():
  340. for step, data in enumerate(self.eval_data_loader):
  341. outputs = self.run(self.net, data, 'eval')
  342. eval_metric.update(data, outputs)
  343. eval_metric.accumulate()
  344. self.eval_details = eval_metric.details
  345. scores.update(eval_metric.get())
  346. eval_metric.reset()
  347. if return_details:
  348. return scores, self.eval_details
  349. return scores
  350. def predict(self, img_file, transforms=None):
  351. """
  352. Do inference.
  353. Args:
  354. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  355. Image path or decoded image data in a BGR format, which also could constitute a list,
  356. meaning all images to be predicted as a mini-batch.
  357. transforms(paddlex.transforms.Compose or None, optional):
  358. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  359. Returns:
  360. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  361. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  362. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  363. category_id(int): the predicted category ID
  364. category(str): category name
  365. bbox(list): bounding box in [x, y, w, h] format
  366. score(str): confidence
  367. """
  368. if transforms is None and not hasattr(self, 'test_transforms'):
  369. raise Exception("transforms need to be defined, now is None.")
  370. if transforms is None:
  371. transforms = self.test_transforms
  372. if isinstance(img_file, (str, np.ndarray)):
  373. images = [img_file]
  374. else:
  375. images = img_file
  376. batch_samples = self._preprocess(images, transforms)
  377. self.net.eval()
  378. outputs = self.run(self.net, batch_samples, 'test')
  379. prediction = self._postprocess(outputs)
  380. if isinstance(img_file, (str, np.ndarray)):
  381. prediction = prediction[0]
  382. return prediction
  383. def _preprocess(self, images, transforms):
  384. arrange_transforms(
  385. model_type=self.model_type, transforms=transforms, mode='test')
  386. batch_samples = list()
  387. for im in images:
  388. sample = {'image': im}
  389. batch_samples.append(transforms(sample))
  390. batch_transforms = self._compose_batch_transform(transforms, 'test')
  391. batch_samples = batch_transforms(batch_samples)
  392. for k, v in batch_samples.items():
  393. batch_samples[k] = paddle.to_tensor(v)
  394. return batch_samples
  395. def _postprocess(self, batch_pred):
  396. infer_result = {}
  397. if 'bbox' in batch_pred:
  398. bboxes = batch_pred['bbox']
  399. bbox_nums = batch_pred['bbox_num']
  400. det_res = []
  401. k = 0
  402. for i in range(len(bbox_nums)):
  403. det_nums = bbox_nums[i]
  404. for j in range(det_nums):
  405. dt = bboxes[k]
  406. k = k + 1
  407. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  408. if int(num_id) < 0:
  409. continue
  410. category = self.labels[int(num_id)]
  411. w = xmax - xmin
  412. h = ymax - ymin
  413. bbox = [xmin, ymin, w, h]
  414. dt_res = {
  415. 'category_id': int(num_id),
  416. 'category': category,
  417. 'bbox': bbox,
  418. 'score': score
  419. }
  420. det_res.append(dt_res)
  421. infer_result['bbox'] = det_res
  422. if 'mask' in batch_pred:
  423. masks = batch_pred['mask']
  424. bboxes = batch_pred['bbox']
  425. mask_nums = batch_pred['bbox_num']
  426. seg_res = []
  427. k = 0
  428. for i in range(len(mask_nums)):
  429. det_nums = mask_nums[i]
  430. for j in range(det_nums):
  431. mask = masks[k].astype(np.uint8)
  432. score = float(bboxes[k][1])
  433. label = int(bboxes[k][0])
  434. k = k + 1
  435. if label == -1:
  436. continue
  437. category = self.labels[int(label)]
  438. import pycocotools.mask as mask_util
  439. rle = mask_util.encode(
  440. np.array(
  441. mask[:, :, None], order="F", dtype="uint8"))[0]
  442. if six.PY3:
  443. if 'counts' in rle:
  444. rle['counts'] = rle['counts'].decode("utf8")
  445. sg_res = {
  446. 'category': category,
  447. 'segmentation': rle,
  448. 'score': score
  449. }
  450. seg_res.append(sg_res)
  451. infer_result['mask'] = seg_res
  452. bbox_num = batch_pred['bbox_num']
  453. results = []
  454. start = 0
  455. for num in bbox_num:
  456. end = start + num
  457. curr_res = infer_result['bbox'][start:end]
  458. if 'mask' in infer_result:
  459. mask_res = infer_result['mask'][start:end]
  460. for box, mask in zip(curr_res, mask_res):
  461. box.update(mask)
  462. results.append(curr_res)
  463. start = end
  464. return results
  465. class YOLOv3(BaseDetector):
  466. def __init__(self,
  467. num_classes=80,
  468. backbone='MobileNetV1',
  469. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  470. [59, 119], [116, 90], [156, 198], [373, 326]],
  471. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  472. ignore_threshold=0.7,
  473. nms_score_threshold=0.01,
  474. nms_topk=1000,
  475. nms_keep_topk=100,
  476. nms_iou_threshold=0.45,
  477. label_smooth=False):
  478. self.init_params = locals()
  479. if backbone not in [
  480. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  481. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  482. ]:
  483. raise ValueError(
  484. "backbone: {} is not supported. Please choose one of "
  485. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  486. format(backbone))
  487. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  488. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  489. norm_type = 'sync_bn'
  490. else:
  491. norm_type = 'bn'
  492. self.backbone_name = backbone
  493. if 'MobileNetV1' in backbone:
  494. norm_type = 'bn'
  495. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  496. elif 'MobileNetV3' in backbone:
  497. backbone = self._get_backbone(
  498. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  499. elif backbone == 'ResNet50_vd_dcn':
  500. backbone = self._get_backbone(
  501. 'ResNet',
  502. norm_type=norm_type,
  503. variant='d',
  504. return_idx=[1, 2, 3],
  505. dcn_v2_stages=[3],
  506. freeze_at=-1,
  507. freeze_norm=False)
  508. elif backbone == 'ResNet34':
  509. backbone = self._get_backbone(
  510. 'ResNet',
  511. depth=34,
  512. norm_type=norm_type,
  513. return_idx=[1, 2, 3],
  514. freeze_at=-1,
  515. freeze_norm=False,
  516. norm_decay=0.)
  517. else:
  518. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  519. neck = ppdet.modeling.YOLOv3FPN(
  520. norm_type=norm_type,
  521. in_channels=[i.channels for i in backbone.out_shape])
  522. loss = ppdet.modeling.YOLOv3Loss(
  523. num_classes=num_classes,
  524. ignore_thresh=ignore_threshold,
  525. label_smooth=label_smooth)
  526. yolo_head = ppdet.modeling.YOLOv3Head(
  527. in_channels=[i.channels for i in neck.out_shape],
  528. anchors=anchors,
  529. anchor_masks=anchor_masks,
  530. num_classes=num_classes,
  531. loss=loss)
  532. post_process = ppdet.modeling.BBoxPostProcess(
  533. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  534. nms=ppdet.modeling.MultiClassNMS(
  535. score_threshold=nms_score_threshold,
  536. nms_top_k=nms_topk,
  537. keep_top_k=nms_keep_topk,
  538. nms_threshold=nms_iou_threshold))
  539. params = {
  540. 'backbone': backbone,
  541. 'neck': neck,
  542. 'yolo_head': yolo_head,
  543. 'post_process': post_process
  544. }
  545. super(YOLOv3, self).__init__(
  546. model_name='YOLOv3', num_classes=num_classes, **params)
  547. self.anchors = anchors
  548. self.anchor_masks = anchor_masks
  549. def _compose_batch_transform(self, transforms, mode='train'):
  550. if mode == 'train':
  551. default_batch_transforms = [
  552. _BatchPadding(
  553. pad_to_stride=-1, pad_gt=False), _NormalizeBox(),
  554. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  555. _Gt2YoloTarget(
  556. anchor_masks=self.anchor_masks,
  557. anchors=self.anchors,
  558. downsample_ratios=getattr(self, 'downsample_ratios',
  559. [32, 16, 8]),
  560. num_classes=self.num_classes)
  561. ]
  562. else:
  563. default_batch_transforms = [
  564. _BatchPadding(
  565. pad_to_stride=-1, pad_gt=False)
  566. ]
  567. custom_batch_transforms = []
  568. for i, op in enumerate(transforms.transforms):
  569. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  570. if mode != 'train':
  571. raise Exception(
  572. "{} cannot be present in the {} transforms. ".format(
  573. op.__class__.__name__, mode) +
  574. "Please check the {} transforms.".format(mode))
  575. custom_batch_transforms.insert(0, copy.deepcopy(op))
  576. batch_transforms = BatchCompose(custom_batch_transforms +
  577. default_batch_transforms)
  578. return batch_transforms
  579. def _fix_transforms_shape(self, image_shape):
  580. if hasattr(self, 'test_transforms'):
  581. if self.test_transforms is not None:
  582. has_resize_op = False
  583. resize_op_idx = -1
  584. normalize_op_idx = len(self.test_transforms.transforms)
  585. for idx, op in enumerate(self.test_transforms.transforms):
  586. name = op.__class__.__name__
  587. if name == 'Resize':
  588. has_resize_op = True
  589. resize_op_idx = idx
  590. if name == 'Normalize':
  591. normalize_op_idx = idx
  592. if not has_resize_op:
  593. self.test_transforms.transforms.insert(
  594. normalize_op_idx,
  595. Resize(
  596. target_size=image_shape, interp='CUBIC'))
  597. else:
  598. self.test_transforms.transforms[
  599. resize_op_idx].target_size = image_shape
  600. class FasterRCNN(BaseDetector):
  601. def __init__(self,
  602. num_classes=80,
  603. backbone='ResNet50',
  604. with_fpn=True,
  605. aspect_ratios=[0.5, 1.0, 2.0],
  606. anchor_sizes=[[32], [64], [128], [256], [512]],
  607. keep_top_k=100,
  608. nms_threshold=0.5,
  609. score_threshold=0.05,
  610. fpn_num_channels=256,
  611. rpn_batch_size_per_im=256,
  612. rpn_fg_fraction=0.5,
  613. test_pre_nms_top_n=None,
  614. test_post_nms_top_n=1000):
  615. self.init_params = locals()
  616. if backbone not in [
  617. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  618. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd'
  619. ]:
  620. raise ValueError(
  621. "backbone: {} is not supported. Please choose one of "
  622. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  623. "'ResNet101', 'ResNet101_vd')".format(backbone))
  624. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  625. if backbone == 'ResNet50_vd_ssld':
  626. if not with_fpn:
  627. logging.warning(
  628. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  629. format(backbone))
  630. with_fpn = True
  631. backbone = self._get_backbone(
  632. 'ResNet',
  633. variant='d',
  634. norm_type='bn',
  635. freeze_at=0,
  636. return_idx=[0, 1, 2, 3],
  637. num_stages=4,
  638. lr_mult_list=[0.05, 0.05, 0.1, 0.15])
  639. elif 'ResNet50' in backbone:
  640. if with_fpn:
  641. backbone = self._get_backbone(
  642. 'ResNet',
  643. variant='d' if '_vd' in backbone else 'b',
  644. norm_type='bn',
  645. freeze_at=0,
  646. return_idx=[0, 1, 2, 3],
  647. num_stages=4)
  648. else:
  649. backbone = self._get_backbone(
  650. 'ResNet',
  651. variant='d' if '_vd' in backbone else 'b',
  652. norm_type='bn',
  653. freeze_at=0,
  654. return_idx=[2],
  655. num_stages=3)
  656. elif 'ResNet34' in backbone:
  657. if not with_fpn:
  658. logging.warning(
  659. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  660. format(backbone))
  661. with_fpn = True
  662. backbone = self._get_backbone(
  663. 'ResNet',
  664. depth=34,
  665. variant='d' if 'vd' in backbone else 'b',
  666. norm_type='bn',
  667. freeze_at=0,
  668. return_idx=[0, 1, 2, 3],
  669. num_stages=4)
  670. else:
  671. if not with_fpn:
  672. logging.warning(
  673. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  674. format(backbone))
  675. with_fpn = True
  676. backbone = self._get_backbone(
  677. 'ResNet',
  678. depth=101,
  679. variant='d' if 'vd' in backbone else 'b',
  680. norm_type='bn',
  681. freeze_at=0,
  682. return_idx=[0, 1, 2, 3],
  683. num_stages=4)
  684. rpn_in_channel = backbone.out_shape[0].channels
  685. if with_fpn:
  686. neck = ppdet.modeling.FPN(
  687. in_channels=[i.channels for i in backbone.out_shape],
  688. out_channel=fpn_num_channels,
  689. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  690. rpn_in_channel = neck.out_shape[0].channels
  691. anchor_generator_cfg = {
  692. 'aspect_ratios': aspect_ratios,
  693. 'anchor_sizes': anchor_sizes,
  694. 'strides': [4, 8, 16, 32, 64]
  695. }
  696. train_proposal_cfg = {
  697. 'min_size': 0.0,
  698. 'nms_thresh': .7,
  699. 'pre_nms_top_n': 2000,
  700. 'post_nms_top_n': 1000,
  701. 'topk_after_collect': True
  702. }
  703. test_proposal_cfg = {
  704. 'min_size': 0.0,
  705. 'nms_thresh': .7,
  706. 'pre_nms_top_n': 1000
  707. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  708. 'post_nms_top_n': test_post_nms_top_n
  709. }
  710. head = ppdet.modeling.TwoFCHead(out_channel=1024)
  711. roi_extractor_cfg = {
  712. 'resolution': 7,
  713. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  714. 'sampling_ratio': 0,
  715. 'aligned': True
  716. }
  717. with_pool = False
  718. else:
  719. neck = None
  720. anchor_generator_cfg = {
  721. 'aspect_ratios': aspect_ratios,
  722. 'anchor_sizes': anchor_sizes,
  723. 'strides': [16]
  724. }
  725. train_proposal_cfg = {
  726. 'min_size': 0.0,
  727. 'nms_thresh': .7,
  728. 'pre_nms_top_n': 12000,
  729. 'post_nms_top_n': 2000,
  730. 'topk_after_collect': False
  731. }
  732. test_proposal_cfg = {
  733. 'min_size': 0.0,
  734. 'nms_thresh': .7,
  735. 'pre_nms_top_n': 6000
  736. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  737. 'post_nms_top_n': test_post_nms_top_n
  738. }
  739. head = ppdet.modeling.Res5Head()
  740. roi_extractor_cfg = {
  741. 'resolution': 14,
  742. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  743. 'sampling_ratio': 0,
  744. 'aligned': True
  745. }
  746. with_pool = True
  747. rpn_target_assign_cfg = {
  748. 'batch_size_per_im': rpn_batch_size_per_im,
  749. 'fg_fraction': rpn_fg_fraction,
  750. 'negative_overlap': .3,
  751. 'positive_overlap': .7,
  752. 'use_random': True
  753. }
  754. rpn_head = ppdet.modeling.RPNHead(
  755. anchor_generator=anchor_generator_cfg,
  756. rpn_target_assign=rpn_target_assign_cfg,
  757. train_proposal=train_proposal_cfg,
  758. test_proposal=test_proposal_cfg,
  759. in_channel=rpn_in_channel)
  760. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  761. bbox_head = ppdet.modeling.BBoxHead(
  762. head=head,
  763. in_channel=head.out_shape[0].channels,
  764. roi_extractor=roi_extractor_cfg,
  765. with_pool=with_pool,
  766. bbox_assigner=bbox_assigner,
  767. num_classes=num_classes)
  768. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  769. num_classes=num_classes,
  770. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  771. nms=ppdet.modeling.MultiClassNMS(
  772. score_threshold=score_threshold,
  773. keep_top_k=keep_top_k,
  774. nms_threshold=nms_threshold))
  775. params = {
  776. 'backbone': backbone,
  777. 'neck': neck,
  778. 'rpn_head': rpn_head,
  779. 'bbox_head': bbox_head,
  780. 'bbox_post_process': bbox_post_process
  781. }
  782. self.with_fpn = with_fpn
  783. super(FasterRCNN, self).__init__(
  784. model_name='FasterRCNN', num_classes=num_classes, **params)
  785. def _compose_batch_transform(self, transforms, mode='train'):
  786. if mode == 'train':
  787. default_batch_transforms = [
  788. _BatchPadding(
  789. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  790. ]
  791. else:
  792. default_batch_transforms = [
  793. _BatchPadding(
  794. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  795. ]
  796. custom_batch_transforms = []
  797. for i, op in enumerate(transforms.transforms):
  798. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  799. if mode != 'train':
  800. raise Exception(
  801. "{} cannot be present in the {} transforms. ".format(
  802. op.__class__.__name__, mode) +
  803. "Please check the {} transforms.".format(mode))
  804. custom_batch_transforms.insert(0, copy.deepcopy(op))
  805. batch_transforms = BatchCompose(custom_batch_transforms +
  806. default_batch_transforms)
  807. return batch_transforms
  808. def _fix_transforms_shape(self, image_shape):
  809. if hasattr(self, 'test_transforms'):
  810. if self.test_transforms is not None:
  811. has_resize_op = False
  812. resize_op_idx = -1
  813. normalize_op_idx = len(self.test_transforms.transforms)
  814. for idx, op in enumerate(self.test_transforms.transforms):
  815. name = op.__class__.__name__
  816. if name == 'ResizeByShort':
  817. has_resize_op = True
  818. resize_op_idx = idx
  819. if name == 'Normalize':
  820. normalize_op_idx = idx
  821. if not has_resize_op:
  822. self.test_transforms.transforms.insert(
  823. normalize_op_idx,
  824. Resize(
  825. target_size=image_shape,
  826. keep_ratio=True,
  827. interp='CUBIC'))
  828. else:
  829. self.test_transforms.transforms[resize_op_idx] = Resize(
  830. target_size=image_shape,
  831. keep_ratio=True,
  832. interp='CUBIC')
  833. self.test_transforms.transforms.append(
  834. Padding(im_padding_value=[0., 0., 0.]))
  835. class PPYOLO(YOLOv3):
  836. def __init__(self,
  837. num_classes=80,
  838. backbone='ResNet50_vd_dcn',
  839. anchors=None,
  840. anchor_masks=None,
  841. use_coord_conv=True,
  842. use_iou_aware=True,
  843. use_spp=True,
  844. use_drop_block=True,
  845. scale_x_y=1.05,
  846. ignore_threshold=0.7,
  847. label_smooth=False,
  848. use_iou_loss=True,
  849. use_matrix_nms=True,
  850. nms_score_threshold=0.01,
  851. nms_topk=-1,
  852. nms_keep_topk=100,
  853. nms_iou_threshold=0.45):
  854. self.init_params = locals()
  855. if backbone not in [
  856. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  857. 'MobileNetV3_small'
  858. ]:
  859. raise ValueError(
  860. "backbone: {} is not supported. Please choose one of "
  861. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  862. format(backbone))
  863. self.backbone_name = backbone
  864. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  865. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  866. norm_type = 'sync_bn'
  867. else:
  868. norm_type = 'bn'
  869. if anchors is None and anchor_masks is None:
  870. if 'MobileNetV3' in backbone:
  871. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  872. [120, 195], [254, 235]]
  873. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  874. elif backbone == 'ResNet50_vd_dcn':
  875. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  876. [59, 119], [116, 90], [156, 198], [373, 326]]
  877. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  878. else:
  879. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  880. [344, 319]]
  881. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  882. elif anchors is None or anchor_masks is None:
  883. raise ValueError("Please define both anchors and anchor_masks.")
  884. if backbone == 'ResNet50_vd_dcn':
  885. backbone = self._get_backbone(
  886. 'ResNet',
  887. variant='d',
  888. norm_type=norm_type,
  889. return_idx=[1, 2, 3],
  890. dcn_v2_stages=[3],
  891. freeze_at=-1,
  892. freeze_norm=False,
  893. norm_decay=0.)
  894. downsample_ratios = [32, 16, 8]
  895. elif backbone == 'ResNet18_vd':
  896. backbone = self._get_backbone(
  897. 'ResNet',
  898. depth=18,
  899. variant='d',
  900. norm_type=norm_type,
  901. return_idx=[2, 3],
  902. freeze_at=-1,
  903. freeze_norm=False,
  904. norm_decay=0.)
  905. downsample_ratios = [32, 16, 8]
  906. elif backbone == 'MobileNetV3_large':
  907. backbone = self._get_backbone(
  908. 'MobileNetV3',
  909. model_name='large',
  910. norm_type=norm_type,
  911. scale=1,
  912. with_extra_blocks=False,
  913. extra_block_filters=[],
  914. feature_maps=[13, 16])
  915. downsample_ratios = [32, 16]
  916. elif backbone == 'MobileNetV3_small':
  917. backbone = self._get_backbone(
  918. 'MobileNetV3',
  919. model_name='small',
  920. norm_type=norm_type,
  921. scale=1,
  922. with_extra_blocks=False,
  923. extra_block_filters=[],
  924. feature_maps=[9, 12])
  925. downsample_ratios = [32, 16]
  926. neck = ppdet.modeling.PPYOLOFPN(
  927. norm_type=norm_type,
  928. in_channels=[i.channels for i in backbone.out_shape],
  929. coord_conv=use_coord_conv,
  930. drop_block=use_drop_block,
  931. spp=use_spp,
  932. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  933. self.backbone_name == 'ResNet18_vd') else 2)
  934. loss = ppdet.modeling.YOLOv3Loss(
  935. num_classes=num_classes,
  936. ignore_thresh=ignore_threshold,
  937. downsample=downsample_ratios,
  938. label_smooth=label_smooth,
  939. scale_x_y=scale_x_y,
  940. iou_loss=ppdet.modeling.IouLoss(
  941. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  942. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  943. if use_iou_aware else None)
  944. yolo_head = ppdet.modeling.YOLOv3Head(
  945. in_channels=[i.channels for i in neck.out_shape],
  946. anchors=anchors,
  947. anchor_masks=anchor_masks,
  948. num_classes=num_classes,
  949. loss=loss,
  950. iou_aware=use_iou_aware)
  951. if use_matrix_nms:
  952. nms = ppdet.modeling.MatrixNMS(
  953. keep_top_k=nms_keep_topk,
  954. score_threshold=nms_score_threshold,
  955. post_threshold=.05
  956. if 'MobileNetV3' in self.backbone_name else .01,
  957. nms_top_k=nms_topk,
  958. background_label=-1)
  959. else:
  960. nms = ppdet.modeling.MultiClassNMS(
  961. score_threshold=nms_score_threshold,
  962. nms_top_k=nms_topk,
  963. keep_top_k=nms_keep_topk,
  964. nms_threshold=nms_iou_threshold)
  965. post_process = ppdet.modeling.BBoxPostProcess(
  966. decode=ppdet.modeling.YOLOBox(
  967. num_classes=num_classes,
  968. conf_thresh=.005
  969. if 'MobileNetV3' in self.backbone_name else .01,
  970. scale_x_y=scale_x_y),
  971. nms=nms)
  972. params = {
  973. 'backbone': backbone,
  974. 'neck': neck,
  975. 'yolo_head': yolo_head,
  976. 'post_process': post_process
  977. }
  978. super(YOLOv3, self).__init__(
  979. model_name='YOLOv3', num_classes=num_classes, **params)
  980. self.anchors = anchors
  981. self.anchor_masks = anchor_masks
  982. self.downsample_ratios = downsample_ratios
  983. self.model_name = 'PPYOLO'
  984. class PPYOLOTiny(YOLOv3):
  985. def __init__(self,
  986. num_classes=80,
  987. backbone='MobileNetV3',
  988. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  989. [60, 170], [220, 125], [128, 222], [264, 266]],
  990. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  991. use_iou_aware=False,
  992. use_spp=True,
  993. use_drop_block=True,
  994. scale_x_y=1.05,
  995. ignore_threshold=0.5,
  996. label_smooth=False,
  997. use_iou_loss=True,
  998. use_matrix_nms=False,
  999. nms_score_threshold=0.005,
  1000. nms_topk=1000,
  1001. nms_keep_topk=100,
  1002. nms_iou_threshold=0.45):
  1003. self.init_params = locals()
  1004. if backbone != 'MobileNetV3':
  1005. logging.warning(
  1006. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1007. "Backbone is forcibly set to MobileNetV3.")
  1008. self.backbone_name = 'MobileNetV3'
  1009. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1010. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1011. norm_type = 'sync_bn'
  1012. else:
  1013. norm_type = 'bn'
  1014. backbone = self._get_backbone(
  1015. 'MobileNetV3',
  1016. model_name='large',
  1017. norm_type=norm_type,
  1018. scale=.5,
  1019. with_extra_blocks=False,
  1020. extra_block_filters=[],
  1021. feature_maps=[7, 13, 16])
  1022. downsample_ratios = [32, 16, 8]
  1023. neck = ppdet.modeling.PPYOLOTinyFPN(
  1024. detection_block_channels=[160, 128, 96],
  1025. in_channels=[i.channels for i in backbone.out_shape],
  1026. spp=use_spp,
  1027. drop_block=use_drop_block)
  1028. loss = ppdet.modeling.YOLOv3Loss(
  1029. num_classes=num_classes,
  1030. ignore_thresh=ignore_threshold,
  1031. downsample=downsample_ratios,
  1032. label_smooth=label_smooth,
  1033. scale_x_y=scale_x_y,
  1034. iou_loss=ppdet.modeling.IouLoss(
  1035. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1036. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1037. if use_iou_aware else None)
  1038. yolo_head = ppdet.modeling.YOLOv3Head(
  1039. in_channels=[i.channels for i in neck.out_shape],
  1040. anchors=anchors,
  1041. anchor_masks=anchor_masks,
  1042. num_classes=num_classes,
  1043. loss=loss,
  1044. iou_aware=use_iou_aware)
  1045. if use_matrix_nms:
  1046. nms = ppdet.modeling.MatrixNMS(
  1047. keep_top_k=nms_keep_topk,
  1048. score_threshold=nms_score_threshold,
  1049. post_threshold=.05,
  1050. nms_top_k=nms_topk,
  1051. background_label=-1)
  1052. else:
  1053. nms = ppdet.modeling.MultiClassNMS(
  1054. score_threshold=nms_score_threshold,
  1055. nms_top_k=nms_topk,
  1056. keep_top_k=nms_keep_topk,
  1057. nms_threshold=nms_iou_threshold)
  1058. post_process = ppdet.modeling.BBoxPostProcess(
  1059. decode=ppdet.modeling.YOLOBox(
  1060. num_classes=num_classes,
  1061. conf_thresh=.005,
  1062. downsample_ratio=32,
  1063. clip_bbox=True,
  1064. scale_x_y=scale_x_y),
  1065. nms=nms)
  1066. params = {
  1067. 'backbone': backbone,
  1068. 'neck': neck,
  1069. 'yolo_head': yolo_head,
  1070. 'post_process': post_process
  1071. }
  1072. super(YOLOv3, self).__init__(
  1073. model_name='YOLOv3', num_classes=num_classes, **params)
  1074. self.anchors = anchors
  1075. self.anchor_masks = anchor_masks
  1076. self.downsample_ratios = downsample_ratios
  1077. self.num_max_boxes = 100
  1078. self.model_name = 'PPYOLOTiny'
  1079. class PPYOLOv2(YOLOv3):
  1080. def __init__(self,
  1081. num_classes=80,
  1082. backbone='ResNet50_vd_dcn',
  1083. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1084. [59, 119], [116, 90], [156, 198], [373, 326]],
  1085. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1086. use_iou_aware=True,
  1087. use_spp=True,
  1088. use_drop_block=True,
  1089. scale_x_y=1.05,
  1090. ignore_threshold=0.7,
  1091. label_smooth=False,
  1092. use_iou_loss=True,
  1093. use_matrix_nms=True,
  1094. nms_score_threshold=0.01,
  1095. nms_topk=-1,
  1096. nms_keep_topk=100,
  1097. nms_iou_threshold=0.45):
  1098. self.init_params = locals()
  1099. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1100. raise ValueError(
  1101. "backbone: {} is not supported. Please choose one of "
  1102. "('ResNet50_vd_dcn', 'ResNet18_vd')".format(backbone))
  1103. self.backbone_name = backbone
  1104. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1105. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1106. norm_type = 'sync_bn'
  1107. else:
  1108. norm_type = 'bn'
  1109. if backbone == 'ResNet50_vd_dcn':
  1110. backbone = self._get_backbone(
  1111. 'ResNet',
  1112. variant='d',
  1113. norm_type=norm_type,
  1114. return_idx=[1, 2, 3],
  1115. dcn_v2_stages=[3],
  1116. freeze_at=-1,
  1117. freeze_norm=False,
  1118. norm_decay=0.)
  1119. downsample_ratios = [32, 16, 8]
  1120. elif backbone == 'ResNet101_vd_dcn':
  1121. backbone = self._get_backbone(
  1122. 'ResNet',
  1123. depth=101,
  1124. variant='d',
  1125. norm_type=norm_type,
  1126. return_idx=[1, 2, 3],
  1127. dcn_v2_stages=[3],
  1128. freeze_at=-1,
  1129. freeze_norm=False,
  1130. norm_decay=0.)
  1131. downsample_ratios = [32, 16, 8]
  1132. neck = ppdet.modeling.PPYOLOPAN(
  1133. norm_type=norm_type,
  1134. in_channels=[i.channels for i in backbone.out_shape],
  1135. drop_block=use_drop_block,
  1136. block_size=3,
  1137. keep_prob=.9,
  1138. spp=use_spp)
  1139. loss = ppdet.modeling.YOLOv3Loss(
  1140. num_classes=num_classes,
  1141. ignore_thresh=ignore_threshold,
  1142. downsample=downsample_ratios,
  1143. label_smooth=label_smooth,
  1144. scale_x_y=scale_x_y,
  1145. iou_loss=ppdet.modeling.IouLoss(
  1146. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1147. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1148. if use_iou_aware else None)
  1149. yolo_head = ppdet.modeling.YOLOv3Head(
  1150. in_channels=[i.channels for i in neck.out_shape],
  1151. anchors=anchors,
  1152. anchor_masks=anchor_masks,
  1153. num_classes=num_classes,
  1154. loss=loss,
  1155. iou_aware=use_iou_aware,
  1156. iou_aware_factor=.5)
  1157. if use_matrix_nms:
  1158. nms = ppdet.modeling.MatrixNMS(
  1159. keep_top_k=nms_keep_topk,
  1160. score_threshold=nms_score_threshold,
  1161. post_threshold=.01,
  1162. nms_top_k=nms_topk,
  1163. background_label=-1)
  1164. else:
  1165. nms = ppdet.modeling.MultiClassNMS(
  1166. score_threshold=nms_score_threshold,
  1167. nms_top_k=nms_topk,
  1168. keep_top_k=nms_keep_topk,
  1169. nms_threshold=nms_iou_threshold)
  1170. post_process = ppdet.modeling.BBoxPostProcess(
  1171. decode=ppdet.modeling.YOLOBox(
  1172. num_classes=num_classes,
  1173. conf_thresh=.01,
  1174. downsample_ratio=32,
  1175. clip_bbox=True,
  1176. scale_x_y=scale_x_y),
  1177. nms=nms)
  1178. params = {
  1179. 'backbone': backbone,
  1180. 'neck': neck,
  1181. 'yolo_head': yolo_head,
  1182. 'post_process': post_process
  1183. }
  1184. super(YOLOv3, self).__init__(
  1185. model_name='YOLOv3', num_classes=num_classes, **params)
  1186. self.anchors = anchors
  1187. self.anchor_masks = anchor_masks
  1188. self.downsample_ratios = downsample_ratios
  1189. self.num_max_boxes = 100
  1190. self.model_name = 'PPYOLOv2'
  1191. class MaskRCNN(BaseDetector):
  1192. def __init__(self,
  1193. num_classes=80,
  1194. backbone='ResNet50_vd',
  1195. with_fpn=True,
  1196. aspect_ratios=[0.5, 1.0, 2.0],
  1197. anchor_sizes=[[32], [64], [128], [256], [512]],
  1198. keep_top_k=100,
  1199. nms_threshold=0.5,
  1200. score_threshold=0.05,
  1201. fpn_num_channels=256,
  1202. rpn_batch_size_per_im=256,
  1203. rpn_fg_fraction=0.5,
  1204. test_pre_nms_top_n=None,
  1205. test_post_nms_top_n=1000):
  1206. self.init_params = locals()
  1207. if backbone not in [
  1208. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1209. 'ResNet101_vd'
  1210. ]:
  1211. raise ValueError(
  1212. "backbone: {} is not supported. Please choose one of "
  1213. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1214. format(backbone))
  1215. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1216. if backbone == 'ResNet50':
  1217. if with_fpn:
  1218. backbone = self._get_backbone(
  1219. 'ResNet',
  1220. norm_type='bn',
  1221. freeze_at=0,
  1222. return_idx=[0, 1, 2, 3],
  1223. num_stages=4)
  1224. else:
  1225. backbone = self._get_backbone(
  1226. 'ResNet',
  1227. norm_type='bn',
  1228. freeze_at=0,
  1229. return_idx=[2],
  1230. num_stages=3)
  1231. elif 'ResNet50_vd' in backbone:
  1232. if not with_fpn:
  1233. logging.warning(
  1234. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1235. format(backbone))
  1236. with_fpn = True
  1237. backbone = self._get_backbone(
  1238. 'ResNet',
  1239. variant='d',
  1240. norm_type='bn',
  1241. freeze_at=0,
  1242. return_idx=[0, 1, 2, 3],
  1243. num_stages=4,
  1244. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1245. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0])
  1246. else:
  1247. if not with_fpn:
  1248. logging.warning(
  1249. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1250. format(backbone))
  1251. with_fpn = True
  1252. backbone = self._get_backbone(
  1253. 'ResNet',
  1254. variant='d' if '_vd' in backbone else 'b',
  1255. depth=101,
  1256. norm_type='bn',
  1257. freeze_at=0,
  1258. return_idx=[0, 1, 2, 3],
  1259. num_stages=4)
  1260. rpn_in_channel = backbone.out_shape[0].channels
  1261. if with_fpn:
  1262. neck = ppdet.modeling.FPN(
  1263. in_channels=[i.channels for i in backbone.out_shape],
  1264. out_channel=fpn_num_channels,
  1265. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1266. rpn_in_channel = neck.out_shape[0].channels
  1267. anchor_generator_cfg = {
  1268. 'aspect_ratios': aspect_ratios,
  1269. 'anchor_sizes': anchor_sizes,
  1270. 'strides': [4, 8, 16, 32, 64]
  1271. }
  1272. train_proposal_cfg = {
  1273. 'min_size': 0.0,
  1274. 'nms_thresh': .7,
  1275. 'pre_nms_top_n': 2000,
  1276. 'post_nms_top_n': 1000,
  1277. 'topk_after_collect': True
  1278. }
  1279. test_proposal_cfg = {
  1280. 'min_size': 0.0,
  1281. 'nms_thresh': .7,
  1282. 'pre_nms_top_n': 1000
  1283. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1284. 'post_nms_top_n': test_post_nms_top_n
  1285. }
  1286. bb_head = ppdet.modeling.TwoFCHead(
  1287. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1288. bb_roi_extractor_cfg = {
  1289. 'resolution': 7,
  1290. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1291. 'sampling_ratio': 0,
  1292. 'aligned': True
  1293. }
  1294. with_pool = False
  1295. m_head = ppdet.modeling.MaskFeat(
  1296. in_channel=neck.out_shape[0].channels,
  1297. out_channel=256,
  1298. num_convs=4)
  1299. m_roi_extractor_cfg = {
  1300. 'resolution': 14,
  1301. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1302. 'sampling_ratio': 0,
  1303. 'aligned': True
  1304. }
  1305. mask_assigner = MaskAssigner(
  1306. num_classes=num_classes, mask_resolution=28)
  1307. share_bbox_feat = False
  1308. else:
  1309. neck = None
  1310. anchor_generator_cfg = {
  1311. 'aspect_ratios': aspect_ratios,
  1312. 'anchor_sizes': anchor_sizes,
  1313. 'strides': [16]
  1314. }
  1315. train_proposal_cfg = {
  1316. 'min_size': 0.0,
  1317. 'nms_thresh': .7,
  1318. 'pre_nms_top_n': 12000,
  1319. 'post_nms_top_n': 2000,
  1320. 'topk_after_collect': False
  1321. }
  1322. test_proposal_cfg = {
  1323. 'min_size': 0.0,
  1324. 'nms_thresh': .7,
  1325. 'pre_nms_top_n': 6000
  1326. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1327. 'post_nms_top_n': test_post_nms_top_n
  1328. }
  1329. bb_head = ppdet.modeling.Res5Head()
  1330. bb_roi_extractor_cfg = {
  1331. 'resolution': 14,
  1332. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1333. 'sampling_ratio': 0,
  1334. 'aligned': True
  1335. }
  1336. with_pool = True
  1337. m_head = ppdet.modeling.MaskFeat(
  1338. in_channel=bb_head.out_shape[0].channels,
  1339. out_channel=256,
  1340. num_convs=0)
  1341. m_roi_extractor_cfg = {
  1342. 'resolution': 14,
  1343. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1344. 'sampling_ratio': 0,
  1345. 'aligned': True
  1346. }
  1347. mask_assigner = MaskAssigner(
  1348. num_classes=num_classes, mask_resolution=14)
  1349. share_bbox_feat = True
  1350. rpn_target_assign_cfg = {
  1351. 'batch_size_per_im': rpn_batch_size_per_im,
  1352. 'fg_fraction': rpn_fg_fraction,
  1353. 'negative_overlap': .3,
  1354. 'positive_overlap': .7,
  1355. 'use_random': True
  1356. }
  1357. rpn_head = ppdet.modeling.RPNHead(
  1358. anchor_generator=anchor_generator_cfg,
  1359. rpn_target_assign=rpn_target_assign_cfg,
  1360. train_proposal=train_proposal_cfg,
  1361. test_proposal=test_proposal_cfg,
  1362. in_channel=rpn_in_channel)
  1363. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1364. bbox_head = ppdet.modeling.BBoxHead(
  1365. head=bb_head,
  1366. in_channel=bb_head.out_shape[0].channels,
  1367. roi_extractor=bb_roi_extractor_cfg,
  1368. with_pool=with_pool,
  1369. bbox_assigner=bbox_assigner,
  1370. num_classes=num_classes)
  1371. mask_head = ppdet.modeling.MaskHead(
  1372. head=m_head,
  1373. roi_extractor=m_roi_extractor_cfg,
  1374. mask_assigner=mask_assigner,
  1375. share_bbox_feat=share_bbox_feat,
  1376. num_classes=num_classes)
  1377. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1378. num_classes=num_classes,
  1379. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1380. nms=ppdet.modeling.MultiClassNMS(
  1381. score_threshold=score_threshold,
  1382. keep_top_k=keep_top_k,
  1383. nms_threshold=nms_threshold))
  1384. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1385. params = {
  1386. 'backbone': backbone,
  1387. 'neck': neck,
  1388. 'rpn_head': rpn_head,
  1389. 'bbox_head': bbox_head,
  1390. 'mask_head': mask_head,
  1391. 'bbox_post_process': bbox_post_process,
  1392. 'mask_post_process': mask_post_process
  1393. }
  1394. self.with_fpn = with_fpn
  1395. super(MaskRCNN, self).__init__(
  1396. model_name='MaskRCNN', num_classes=num_classes, **params)
  1397. def _compose_batch_transform(self, transforms, mode='train'):
  1398. if mode == 'train':
  1399. default_batch_transforms = [
  1400. _BatchPadding(
  1401. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  1402. ]
  1403. else:
  1404. default_batch_transforms = [
  1405. _BatchPadding(
  1406. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  1407. ]
  1408. custom_batch_transforms = []
  1409. for i, op in enumerate(transforms.transforms):
  1410. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1411. if mode != 'train':
  1412. raise Exception(
  1413. "{} cannot be present in the {} transforms. ".format(
  1414. op.__class__.__name__, mode) +
  1415. "Please check the {} transforms.".format(mode))
  1416. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1417. batch_transforms = BatchCompose(custom_batch_transforms +
  1418. default_batch_transforms)
  1419. return batch_transforms
  1420. def _fix_transforms_shape(self, image_shape):
  1421. if hasattr(self, 'test_transforms'):
  1422. if self.test_transforms is not None:
  1423. has_resize_op = False
  1424. resize_op_idx = -1
  1425. normalize_op_idx = len(self.test_transforms.transforms)
  1426. for idx, op in enumerate(self.test_transforms.transforms):
  1427. name = op.__class__.__name__
  1428. if name == 'ResizeByShort':
  1429. has_resize_op = True
  1430. resize_op_idx = idx
  1431. if name == 'Normalize':
  1432. normalize_op_idx = idx
  1433. if not has_resize_op:
  1434. self.test_transforms.transforms.insert(
  1435. normalize_op_idx,
  1436. Resize(
  1437. target_size=image_shape, keep_ratio=True))
  1438. else:
  1439. self.test_transforms.transforms[resize_op_idx] = Resize(
  1440. target_size=image_shape, keep_ratio=True)
  1441. self.test_transforms.transforms.append(
  1442. Padding(im_padding_value=[0., 0., 0.]))