detector.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import ppdet
  24. from ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. del self.init_params['params']
  41. super(BaseDetector, self).__init__('detector')
  42. if not hasattr(ppdet.modeling, model_name):
  43. raise Exception("ERROR: There's no model named {}.".format(
  44. model_name))
  45. self.model_name = model_name
  46. self.num_classes = num_classes
  47. self.labels = None
  48. self.net = self.build_net(**params)
  49. def build_net(self, **params):
  50. with paddle.utils.unique_name.guard():
  51. net = ppdet.modeling.__dict__[self.model_name](**params)
  52. return net
  53. def get_test_inputs(self, image_shape):
  54. input_spec = [{
  55. "image": InputSpec(
  56. shape=[None, 3] + image_shape, name='image', dtype='float32'),
  57. "im_shape": InputSpec(
  58. shape=[None, 2], name='im_shape', dtype='float32'),
  59. "scale_factor": InputSpec(
  60. shape=[None, 2], name='scale_factor', dtype='float32')
  61. }]
  62. return input_spec
  63. def _get_backbone(self, backbone_name, **params):
  64. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  65. return backbone
  66. def run(self, net, inputs, mode):
  67. net_out = net(inputs)
  68. if mode in ['train', 'eval']:
  69. outputs = net_out
  70. else:
  71. for key in ['im_shape', 'scale_factor']:
  72. net_out[key] = inputs[key]
  73. outputs = dict()
  74. for key in net_out:
  75. outputs[key] = net_out[key].numpy()
  76. return outputs
  77. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  78. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  79. num_steps_each_epoch):
  80. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  81. values = [(lr_decay_gamma**i) * learning_rate
  82. for i in range(len(lr_decay_epochs) + 1)]
  83. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  84. boundaries=boundaries, values=values)
  85. if warmup_steps > 0:
  86. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  87. logging.error(
  88. "In function train(), parameters should satisfy: "
  89. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  90. exit=False)
  91. logging.error(
  92. "See this doc for more information: "
  93. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  94. exit=False)
  95. scheduler = paddle.optimizer.lr.LinearWarmup(
  96. learning_rate=scheduler,
  97. warmup_steps=warmup_steps,
  98. start_lr=warmup_start_lr,
  99. end_lr=learning_rate)
  100. optimizer = paddle.optimizer.Momentum(
  101. scheduler,
  102. momentum=.9,
  103. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  104. parameters=parameters)
  105. return optimizer
  106. def train(self,
  107. num_epochs,
  108. train_dataset,
  109. train_batch_size=64,
  110. eval_dataset=None,
  111. optimizer=None,
  112. save_interval_epochs=1,
  113. log_interval_steps=10,
  114. save_dir='output',
  115. pretrain_weights='IMAGENET',
  116. learning_rate=.001,
  117. warmup_steps=0,
  118. warmup_start_lr=0.0,
  119. lr_decay_epochs=(216, 243),
  120. lr_decay_gamma=0.1,
  121. metric=None,
  122. use_ema=False,
  123. early_stop=False,
  124. early_stop_patience=5,
  125. use_vdl=True):
  126. """
  127. Train the model.
  128. Args:
  129. num_epochs(int): The number of epochs.
  130. train_dataset(paddlex.dataset): Training dataset.
  131. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  132. eval_dataset(paddlex.dataset, optional):
  133. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  134. optimizer(paddle.optimizer.Optimizer or None, optional):
  135. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  136. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  137. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  138. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  139. pretrain_weights(str or None, optional):
  140. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  141. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  142. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  143. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  144. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  145. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  146. metric({'VOC', 'COCO', None}, optional):
  147. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  148. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  149. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  150. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  151. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  152. """
  153. if train_dataset.__class__.__name__ == 'VOCDetection':
  154. train_dataset.data_fields = {
  155. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  156. 'difficult'
  157. }
  158. elif train_dataset.__class__.__name__ == 'CocoDetection':
  159. if self.__class__.__name__ == 'MaskRCNN':
  160. train_dataset.data_fields = {
  161. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  162. 'gt_poly', 'is_crowd'
  163. }
  164. else:
  165. train_dataset.data_fields = {
  166. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  167. 'is_crowd'
  168. }
  169. if metric is None:
  170. if eval_dataset.__class__.__name__ == 'VOCDetection':
  171. self.metric = 'voc'
  172. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  173. self.metric = 'coco'
  174. else:
  175. assert metric.lower() in ['coco', 'voc'], \
  176. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  177. self.metric = metric.lower()
  178. self.labels = train_dataset.labels
  179. self.num_max_boxes = train_dataset.num_max_boxes
  180. train_dataset.batch_transforms = self._compose_batch_transform(
  181. train_dataset.transforms, mode='train')
  182. # build optimizer if not defined
  183. if optimizer is None:
  184. num_steps_each_epoch = len(train_dataset) // train_batch_size
  185. self.optimizer = self.default_optimizer(
  186. parameters=self.net.parameters(),
  187. learning_rate=learning_rate,
  188. warmup_steps=warmup_steps,
  189. warmup_start_lr=warmup_start_lr,
  190. lr_decay_epochs=lr_decay_epochs,
  191. lr_decay_gamma=lr_decay_gamma,
  192. num_steps_each_epoch=num_steps_each_epoch)
  193. else:
  194. self.optimizer = optimizer
  195. # initiate weights
  196. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  197. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  198. [self.model_name, self.backbone_name])]:
  199. logging.warning(
  200. "Path of pretrain_weights('{}') does not exist!".format(
  201. pretrain_weights))
  202. pretrain_weights = det_pretrain_weights_dict['_'.join(
  203. [self.model_name, self.backbone_name])][0]
  204. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  205. "If you don't want to use pretrain weights, "
  206. "set pretrain_weights to be None.".format(
  207. pretrain_weights))
  208. pretrained_dir = osp.join(save_dir, 'pretrain')
  209. self.net_initialize(
  210. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  211. if use_ema:
  212. ema = ExponentialMovingAverage(
  213. decay=.9998, model=self.net, use_thres_step=True)
  214. else:
  215. ema = None
  216. # start train loop
  217. self.train_loop(
  218. num_epochs=num_epochs,
  219. train_dataset=train_dataset,
  220. train_batch_size=train_batch_size,
  221. eval_dataset=eval_dataset,
  222. save_interval_epochs=save_interval_epochs,
  223. log_interval_steps=log_interval_steps,
  224. save_dir=save_dir,
  225. ema=ema,
  226. early_stop=early_stop,
  227. early_stop_patience=early_stop_patience,
  228. use_vdl=use_vdl)
  229. def quant_aware_train(self,
  230. num_epochs,
  231. train_dataset,
  232. train_batch_size=64,
  233. eval_dataset=None,
  234. optimizer=None,
  235. save_interval_epochs=1,
  236. log_interval_steps=10,
  237. save_dir='output',
  238. learning_rate=.00001,
  239. warmup_steps=0,
  240. warmup_start_lr=0.0,
  241. lr_decay_epochs=(216, 243),
  242. lr_decay_gamma=0.1,
  243. metric=None,
  244. use_ema=False,
  245. early_stop=False,
  246. early_stop_patience=5,
  247. use_vdl=True,
  248. quant_config=None):
  249. """
  250. Quantization-aware training.
  251. Args:
  252. num_epochs(int): The number of epochs.
  253. train_dataset(paddlex.dataset): Training dataset.
  254. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  255. eval_dataset(paddlex.dataset, optional):
  256. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  257. optimizer(paddle.optimizer.Optimizer or None, optional):
  258. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  259. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  260. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  261. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  262. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  263. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  264. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  265. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  266. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  267. metric({'VOC', 'COCO', None}, optional):
  268. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  269. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  270. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  271. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  272. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  273. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  274. configuration will be used. Defaults to None.
  275. """
  276. self._prepare_qat(quant_config)
  277. self.train(
  278. num_epochs=num_epochs,
  279. train_dataset=train_dataset,
  280. train_batch_size=train_batch_size,
  281. eval_dataset=eval_dataset,
  282. optimizer=optimizer,
  283. save_interval_epochs=save_interval_epochs,
  284. log_interval_steps=log_interval_steps,
  285. save_dir=save_dir,
  286. pretrain_weights=None,
  287. learning_rate=learning_rate,
  288. warmup_steps=warmup_steps,
  289. warmup_start_lr=warmup_start_lr,
  290. lr_decay_epochs=lr_decay_epochs,
  291. lr_decay_gamma=lr_decay_gamma,
  292. metric=metric,
  293. use_ema=use_ema,
  294. early_stop=early_stop,
  295. early_stop_patience=early_stop_patience,
  296. use_vdl=use_vdl)
  297. def evaluate(self,
  298. eval_dataset,
  299. batch_size=1,
  300. metric=None,
  301. return_details=False):
  302. """
  303. Evaluate the model.
  304. Args:
  305. eval_dataset(paddlex.dataset): Evaluation dataset.
  306. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  307. metric({'VOC', 'COCO', None}, optional):
  308. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  309. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  310. Returns:
  311. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  312. """
  313. if metric is None:
  314. if not hasattr(self, 'metric'):
  315. if eval_dataset.__class__.__name__ == 'VOCDetection':
  316. self.metric = 'voc'
  317. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  318. self.metric = 'coco'
  319. else:
  320. assert metric.lower() in ['coco', 'voc'], \
  321. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  322. self.metric = metric.lower()
  323. if self.metric == 'voc':
  324. eval_dataset.data_fields = {
  325. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  326. 'difficult'
  327. }
  328. elif self.metric == 'coco':
  329. if self.__class__.__name__ == 'MaskRCNN':
  330. eval_dataset.data_fields = {
  331. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  332. 'gt_poly', 'is_crowd'
  333. }
  334. else:
  335. eval_dataset.data_fields = {
  336. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  337. 'is_crowd'
  338. }
  339. eval_dataset.batch_transforms = self._compose_batch_transform(
  340. eval_dataset.transforms, mode='eval')
  341. arrange_transforms(
  342. model_type=self.model_type,
  343. transforms=eval_dataset.transforms,
  344. mode='eval')
  345. self.net.eval()
  346. nranks = paddle.distributed.get_world_size()
  347. local_rank = paddle.distributed.get_rank()
  348. if nranks > 1:
  349. # Initialize parallel environment if not done.
  350. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  351. ):
  352. paddle.distributed.init_parallel_env()
  353. if batch_size > 1:
  354. logging.warning(
  355. "Detector only supports single card evaluation with batch_size=1 "
  356. "during evaluation, so batch_size is forcibly set to 1.")
  357. batch_size = 1
  358. if nranks < 2 or local_rank == 0:
  359. self.eval_data_loader = self.build_data_loader(
  360. eval_dataset, batch_size=batch_size, mode='eval')
  361. is_bbox_normalized = False
  362. if eval_dataset.batch_transforms is not None:
  363. is_bbox_normalized = any(
  364. isinstance(t, _NormalizeBox)
  365. for t in eval_dataset.batch_transforms.batch_transforms)
  366. if self.metric == 'voc':
  367. eval_metric = VOCMetric(
  368. labels=eval_dataset.labels,
  369. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  370. is_bbox_normalized=is_bbox_normalized,
  371. classwise=False)
  372. else:
  373. eval_metric = COCOMetric(
  374. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  375. classwise=False)
  376. scores = collections.OrderedDict()
  377. logging.info(
  378. "Start to evaluate(total_samples={}, total_steps={})...".
  379. format(eval_dataset.num_samples, eval_dataset.num_samples))
  380. with paddle.no_grad():
  381. for step, data in enumerate(self.eval_data_loader):
  382. outputs = self.run(self.net, data, 'eval')
  383. eval_metric.update(data, outputs)
  384. eval_metric.accumulate()
  385. self.eval_details = eval_metric.details
  386. scores.update(eval_metric.get())
  387. eval_metric.reset()
  388. if return_details:
  389. return scores, self.eval_details
  390. return scores
  391. def predict(self, img_file, transforms=None):
  392. """
  393. Do inference.
  394. Args:
  395. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  396. Image path or decoded image data in a BGR format, which also could constitute a list,
  397. meaning all images to be predicted as a mini-batch.
  398. transforms(paddlex.transforms.Compose or None, optional):
  399. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  400. Returns:
  401. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  402. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  403. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  404. category_id(int): the predicted category ID
  405. category(str): category name
  406. bbox(list): bounding box in [x, y, w, h] format
  407. score(str): confidence
  408. """
  409. if transforms is None and not hasattr(self, 'test_transforms'):
  410. raise Exception("transforms need to be defined, now is None.")
  411. if transforms is None:
  412. transforms = self.test_transforms
  413. if isinstance(img_file, (str, np.ndarray)):
  414. images = [img_file]
  415. else:
  416. images = img_file
  417. batch_samples = self._preprocess(images, transforms)
  418. self.net.eval()
  419. outputs = self.run(self.net, batch_samples, 'test')
  420. prediction = self._postprocess(outputs)
  421. if isinstance(img_file, (str, np.ndarray)):
  422. prediction = prediction[0]
  423. return prediction
  424. def _preprocess(self, images, transforms):
  425. arrange_transforms(
  426. model_type=self.model_type, transforms=transforms, mode='test')
  427. batch_samples = list()
  428. for im in images:
  429. sample = {'image': im}
  430. batch_samples.append(transforms(sample))
  431. batch_transforms = self._compose_batch_transform(transforms, 'test')
  432. batch_samples = batch_transforms(batch_samples)
  433. for k, v in batch_samples.items():
  434. batch_samples[k] = paddle.to_tensor(v)
  435. return batch_samples
  436. def _postprocess(self, batch_pred):
  437. infer_result = {}
  438. if 'bbox' in batch_pred:
  439. bboxes = batch_pred['bbox']
  440. bbox_nums = batch_pred['bbox_num']
  441. det_res = []
  442. k = 0
  443. for i in range(len(bbox_nums)):
  444. det_nums = bbox_nums[i]
  445. for j in range(det_nums):
  446. dt = bboxes[k]
  447. k = k + 1
  448. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  449. if int(num_id) < 0:
  450. continue
  451. category = self.labels[int(num_id)]
  452. w = xmax - xmin
  453. h = ymax - ymin
  454. bbox = [xmin, ymin, w, h]
  455. dt_res = {
  456. 'category_id': int(num_id),
  457. 'category': category,
  458. 'bbox': bbox,
  459. 'score': score
  460. }
  461. det_res.append(dt_res)
  462. infer_result['bbox'] = det_res
  463. if 'mask' in batch_pred:
  464. masks = batch_pred['mask']
  465. bboxes = batch_pred['bbox']
  466. mask_nums = batch_pred['bbox_num']
  467. seg_res = []
  468. k = 0
  469. for i in range(len(mask_nums)):
  470. det_nums = mask_nums[i]
  471. for j in range(det_nums):
  472. mask = masks[k].astype(np.uint8)
  473. score = float(bboxes[k][1])
  474. label = int(bboxes[k][0])
  475. k = k + 1
  476. if label == -1:
  477. continue
  478. category = self.labels[int(label)]
  479. import pycocotools.mask as mask_util
  480. rle = mask_util.encode(
  481. np.array(
  482. mask[:, :, None], order="F", dtype="uint8"))[0]
  483. if six.PY3:
  484. if 'counts' in rle:
  485. rle['counts'] = rle['counts'].decode("utf8")
  486. sg_res = {
  487. 'category': category,
  488. 'segmentation': rle,
  489. 'score': score
  490. }
  491. seg_res.append(sg_res)
  492. infer_result['mask'] = seg_res
  493. bbox_num = batch_pred['bbox_num']
  494. results = []
  495. start = 0
  496. for num in bbox_num:
  497. end = start + num
  498. curr_res = infer_result['bbox'][start:end]
  499. if 'mask' in infer_result:
  500. mask_res = infer_result['mask'][start:end]
  501. for box, mask in zip(curr_res, mask_res):
  502. box.update(mask)
  503. results.append(curr_res)
  504. start = end
  505. return results
  506. class YOLOv3(BaseDetector):
  507. def __init__(self,
  508. num_classes=80,
  509. backbone='MobileNetV1',
  510. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  511. [59, 119], [116, 90], [156, 198], [373, 326]],
  512. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  513. ignore_threshold=0.7,
  514. nms_score_threshold=0.01,
  515. nms_topk=1000,
  516. nms_keep_topk=100,
  517. nms_iou_threshold=0.45,
  518. label_smooth=False):
  519. self.init_params = locals()
  520. if backbone not in [
  521. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  522. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  523. ]:
  524. raise ValueError(
  525. "backbone: {} is not supported. Please choose one of "
  526. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  527. format(backbone))
  528. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  529. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  530. norm_type = 'sync_bn'
  531. else:
  532. norm_type = 'bn'
  533. self.backbone_name = backbone
  534. if 'MobileNetV1' in backbone:
  535. norm_type = 'bn'
  536. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  537. elif 'MobileNetV3' in backbone:
  538. backbone = self._get_backbone(
  539. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  540. elif backbone == 'ResNet50_vd_dcn':
  541. backbone = self._get_backbone(
  542. 'ResNet',
  543. norm_type=norm_type,
  544. variant='d',
  545. return_idx=[1, 2, 3],
  546. dcn_v2_stages=[3],
  547. freeze_at=-1,
  548. freeze_norm=False)
  549. elif backbone == 'ResNet34':
  550. backbone = self._get_backbone(
  551. 'ResNet',
  552. depth=34,
  553. norm_type=norm_type,
  554. return_idx=[1, 2, 3],
  555. freeze_at=-1,
  556. freeze_norm=False,
  557. norm_decay=0.)
  558. else:
  559. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  560. neck = ppdet.modeling.YOLOv3FPN(
  561. norm_type=norm_type,
  562. in_channels=[i.channels for i in backbone.out_shape])
  563. loss = ppdet.modeling.YOLOv3Loss(
  564. num_classes=num_classes,
  565. ignore_thresh=ignore_threshold,
  566. label_smooth=label_smooth)
  567. yolo_head = ppdet.modeling.YOLOv3Head(
  568. in_channels=[i.channels for i in neck.out_shape],
  569. anchors=anchors,
  570. anchor_masks=anchor_masks,
  571. num_classes=num_classes,
  572. loss=loss)
  573. post_process = ppdet.modeling.BBoxPostProcess(
  574. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  575. nms=ppdet.modeling.MultiClassNMS(
  576. score_threshold=nms_score_threshold,
  577. nms_top_k=nms_topk,
  578. keep_top_k=nms_keep_topk,
  579. nms_threshold=nms_iou_threshold))
  580. params = {
  581. 'backbone': backbone,
  582. 'neck': neck,
  583. 'yolo_head': yolo_head,
  584. 'post_process': post_process
  585. }
  586. super(YOLOv3, self).__init__(
  587. model_name='YOLOv3', num_classes=num_classes, **params)
  588. self.anchors = anchors
  589. self.anchor_masks = anchor_masks
  590. def _compose_batch_transform(self, transforms, mode='train'):
  591. if mode == 'train':
  592. default_batch_transforms = [
  593. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  594. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  595. _Gt2YoloTarget(
  596. anchor_masks=self.anchor_masks,
  597. anchors=self.anchors,
  598. downsample_ratios=getattr(self, 'downsample_ratios',
  599. [32, 16, 8]),
  600. num_classes=self.num_classes)
  601. ]
  602. else:
  603. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  604. if mode == 'eval' and self.metric == 'voc':
  605. collate_batch = False
  606. else:
  607. collate_batch = True
  608. custom_batch_transforms = []
  609. for i, op in enumerate(transforms.transforms):
  610. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  611. if mode != 'train':
  612. raise Exception(
  613. "{} cannot be present in the {} transforms. ".format(
  614. op.__class__.__name__, mode) +
  615. "Please check the {} transforms.".format(mode))
  616. custom_batch_transforms.insert(0, copy.deepcopy(op))
  617. batch_transforms = BatchCompose(
  618. custom_batch_transforms + default_batch_transforms,
  619. collate_batch=collate_batch)
  620. return batch_transforms
  621. class FasterRCNN(BaseDetector):
  622. def __init__(self,
  623. num_classes=80,
  624. backbone='ResNet50',
  625. with_fpn=True,
  626. aspect_ratios=[0.5, 1.0, 2.0],
  627. anchor_sizes=[[32], [64], [128], [256], [512]],
  628. keep_top_k=100,
  629. nms_threshold=0.5,
  630. score_threshold=0.05,
  631. fpn_num_channels=256,
  632. rpn_batch_size_per_im=256,
  633. rpn_fg_fraction=0.5,
  634. test_pre_nms_top_n=None,
  635. test_post_nms_top_n=1000):
  636. self.init_params = locals()
  637. if backbone not in [
  638. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  639. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  640. ]:
  641. raise ValueError(
  642. "backbone: {} is not supported. Please choose one of "
  643. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  644. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  645. self.backbone_name = backbone
  646. if backbone == 'HRNet_W18':
  647. if not with_fpn:
  648. logging.warning(
  649. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  650. format(backbone))
  651. with_fpn = True
  652. backbone = self._get_backbone(
  653. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  654. elif backbone == 'ResNet50_vd_ssld':
  655. if not with_fpn:
  656. logging.warning(
  657. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  658. format(backbone))
  659. with_fpn = True
  660. backbone = self._get_backbone(
  661. 'ResNet',
  662. variant='d',
  663. norm_type='bn',
  664. freeze_at=0,
  665. return_idx=[0, 1, 2, 3],
  666. num_stages=4,
  667. lr_mult_list=[0.05, 0.05, 0.1, 0.15])
  668. elif 'ResNet50' in backbone:
  669. if with_fpn:
  670. backbone = self._get_backbone(
  671. 'ResNet',
  672. variant='d' if '_vd' in backbone else 'b',
  673. norm_type='bn',
  674. freeze_at=0,
  675. return_idx=[0, 1, 2, 3],
  676. num_stages=4)
  677. else:
  678. backbone = self._get_backbone(
  679. 'ResNet',
  680. variant='d' if '_vd' in backbone else 'b',
  681. norm_type='bn',
  682. freeze_at=0,
  683. return_idx=[2],
  684. num_stages=3)
  685. elif 'ResNet34' in backbone:
  686. if not with_fpn:
  687. logging.warning(
  688. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  689. format(backbone))
  690. with_fpn = True
  691. backbone = self._get_backbone(
  692. 'ResNet',
  693. depth=34,
  694. variant='d' if 'vd' in backbone else 'b',
  695. norm_type='bn',
  696. freeze_at=0,
  697. return_idx=[0, 1, 2, 3],
  698. num_stages=4)
  699. else:
  700. if not with_fpn:
  701. logging.warning(
  702. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  703. format(backbone))
  704. with_fpn = True
  705. backbone = self._get_backbone(
  706. 'ResNet',
  707. depth=101,
  708. variant='d' if 'vd' in backbone else 'b',
  709. norm_type='bn',
  710. freeze_at=0,
  711. return_idx=[0, 1, 2, 3],
  712. num_stages=4)
  713. rpn_in_channel = backbone.out_shape[0].channels
  714. if with_fpn:
  715. self.backbone_name = self.backbone_name + '_fpn'
  716. if 'HRNet' in self.backbone_name:
  717. neck = ppdet.modeling.HRFPN(
  718. in_channels=[i.channels for i in backbone.out_shape],
  719. out_channel=fpn_num_channels,
  720. spatial_scales=[
  721. 1.0 / i.stride for i in backbone.out_shape
  722. ],
  723. share_conv=False)
  724. else:
  725. neck = ppdet.modeling.FPN(
  726. in_channels=[i.channels for i in backbone.out_shape],
  727. out_channel=fpn_num_channels,
  728. spatial_scales=[
  729. 1.0 / i.stride for i in backbone.out_shape
  730. ])
  731. rpn_in_channel = neck.out_shape[0].channels
  732. anchor_generator_cfg = {
  733. 'aspect_ratios': aspect_ratios,
  734. 'anchor_sizes': anchor_sizes,
  735. 'strides': [4, 8, 16, 32, 64]
  736. }
  737. train_proposal_cfg = {
  738. 'min_size': 0.0,
  739. 'nms_thresh': .7,
  740. 'pre_nms_top_n': 2000,
  741. 'post_nms_top_n': 1000,
  742. 'topk_after_collect': True
  743. }
  744. test_proposal_cfg = {
  745. 'min_size': 0.0,
  746. 'nms_thresh': .7,
  747. 'pre_nms_top_n': 1000
  748. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  749. 'post_nms_top_n': test_post_nms_top_n
  750. }
  751. head = ppdet.modeling.TwoFCHead(out_channel=1024)
  752. roi_extractor_cfg = {
  753. 'resolution': 7,
  754. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  755. 'sampling_ratio': 0,
  756. 'aligned': True
  757. }
  758. with_pool = False
  759. else:
  760. neck = None
  761. anchor_generator_cfg = {
  762. 'aspect_ratios': aspect_ratios,
  763. 'anchor_sizes': anchor_sizes,
  764. 'strides': [16]
  765. }
  766. train_proposal_cfg = {
  767. 'min_size': 0.0,
  768. 'nms_thresh': .7,
  769. 'pre_nms_top_n': 12000,
  770. 'post_nms_top_n': 2000,
  771. 'topk_after_collect': False
  772. }
  773. test_proposal_cfg = {
  774. 'min_size': 0.0,
  775. 'nms_thresh': .7,
  776. 'pre_nms_top_n': 6000
  777. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  778. 'post_nms_top_n': test_post_nms_top_n
  779. }
  780. head = ppdet.modeling.Res5Head()
  781. roi_extractor_cfg = {
  782. 'resolution': 14,
  783. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  784. 'sampling_ratio': 0,
  785. 'aligned': True
  786. }
  787. with_pool = True
  788. rpn_target_assign_cfg = {
  789. 'batch_size_per_im': rpn_batch_size_per_im,
  790. 'fg_fraction': rpn_fg_fraction,
  791. 'negative_overlap': .3,
  792. 'positive_overlap': .7,
  793. 'use_random': True
  794. }
  795. rpn_head = ppdet.modeling.RPNHead(
  796. anchor_generator=anchor_generator_cfg,
  797. rpn_target_assign=rpn_target_assign_cfg,
  798. train_proposal=train_proposal_cfg,
  799. test_proposal=test_proposal_cfg,
  800. in_channel=rpn_in_channel)
  801. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  802. bbox_head = ppdet.modeling.BBoxHead(
  803. head=head,
  804. in_channel=head.out_shape[0].channels,
  805. roi_extractor=roi_extractor_cfg,
  806. with_pool=with_pool,
  807. bbox_assigner=bbox_assigner,
  808. num_classes=num_classes)
  809. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  810. num_classes=num_classes,
  811. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  812. nms=ppdet.modeling.MultiClassNMS(
  813. score_threshold=score_threshold,
  814. keep_top_k=keep_top_k,
  815. nms_threshold=nms_threshold))
  816. params = {
  817. 'backbone': backbone,
  818. 'neck': neck,
  819. 'rpn_head': rpn_head,
  820. 'bbox_head': bbox_head,
  821. 'bbox_post_process': bbox_post_process
  822. }
  823. self.with_fpn = with_fpn
  824. super(FasterRCNN, self).__init__(
  825. model_name='FasterRCNN', num_classes=num_classes, **params)
  826. def _compose_batch_transform(self, transforms, mode='train'):
  827. if mode == 'train':
  828. default_batch_transforms = [
  829. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  830. ]
  831. collate_batch = False
  832. else:
  833. default_batch_transforms = [
  834. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  835. ]
  836. collate_batch = True
  837. custom_batch_transforms = []
  838. for i, op in enumerate(transforms.transforms):
  839. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  840. if mode != 'train':
  841. raise Exception(
  842. "{} cannot be present in the {} transforms. ".format(
  843. op.__class__.__name__, mode) +
  844. "Please check the {} transforms.".format(mode))
  845. custom_batch_transforms.insert(0, copy.deepcopy(op))
  846. batch_transforms = BatchCompose(
  847. custom_batch_transforms + default_batch_transforms,
  848. collate_batch=collate_batch)
  849. return batch_transforms
  850. class PPYOLO(YOLOv3):
  851. def __init__(self,
  852. num_classes=80,
  853. backbone='ResNet50_vd_dcn',
  854. anchors=None,
  855. anchor_masks=None,
  856. use_coord_conv=True,
  857. use_iou_aware=True,
  858. use_spp=True,
  859. use_drop_block=True,
  860. scale_x_y=1.05,
  861. ignore_threshold=0.7,
  862. label_smooth=False,
  863. use_iou_loss=True,
  864. use_matrix_nms=True,
  865. nms_score_threshold=0.01,
  866. nms_topk=-1,
  867. nms_keep_topk=100,
  868. nms_iou_threshold=0.45):
  869. self.init_params = locals()
  870. if backbone not in [
  871. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  872. 'MobileNetV3_small'
  873. ]:
  874. raise ValueError(
  875. "backbone: {} is not supported. Please choose one of "
  876. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  877. format(backbone))
  878. self.backbone_name = backbone
  879. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  880. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  881. norm_type = 'sync_bn'
  882. else:
  883. norm_type = 'bn'
  884. if anchors is None and anchor_masks is None:
  885. if 'MobileNetV3' in backbone:
  886. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  887. [120, 195], [254, 235]]
  888. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  889. elif backbone == 'ResNet50_vd_dcn':
  890. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  891. [59, 119], [116, 90], [156, 198], [373, 326]]
  892. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  893. else:
  894. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  895. [344, 319]]
  896. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  897. elif anchors is None or anchor_masks is None:
  898. raise ValueError("Please define both anchors and anchor_masks.")
  899. if backbone == 'ResNet50_vd_dcn':
  900. backbone = self._get_backbone(
  901. 'ResNet',
  902. variant='d',
  903. norm_type=norm_type,
  904. return_idx=[1, 2, 3],
  905. dcn_v2_stages=[3],
  906. freeze_at=-1,
  907. freeze_norm=False,
  908. norm_decay=0.)
  909. downsample_ratios = [32, 16, 8]
  910. elif backbone == 'ResNet18_vd':
  911. backbone = self._get_backbone(
  912. 'ResNet',
  913. depth=18,
  914. variant='d',
  915. norm_type=norm_type,
  916. return_idx=[2, 3],
  917. freeze_at=-1,
  918. freeze_norm=False,
  919. norm_decay=0.)
  920. downsample_ratios = [32, 16, 8]
  921. elif backbone == 'MobileNetV3_large':
  922. backbone = self._get_backbone(
  923. 'MobileNetV3',
  924. model_name='large',
  925. norm_type=norm_type,
  926. scale=1,
  927. with_extra_blocks=False,
  928. extra_block_filters=[],
  929. feature_maps=[13, 16])
  930. downsample_ratios = [32, 16]
  931. elif backbone == 'MobileNetV3_small':
  932. backbone = self._get_backbone(
  933. 'MobileNetV3',
  934. model_name='small',
  935. norm_type=norm_type,
  936. scale=1,
  937. with_extra_blocks=False,
  938. extra_block_filters=[],
  939. feature_maps=[9, 12])
  940. downsample_ratios = [32, 16]
  941. neck = ppdet.modeling.PPYOLOFPN(
  942. norm_type=norm_type,
  943. in_channels=[i.channels for i in backbone.out_shape],
  944. coord_conv=use_coord_conv,
  945. drop_block=use_drop_block,
  946. spp=use_spp,
  947. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  948. self.backbone_name == 'ResNet18_vd') else 2)
  949. loss = ppdet.modeling.YOLOv3Loss(
  950. num_classes=num_classes,
  951. ignore_thresh=ignore_threshold,
  952. downsample=downsample_ratios,
  953. label_smooth=label_smooth,
  954. scale_x_y=scale_x_y,
  955. iou_loss=ppdet.modeling.IouLoss(
  956. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  957. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  958. if use_iou_aware else None)
  959. yolo_head = ppdet.modeling.YOLOv3Head(
  960. in_channels=[i.channels for i in neck.out_shape],
  961. anchors=anchors,
  962. anchor_masks=anchor_masks,
  963. num_classes=num_classes,
  964. loss=loss,
  965. iou_aware=use_iou_aware)
  966. if use_matrix_nms:
  967. nms = ppdet.modeling.MatrixNMS(
  968. keep_top_k=nms_keep_topk,
  969. score_threshold=nms_score_threshold,
  970. post_threshold=.05
  971. if 'MobileNetV3' in self.backbone_name else .01,
  972. nms_top_k=nms_topk,
  973. background_label=-1)
  974. else:
  975. nms = ppdet.modeling.MultiClassNMS(
  976. score_threshold=nms_score_threshold,
  977. nms_top_k=nms_topk,
  978. keep_top_k=nms_keep_topk,
  979. nms_threshold=nms_iou_threshold)
  980. post_process = ppdet.modeling.BBoxPostProcess(
  981. decode=ppdet.modeling.YOLOBox(
  982. num_classes=num_classes,
  983. conf_thresh=.005
  984. if 'MobileNetV3' in self.backbone_name else .01,
  985. scale_x_y=scale_x_y),
  986. nms=nms)
  987. params = {
  988. 'backbone': backbone,
  989. 'neck': neck,
  990. 'yolo_head': yolo_head,
  991. 'post_process': post_process
  992. }
  993. super(YOLOv3, self).__init__(
  994. model_name='YOLOv3', num_classes=num_classes, **params)
  995. self.anchors = anchors
  996. self.anchor_masks = anchor_masks
  997. self.downsample_ratios = downsample_ratios
  998. self.model_name = 'PPYOLO'
  999. class PPYOLOTiny(YOLOv3):
  1000. def __init__(self,
  1001. num_classes=80,
  1002. backbone='MobileNetV3',
  1003. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1004. [60, 170], [220, 125], [128, 222], [264, 266]],
  1005. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1006. use_iou_aware=False,
  1007. use_spp=True,
  1008. use_drop_block=True,
  1009. scale_x_y=1.05,
  1010. ignore_threshold=0.5,
  1011. label_smooth=False,
  1012. use_iou_loss=True,
  1013. use_matrix_nms=False,
  1014. nms_score_threshold=0.005,
  1015. nms_topk=1000,
  1016. nms_keep_topk=100,
  1017. nms_iou_threshold=0.45):
  1018. self.init_params = locals()
  1019. if backbone != 'MobileNetV3':
  1020. logging.warning(
  1021. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1022. "Backbone is forcibly set to MobileNetV3.")
  1023. self.backbone_name = 'MobileNetV3'
  1024. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1025. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1026. norm_type = 'sync_bn'
  1027. else:
  1028. norm_type = 'bn'
  1029. backbone = self._get_backbone(
  1030. 'MobileNetV3',
  1031. model_name='large',
  1032. norm_type=norm_type,
  1033. scale=.5,
  1034. with_extra_blocks=False,
  1035. extra_block_filters=[],
  1036. feature_maps=[7, 13, 16])
  1037. downsample_ratios = [32, 16, 8]
  1038. neck = ppdet.modeling.PPYOLOTinyFPN(
  1039. detection_block_channels=[160, 128, 96],
  1040. in_channels=[i.channels for i in backbone.out_shape],
  1041. spp=use_spp,
  1042. drop_block=use_drop_block)
  1043. loss = ppdet.modeling.YOLOv3Loss(
  1044. num_classes=num_classes,
  1045. ignore_thresh=ignore_threshold,
  1046. downsample=downsample_ratios,
  1047. label_smooth=label_smooth,
  1048. scale_x_y=scale_x_y,
  1049. iou_loss=ppdet.modeling.IouLoss(
  1050. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1051. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1052. if use_iou_aware else None)
  1053. yolo_head = ppdet.modeling.YOLOv3Head(
  1054. in_channels=[i.channels for i in neck.out_shape],
  1055. anchors=anchors,
  1056. anchor_masks=anchor_masks,
  1057. num_classes=num_classes,
  1058. loss=loss,
  1059. iou_aware=use_iou_aware)
  1060. if use_matrix_nms:
  1061. nms = ppdet.modeling.MatrixNMS(
  1062. keep_top_k=nms_keep_topk,
  1063. score_threshold=nms_score_threshold,
  1064. post_threshold=.05,
  1065. nms_top_k=nms_topk,
  1066. background_label=-1)
  1067. else:
  1068. nms = ppdet.modeling.MultiClassNMS(
  1069. score_threshold=nms_score_threshold,
  1070. nms_top_k=nms_topk,
  1071. keep_top_k=nms_keep_topk,
  1072. nms_threshold=nms_iou_threshold)
  1073. post_process = ppdet.modeling.BBoxPostProcess(
  1074. decode=ppdet.modeling.YOLOBox(
  1075. num_classes=num_classes,
  1076. conf_thresh=.005,
  1077. downsample_ratio=32,
  1078. clip_bbox=True,
  1079. scale_x_y=scale_x_y),
  1080. nms=nms)
  1081. params = {
  1082. 'backbone': backbone,
  1083. 'neck': neck,
  1084. 'yolo_head': yolo_head,
  1085. 'post_process': post_process
  1086. }
  1087. super(YOLOv3, self).__init__(
  1088. model_name='YOLOv3', num_classes=num_classes, **params)
  1089. self.anchors = anchors
  1090. self.anchor_masks = anchor_masks
  1091. self.downsample_ratios = downsample_ratios
  1092. self.model_name = 'PPYOLOTiny'
  1093. class PPYOLOv2(YOLOv3):
  1094. def __init__(self,
  1095. num_classes=80,
  1096. backbone='ResNet50_vd_dcn',
  1097. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1098. [59, 119], [116, 90], [156, 198], [373, 326]],
  1099. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1100. use_iou_aware=True,
  1101. use_spp=True,
  1102. use_drop_block=True,
  1103. scale_x_y=1.05,
  1104. ignore_threshold=0.7,
  1105. label_smooth=False,
  1106. use_iou_loss=True,
  1107. use_matrix_nms=True,
  1108. nms_score_threshold=0.01,
  1109. nms_topk=-1,
  1110. nms_keep_topk=100,
  1111. nms_iou_threshold=0.45):
  1112. self.init_params = locals()
  1113. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1114. raise ValueError(
  1115. "backbone: {} is not supported. Please choose one of "
  1116. "('ResNet50_vd_dcn', 'ResNet18_vd')".format(backbone))
  1117. self.backbone_name = backbone
  1118. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1119. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1120. norm_type = 'sync_bn'
  1121. else:
  1122. norm_type = 'bn'
  1123. if backbone == 'ResNet50_vd_dcn':
  1124. backbone = self._get_backbone(
  1125. 'ResNet',
  1126. variant='d',
  1127. norm_type=norm_type,
  1128. return_idx=[1, 2, 3],
  1129. dcn_v2_stages=[3],
  1130. freeze_at=-1,
  1131. freeze_norm=False,
  1132. norm_decay=0.)
  1133. downsample_ratios = [32, 16, 8]
  1134. elif backbone == 'ResNet101_vd_dcn':
  1135. backbone = self._get_backbone(
  1136. 'ResNet',
  1137. depth=101,
  1138. variant='d',
  1139. norm_type=norm_type,
  1140. return_idx=[1, 2, 3],
  1141. dcn_v2_stages=[3],
  1142. freeze_at=-1,
  1143. freeze_norm=False,
  1144. norm_decay=0.)
  1145. downsample_ratios = [32, 16, 8]
  1146. neck = ppdet.modeling.PPYOLOPAN(
  1147. norm_type=norm_type,
  1148. in_channels=[i.channels for i in backbone.out_shape],
  1149. drop_block=use_drop_block,
  1150. block_size=3,
  1151. keep_prob=.9,
  1152. spp=use_spp)
  1153. loss = ppdet.modeling.YOLOv3Loss(
  1154. num_classes=num_classes,
  1155. ignore_thresh=ignore_threshold,
  1156. downsample=downsample_ratios,
  1157. label_smooth=label_smooth,
  1158. scale_x_y=scale_x_y,
  1159. iou_loss=ppdet.modeling.IouLoss(
  1160. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1161. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1162. if use_iou_aware else None)
  1163. yolo_head = ppdet.modeling.YOLOv3Head(
  1164. in_channels=[i.channels for i in neck.out_shape],
  1165. anchors=anchors,
  1166. anchor_masks=anchor_masks,
  1167. num_classes=num_classes,
  1168. loss=loss,
  1169. iou_aware=use_iou_aware,
  1170. iou_aware_factor=.5)
  1171. if use_matrix_nms:
  1172. nms = ppdet.modeling.MatrixNMS(
  1173. keep_top_k=nms_keep_topk,
  1174. score_threshold=nms_score_threshold,
  1175. post_threshold=.01,
  1176. nms_top_k=nms_topk,
  1177. background_label=-1)
  1178. else:
  1179. nms = ppdet.modeling.MultiClassNMS(
  1180. score_threshold=nms_score_threshold,
  1181. nms_top_k=nms_topk,
  1182. keep_top_k=nms_keep_topk,
  1183. nms_threshold=nms_iou_threshold)
  1184. post_process = ppdet.modeling.BBoxPostProcess(
  1185. decode=ppdet.modeling.YOLOBox(
  1186. num_classes=num_classes,
  1187. conf_thresh=.01,
  1188. downsample_ratio=32,
  1189. clip_bbox=True,
  1190. scale_x_y=scale_x_y),
  1191. nms=nms)
  1192. params = {
  1193. 'backbone': backbone,
  1194. 'neck': neck,
  1195. 'yolo_head': yolo_head,
  1196. 'post_process': post_process
  1197. }
  1198. super(YOLOv3, self).__init__(
  1199. model_name='YOLOv3', num_classes=num_classes, **params)
  1200. self.anchors = anchors
  1201. self.anchor_masks = anchor_masks
  1202. self.downsample_ratios = downsample_ratios
  1203. self.model_name = 'PPYOLOv2'
  1204. class MaskRCNN(BaseDetector):
  1205. def __init__(self,
  1206. num_classes=80,
  1207. backbone='ResNet50_vd',
  1208. with_fpn=True,
  1209. aspect_ratios=[0.5, 1.0, 2.0],
  1210. anchor_sizes=[[32], [64], [128], [256], [512]],
  1211. keep_top_k=100,
  1212. nms_threshold=0.5,
  1213. score_threshold=0.05,
  1214. fpn_num_channels=256,
  1215. rpn_batch_size_per_im=256,
  1216. rpn_fg_fraction=0.5,
  1217. test_pre_nms_top_n=None,
  1218. test_post_nms_top_n=1000):
  1219. self.init_params = locals()
  1220. if backbone not in [
  1221. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1222. 'ResNet101_vd'
  1223. ]:
  1224. raise ValueError(
  1225. "backbone: {} is not supported. Please choose one of "
  1226. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1227. format(backbone))
  1228. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1229. if backbone == 'ResNet50':
  1230. if with_fpn:
  1231. backbone = self._get_backbone(
  1232. 'ResNet',
  1233. norm_type='bn',
  1234. freeze_at=0,
  1235. return_idx=[0, 1, 2, 3],
  1236. num_stages=4)
  1237. else:
  1238. backbone = self._get_backbone(
  1239. 'ResNet',
  1240. norm_type='bn',
  1241. freeze_at=0,
  1242. return_idx=[2],
  1243. num_stages=3)
  1244. elif 'ResNet50_vd' in backbone:
  1245. if not with_fpn:
  1246. logging.warning(
  1247. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1248. format(backbone))
  1249. with_fpn = True
  1250. backbone = self._get_backbone(
  1251. 'ResNet',
  1252. variant='d',
  1253. norm_type='bn',
  1254. freeze_at=0,
  1255. return_idx=[0, 1, 2, 3],
  1256. num_stages=4,
  1257. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1258. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0])
  1259. else:
  1260. if not with_fpn:
  1261. logging.warning(
  1262. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1263. format(backbone))
  1264. with_fpn = True
  1265. backbone = self._get_backbone(
  1266. 'ResNet',
  1267. variant='d' if '_vd' in backbone else 'b',
  1268. depth=101,
  1269. norm_type='bn',
  1270. freeze_at=0,
  1271. return_idx=[0, 1, 2, 3],
  1272. num_stages=4)
  1273. rpn_in_channel = backbone.out_shape[0].channels
  1274. if with_fpn:
  1275. neck = ppdet.modeling.FPN(
  1276. in_channels=[i.channels for i in backbone.out_shape],
  1277. out_channel=fpn_num_channels,
  1278. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1279. rpn_in_channel = neck.out_shape[0].channels
  1280. anchor_generator_cfg = {
  1281. 'aspect_ratios': aspect_ratios,
  1282. 'anchor_sizes': anchor_sizes,
  1283. 'strides': [4, 8, 16, 32, 64]
  1284. }
  1285. train_proposal_cfg = {
  1286. 'min_size': 0.0,
  1287. 'nms_thresh': .7,
  1288. 'pre_nms_top_n': 2000,
  1289. 'post_nms_top_n': 1000,
  1290. 'topk_after_collect': True
  1291. }
  1292. test_proposal_cfg = {
  1293. 'min_size': 0.0,
  1294. 'nms_thresh': .7,
  1295. 'pre_nms_top_n': 1000
  1296. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1297. 'post_nms_top_n': test_post_nms_top_n
  1298. }
  1299. bb_head = ppdet.modeling.TwoFCHead(
  1300. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1301. bb_roi_extractor_cfg = {
  1302. 'resolution': 7,
  1303. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1304. 'sampling_ratio': 0,
  1305. 'aligned': True
  1306. }
  1307. with_pool = False
  1308. m_head = ppdet.modeling.MaskFeat(
  1309. in_channel=neck.out_shape[0].channels,
  1310. out_channel=256,
  1311. num_convs=4)
  1312. m_roi_extractor_cfg = {
  1313. 'resolution': 14,
  1314. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1315. 'sampling_ratio': 0,
  1316. 'aligned': True
  1317. }
  1318. mask_assigner = MaskAssigner(
  1319. num_classes=num_classes, mask_resolution=28)
  1320. share_bbox_feat = False
  1321. else:
  1322. neck = None
  1323. anchor_generator_cfg = {
  1324. 'aspect_ratios': aspect_ratios,
  1325. 'anchor_sizes': anchor_sizes,
  1326. 'strides': [16]
  1327. }
  1328. train_proposal_cfg = {
  1329. 'min_size': 0.0,
  1330. 'nms_thresh': .7,
  1331. 'pre_nms_top_n': 12000,
  1332. 'post_nms_top_n': 2000,
  1333. 'topk_after_collect': False
  1334. }
  1335. test_proposal_cfg = {
  1336. 'min_size': 0.0,
  1337. 'nms_thresh': .7,
  1338. 'pre_nms_top_n': 6000
  1339. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1340. 'post_nms_top_n': test_post_nms_top_n
  1341. }
  1342. bb_head = ppdet.modeling.Res5Head()
  1343. bb_roi_extractor_cfg = {
  1344. 'resolution': 14,
  1345. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1346. 'sampling_ratio': 0,
  1347. 'aligned': True
  1348. }
  1349. with_pool = True
  1350. m_head = ppdet.modeling.MaskFeat(
  1351. in_channel=bb_head.out_shape[0].channels,
  1352. out_channel=256,
  1353. num_convs=0)
  1354. m_roi_extractor_cfg = {
  1355. 'resolution': 14,
  1356. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1357. 'sampling_ratio': 0,
  1358. 'aligned': True
  1359. }
  1360. mask_assigner = MaskAssigner(
  1361. num_classes=num_classes, mask_resolution=14)
  1362. share_bbox_feat = True
  1363. rpn_target_assign_cfg = {
  1364. 'batch_size_per_im': rpn_batch_size_per_im,
  1365. 'fg_fraction': rpn_fg_fraction,
  1366. 'negative_overlap': .3,
  1367. 'positive_overlap': .7,
  1368. 'use_random': True
  1369. }
  1370. rpn_head = ppdet.modeling.RPNHead(
  1371. anchor_generator=anchor_generator_cfg,
  1372. rpn_target_assign=rpn_target_assign_cfg,
  1373. train_proposal=train_proposal_cfg,
  1374. test_proposal=test_proposal_cfg,
  1375. in_channel=rpn_in_channel)
  1376. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1377. bbox_head = ppdet.modeling.BBoxHead(
  1378. head=bb_head,
  1379. in_channel=bb_head.out_shape[0].channels,
  1380. roi_extractor=bb_roi_extractor_cfg,
  1381. with_pool=with_pool,
  1382. bbox_assigner=bbox_assigner,
  1383. num_classes=num_classes)
  1384. mask_head = ppdet.modeling.MaskHead(
  1385. head=m_head,
  1386. roi_extractor=m_roi_extractor_cfg,
  1387. mask_assigner=mask_assigner,
  1388. share_bbox_feat=share_bbox_feat,
  1389. num_classes=num_classes)
  1390. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1391. num_classes=num_classes,
  1392. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1393. nms=ppdet.modeling.MultiClassNMS(
  1394. score_threshold=score_threshold,
  1395. keep_top_k=keep_top_k,
  1396. nms_threshold=nms_threshold))
  1397. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1398. params = {
  1399. 'backbone': backbone,
  1400. 'neck': neck,
  1401. 'rpn_head': rpn_head,
  1402. 'bbox_head': bbox_head,
  1403. 'mask_head': mask_head,
  1404. 'bbox_post_process': bbox_post_process,
  1405. 'mask_post_process': mask_post_process
  1406. }
  1407. self.with_fpn = with_fpn
  1408. super(MaskRCNN, self).__init__(
  1409. model_name='MaskRCNN', num_classes=num_classes, **params)
  1410. def _compose_batch_transform(self, transforms, mode='train'):
  1411. if mode == 'train':
  1412. default_batch_transforms = [
  1413. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1414. ]
  1415. collate_batch = False
  1416. else:
  1417. default_batch_transforms = [
  1418. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1419. ]
  1420. collate_batch = True
  1421. custom_batch_transforms = []
  1422. for i, op in enumerate(transforms.transforms):
  1423. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1424. if mode != 'train':
  1425. raise Exception(
  1426. "{} cannot be present in the {} transforms. ".format(
  1427. op.__class__.__name__, mode) +
  1428. "Please check the {} transforms.".format(mode))
  1429. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1430. batch_transforms = BatchCompose(
  1431. custom_batch_transforms + default_batch_transforms,
  1432. collate_batch=collate_batch)
  1433. return batch_transforms