detector.py 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import six
  20. import numpy as np
  21. import paddle
  22. from paddle.static import InputSpec
  23. import ppdet
  24. from ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  25. import paddlex
  26. import paddlex.utils.logging as logging
  27. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
  28. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from .utils.ema import ExponentialMovingAverage
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. del self.init_params['params']
  41. super(BaseDetector, self).__init__('detector')
  42. if not hasattr(ppdet.modeling, model_name):
  43. raise Exception("ERROR: There's no model named {}.".format(
  44. model_name))
  45. self.model_name = model_name
  46. self.num_classes = num_classes
  47. self.labels = None
  48. self.net = self.build_net(**params)
  49. def build_net(self, **params):
  50. with paddle.utils.unique_name.guard():
  51. net = ppdet.modeling.__dict__[self.model_name](**params)
  52. return net
  53. def get_test_inputs(self, image_shape):
  54. input_spec = [{
  55. "image": InputSpec(
  56. shape=[None, 3] + image_shape, name='image', dtype='float32'),
  57. "im_shape": InputSpec(
  58. shape=[None, 2], name='im_shape', dtype='float32'),
  59. "scale_factor": InputSpec(
  60. shape=[None, 2], name='scale_factor', dtype='float32')
  61. }]
  62. return input_spec
  63. def _get_backbone(self, backbone_name, **params):
  64. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  65. return backbone
  66. def run(self, net, inputs, mode):
  67. net_out = net(inputs)
  68. if mode in ['train', 'eval']:
  69. outputs = net_out
  70. else:
  71. for key in ['im_shape', 'scale_factor']:
  72. net_out[key] = inputs[key]
  73. outputs = dict()
  74. for key in net_out:
  75. outputs[key] = net_out[key].numpy()
  76. return outputs
  77. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  78. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  79. num_steps_each_epoch):
  80. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  81. values = [(lr_decay_gamma**i) * learning_rate
  82. for i in range(len(lr_decay_epochs) + 1)]
  83. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  84. boundaries=boundaries, values=values)
  85. if warmup_steps > 0:
  86. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  87. logging.error(
  88. "In function train(), parameters should satisfy: "
  89. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  90. exit=False)
  91. logging.error(
  92. "See this doc for more information: "
  93. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  94. exit=False)
  95. scheduler = paddle.optimizer.lr.LinearWarmup(
  96. learning_rate=scheduler,
  97. warmup_steps=warmup_steps,
  98. start_lr=warmup_start_lr,
  99. end_lr=learning_rate)
  100. optimizer = paddle.optimizer.Momentum(
  101. scheduler,
  102. momentum=.9,
  103. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  104. parameters=parameters)
  105. return optimizer
  106. def train(self,
  107. num_epochs,
  108. train_dataset,
  109. train_batch_size=64,
  110. eval_dataset=None,
  111. optimizer=None,
  112. save_interval_epochs=1,
  113. log_interval_steps=10,
  114. save_dir='output',
  115. pretrain_weights='IMAGENET',
  116. learning_rate=.001,
  117. warmup_steps=0,
  118. warmup_start_lr=0.0,
  119. lr_decay_epochs=(216, 243),
  120. lr_decay_gamma=0.1,
  121. metric=None,
  122. use_ema=False,
  123. early_stop=False,
  124. early_stop_patience=5,
  125. use_vdl=True):
  126. """
  127. Train the model.
  128. Args:
  129. num_epochs(int): The number of epochs.
  130. train_dataset(paddlex.dataset): Training dataset.
  131. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  132. eval_dataset(paddlex.dataset, optional):
  133. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  134. optimizer(paddle.optimizer.Optimizer or None, optional):
  135. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  136. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  137. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  138. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  139. pretrain_weights(str or None, optional):
  140. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  141. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  142. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  143. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  144. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  145. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  146. metric({'VOC', 'COCO', None}, optional):
  147. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  148. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  149. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  150. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  151. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  152. """
  153. if train_dataset.__class__.__name__ == 'VOCDetection':
  154. train_dataset.data_fields = {
  155. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  156. 'difficult'
  157. }
  158. elif train_dataset.__class__.__name__ == 'CocoDetection':
  159. if self.__class__.__name__ == 'MaskRCNN':
  160. train_dataset.data_fields = {
  161. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  162. 'gt_poly', 'is_crowd'
  163. }
  164. else:
  165. train_dataset.data_fields = {
  166. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  167. 'is_crowd'
  168. }
  169. if metric is None:
  170. if eval_dataset.__class__.__name__ == 'VOCDetection':
  171. self.metric = 'voc'
  172. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  173. self.metric = 'coco'
  174. else:
  175. assert metric.lower() in ['coco', 'voc'], \
  176. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  177. self.metric = metric.lower()
  178. train_dataset.batch_transforms = self._compose_batch_transform(
  179. train_dataset.transforms, mode='train')
  180. self.labels = train_dataset.labels
  181. # build optimizer if not defined
  182. if optimizer is None:
  183. num_steps_each_epoch = len(train_dataset) // train_batch_size
  184. self.optimizer = self.default_optimizer(
  185. parameters=self.net.parameters(),
  186. learning_rate=learning_rate,
  187. warmup_steps=warmup_steps,
  188. warmup_start_lr=warmup_start_lr,
  189. lr_decay_epochs=lr_decay_epochs,
  190. lr_decay_gamma=lr_decay_gamma,
  191. num_steps_each_epoch=num_steps_each_epoch)
  192. else:
  193. self.optimizer = optimizer
  194. # initiate weights
  195. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  196. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  197. [self.model_name, self.backbone_name])]:
  198. logging.warning(
  199. "Path of pretrain_weights('{}') does not exist!".format(
  200. pretrain_weights))
  201. pretrain_weights = det_pretrain_weights_dict['_'.join(
  202. [self.model_name, self.backbone_name])][0]
  203. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  204. "If don't want to use pretrain weights, "
  205. "set pretrain_weights to be None.".format(
  206. pretrain_weights))
  207. pretrained_dir = osp.join(save_dir, 'pretrain')
  208. self.net_initialize(
  209. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  210. if use_ema:
  211. ema = ExponentialMovingAverage(
  212. decay=.9998, model=self.net, use_thres_step=True)
  213. else:
  214. ema = None
  215. # start train loop
  216. self.train_loop(
  217. num_epochs=num_epochs,
  218. train_dataset=train_dataset,
  219. train_batch_size=train_batch_size,
  220. eval_dataset=eval_dataset,
  221. save_interval_epochs=save_interval_epochs,
  222. log_interval_steps=log_interval_steps,
  223. save_dir=save_dir,
  224. ema=ema,
  225. early_stop=early_stop,
  226. early_stop_patience=early_stop_patience,
  227. use_vdl=use_vdl)
  228. def quant_aware_train(self,
  229. num_epochs,
  230. train_dataset,
  231. train_batch_size=64,
  232. eval_dataset=None,
  233. optimizer=None,
  234. save_interval_epochs=1,
  235. log_interval_steps=10,
  236. save_dir='output',
  237. learning_rate=.00001,
  238. warmup_steps=0,
  239. warmup_start_lr=0.0,
  240. lr_decay_epochs=(216, 243),
  241. lr_decay_gamma=0.1,
  242. metric=None,
  243. use_ema=False,
  244. early_stop=False,
  245. early_stop_patience=5,
  246. use_vdl=True,
  247. quant_config=None):
  248. """
  249. Quantization-aware training.
  250. Args:
  251. num_epochs(int): The number of epochs.
  252. train_dataset(paddlex.dataset): Training dataset.
  253. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  254. eval_dataset(paddlex.dataset, optional):
  255. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  256. optimizer(paddle.optimizer.Optimizer or None, optional):
  257. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  258. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  259. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  260. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  261. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  262. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  263. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  264. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  265. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  266. metric({'VOC', 'COCO', None}, optional):
  267. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  268. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  269. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  270. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  271. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  272. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  273. configuration will be used. Defaults to None.
  274. """
  275. self._prepare_qat(quant_config)
  276. self.train(
  277. num_epochs=num_epochs,
  278. train_dataset=train_dataset,
  279. train_batch_size=train_batch_size,
  280. eval_dataset=eval_dataset,
  281. optimizer=optimizer,
  282. save_interval_epochs=save_interval_epochs,
  283. log_interval_steps=log_interval_steps,
  284. save_dir=save_dir,
  285. pretrain_weights=None,
  286. learning_rate=learning_rate,
  287. warmup_steps=warmup_steps,
  288. warmup_start_lr=warmup_start_lr,
  289. lr_decay_epochs=lr_decay_epochs,
  290. lr_decay_gamma=lr_decay_gamma,
  291. metric=metric,
  292. use_ema=use_ema,
  293. early_stop=early_stop,
  294. early_stop_patience=early_stop_patience,
  295. use_vdl=use_vdl)
  296. def evaluate(self,
  297. eval_dataset,
  298. batch_size=1,
  299. metric=None,
  300. return_details=False):
  301. """
  302. Evaluate the model.
  303. Args:
  304. eval_dataset(paddlex.dataset): Evaluation dataset.
  305. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  306. metric({'VOC', 'COCO', None}, optional):
  307. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  308. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  309. Returns:
  310. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  311. """
  312. if eval_dataset.__class__.__name__ == 'VOCDetection':
  313. eval_dataset.data_fields = {
  314. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  315. 'difficult'
  316. }
  317. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  318. if self.__class__.__name__ == 'MaskRCNN':
  319. eval_dataset.data_fields = {
  320. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  321. 'gt_poly', 'is_crowd'
  322. }
  323. else:
  324. eval_dataset.data_fields = {
  325. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  326. 'is_crowd'
  327. }
  328. eval_dataset.batch_transforms = self._compose_batch_transform(
  329. eval_dataset.transforms, mode='eval')
  330. arrange_transforms(
  331. model_type=self.model_type,
  332. transforms=eval_dataset.transforms,
  333. mode='eval')
  334. self.net.eval()
  335. nranks = paddle.distributed.get_world_size()
  336. local_rank = paddle.distributed.get_rank()
  337. if nranks > 1:
  338. # Initialize parallel environment if not done.
  339. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  340. ):
  341. paddle.distributed.init_parallel_env()
  342. if batch_size > 1:
  343. logging.warning(
  344. "Detector only supports single card evaluation with batch_size=1 "
  345. "during evaluation, so batch_size is forcibly set to 1.")
  346. batch_size = 1
  347. if nranks < 2 or local_rank == 0:
  348. self.eval_data_loader = self.build_data_loader(
  349. eval_dataset, batch_size=batch_size, mode='eval')
  350. is_bbox_normalized = False
  351. if eval_dataset.batch_transforms is not None:
  352. is_bbox_normalized = any(
  353. isinstance(t, _NormalizeBox)
  354. for t in eval_dataset.batch_transforms.batch_transforms)
  355. if metric is None:
  356. if getattr(self, 'metric', None) is not None:
  357. if self.metric == 'voc':
  358. eval_metric = VOCMetric(
  359. labels=eval_dataset.labels,
  360. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  361. is_bbox_normalized=is_bbox_normalized,
  362. classwise=False)
  363. else:
  364. eval_metric = COCOMetric(
  365. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  366. classwise=False)
  367. else:
  368. if eval_dataset.__class__.__name__ == 'VOCDetection':
  369. eval_metric = VOCMetric(
  370. labels=eval_dataset.labels,
  371. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  372. is_bbox_normalized=is_bbox_normalized,
  373. classwise=False)
  374. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  375. eval_metric = COCOMetric(
  376. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  377. classwise=False)
  378. else:
  379. assert metric.lower() in ['coco', 'voc'], \
  380. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  381. if metric.lower() == 'coco':
  382. eval_metric = COCOMetric(
  383. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  384. classwise=False)
  385. else:
  386. eval_metric = VOCMetric(
  387. labels=eval_dataset.labels,
  388. is_bbox_normalized=is_bbox_normalized,
  389. classwise=False)
  390. scores = collections.OrderedDict()
  391. logging.info(
  392. "Start to evaluate(total_samples={}, total_steps={})...".
  393. format(eval_dataset.num_samples, eval_dataset.num_samples))
  394. with paddle.no_grad():
  395. for step, data in enumerate(self.eval_data_loader):
  396. outputs = self.run(self.net, data, 'eval')
  397. eval_metric.update(data, outputs)
  398. eval_metric.accumulate()
  399. self.eval_details = eval_metric.details
  400. scores.update(eval_metric.get())
  401. eval_metric.reset()
  402. if return_details:
  403. return scores, self.eval_details
  404. return scores
  405. def predict(self, img_file, transforms=None):
  406. """
  407. Do inference.
  408. Args:
  409. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  410. Image path or decoded image data in a BGR format, which also could constitute a list,
  411. meaning all images to be predicted as a mini-batch.
  412. transforms(paddlex.transforms.Compose or None, optional):
  413. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  414. Returns:
  415. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  416. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  417. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  418. category_id(int): the predicted category ID
  419. category(str): category name
  420. bbox(list): bounding box in [x, y, w, h] format
  421. score(str): confidence
  422. """
  423. if transforms is None and not hasattr(self, 'test_transforms'):
  424. raise Exception("transforms need to be defined, now is None.")
  425. if transforms is None:
  426. transforms = self.test_transforms
  427. if isinstance(img_file, (str, np.ndarray)):
  428. images = [img_file]
  429. else:
  430. images = img_file
  431. batch_samples = self._preprocess(images, transforms)
  432. self.net.eval()
  433. outputs = self.run(self.net, batch_samples, 'test')
  434. prediction = self._postprocess(outputs)
  435. if isinstance(img_file, (str, np.ndarray)):
  436. prediction = prediction[0]
  437. return prediction
  438. def _preprocess(self, images, transforms):
  439. arrange_transforms(
  440. model_type=self.model_type, transforms=transforms, mode='test')
  441. batch_samples = list()
  442. for im in images:
  443. sample = {'image': im}
  444. batch_samples.append(transforms(sample))
  445. batch_transforms = self._compose_batch_transform(transforms, 'test')
  446. batch_samples = batch_transforms(batch_samples)
  447. for k, v in batch_samples.items():
  448. batch_samples[k] = paddle.to_tensor(v)
  449. return batch_samples
  450. def _postprocess(self, batch_pred):
  451. infer_result = {}
  452. if 'bbox' in batch_pred:
  453. bboxes = batch_pred['bbox']
  454. bbox_nums = batch_pred['bbox_num']
  455. det_res = []
  456. k = 0
  457. for i in range(len(bbox_nums)):
  458. det_nums = bbox_nums[i]
  459. for j in range(det_nums):
  460. dt = bboxes[k]
  461. k = k + 1
  462. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  463. if int(num_id) < 0:
  464. continue
  465. category = self.labels[int(num_id)]
  466. w = xmax - xmin
  467. h = ymax - ymin
  468. bbox = [xmin, ymin, w, h]
  469. dt_res = {
  470. 'category_id': int(num_id),
  471. 'category': category,
  472. 'bbox': bbox,
  473. 'score': score
  474. }
  475. det_res.append(dt_res)
  476. infer_result['bbox'] = det_res
  477. if 'mask' in batch_pred:
  478. masks = batch_pred['mask']
  479. bboxes = batch_pred['bbox']
  480. mask_nums = batch_pred['bbox_num']
  481. seg_res = []
  482. k = 0
  483. for i in range(len(mask_nums)):
  484. det_nums = mask_nums[i]
  485. for j in range(det_nums):
  486. mask = masks[k].astype(np.uint8)
  487. score = float(bboxes[k][1])
  488. label = int(bboxes[k][0])
  489. k = k + 1
  490. if label == -1:
  491. continue
  492. category = self.labels[int(label)]
  493. import pycocotools.mask as mask_util
  494. rle = mask_util.encode(
  495. np.array(
  496. mask[:, :, None], order="F", dtype="uint8"))[0]
  497. if six.PY3:
  498. if 'counts' in rle:
  499. rle['counts'] = rle['counts'].decode("utf8")
  500. sg_res = {
  501. 'category': category,
  502. 'segmentation': rle,
  503. 'score': score
  504. }
  505. seg_res.append(sg_res)
  506. infer_result['mask'] = seg_res
  507. bbox_num = batch_pred['bbox_num']
  508. results = []
  509. start = 0
  510. for num in bbox_num:
  511. end = start + num
  512. curr_res = infer_result['bbox'][start:end]
  513. if 'mask' in infer_result:
  514. mask_res = infer_result['mask'][start:end]
  515. for box, mask in zip(curr_res, mask_res):
  516. box.update(mask)
  517. results.append(curr_res)
  518. start = end
  519. return results
  520. class YOLOv3(BaseDetector):
  521. def __init__(self,
  522. num_classes=80,
  523. backbone='MobileNetV1',
  524. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  525. [59, 119], [116, 90], [156, 198], [373, 326]],
  526. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  527. ignore_threshold=0.7,
  528. nms_score_threshold=0.01,
  529. nms_topk=1000,
  530. nms_keep_topk=100,
  531. nms_iou_threshold=0.45,
  532. label_smooth=False):
  533. self.init_params = locals()
  534. if backbone not in [
  535. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  536. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  537. ]:
  538. raise ValueError(
  539. "backbone: {} is not supported. Please choose one of "
  540. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34')".
  541. format(backbone))
  542. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  543. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  544. norm_type = 'sync_bn'
  545. else:
  546. norm_type = 'bn'
  547. self.backbone_name = backbone
  548. if 'MobileNetV1' in backbone:
  549. norm_type = 'bn'
  550. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  551. elif 'MobileNetV3' in backbone:
  552. backbone = self._get_backbone(
  553. 'MobileNetV3', norm_type=norm_type, feature_maps=[7, 13, 16])
  554. elif backbone == 'ResNet50_vd_dcn':
  555. backbone = self._get_backbone(
  556. 'ResNet',
  557. norm_type=norm_type,
  558. variant='d',
  559. return_idx=[1, 2, 3],
  560. dcn_v2_stages=[3],
  561. freeze_at=-1,
  562. freeze_norm=False)
  563. elif backbone == 'ResNet34':
  564. backbone = self._get_backbone(
  565. 'ResNet',
  566. depth=34,
  567. norm_type=norm_type,
  568. return_idx=[1, 2, 3],
  569. freeze_at=-1,
  570. freeze_norm=False,
  571. norm_decay=0.)
  572. else:
  573. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  574. neck = ppdet.modeling.YOLOv3FPN(
  575. norm_type=norm_type,
  576. in_channels=[i.channels for i in backbone.out_shape])
  577. loss = ppdet.modeling.YOLOv3Loss(
  578. num_classes=num_classes,
  579. ignore_thresh=ignore_threshold,
  580. label_smooth=label_smooth)
  581. yolo_head = ppdet.modeling.YOLOv3Head(
  582. in_channels=[i.channels for i in neck.out_shape],
  583. anchors=anchors,
  584. anchor_masks=anchor_masks,
  585. num_classes=num_classes,
  586. loss=loss)
  587. post_process = ppdet.modeling.BBoxPostProcess(
  588. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  589. nms=ppdet.modeling.MultiClassNMS(
  590. score_threshold=nms_score_threshold,
  591. nms_top_k=nms_topk,
  592. keep_top_k=nms_keep_topk,
  593. nms_threshold=nms_iou_threshold))
  594. params = {
  595. 'backbone': backbone,
  596. 'neck': neck,
  597. 'yolo_head': yolo_head,
  598. 'post_process': post_process
  599. }
  600. super(YOLOv3, self).__init__(
  601. model_name='YOLOv3', num_classes=num_classes, **params)
  602. self.anchors = anchors
  603. self.anchor_masks = anchor_masks
  604. def _compose_batch_transform(self, transforms, mode='train'):
  605. if mode == 'train':
  606. default_batch_transforms = [
  607. _BatchPadding(
  608. pad_to_stride=-1, pad_gt=False), _NormalizeBox(),
  609. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  610. _Gt2YoloTarget(
  611. anchor_masks=self.anchor_masks,
  612. anchors=self.anchors,
  613. downsample_ratios=getattr(self, 'downsample_ratios',
  614. [32, 16, 8]),
  615. num_classes=self.num_classes)
  616. ]
  617. else:
  618. default_batch_transforms = [
  619. _BatchPadding(
  620. pad_to_stride=-1, pad_gt=False)
  621. ]
  622. custom_batch_transforms = []
  623. for i, op in enumerate(transforms.transforms):
  624. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  625. if mode != 'train':
  626. raise Exception(
  627. "{} cannot be present in the {} transforms. ".format(
  628. op.__class__.__name__, mode) +
  629. "Please check the {} transforms.".format(mode))
  630. custom_batch_transforms.insert(0, copy.deepcopy(op))
  631. batch_transforms = BatchCompose(custom_batch_transforms +
  632. default_batch_transforms)
  633. return batch_transforms
  634. class FasterRCNN(BaseDetector):
  635. def __init__(self,
  636. num_classes=80,
  637. backbone='ResNet50',
  638. with_fpn=True,
  639. aspect_ratios=[0.5, 1.0, 2.0],
  640. anchor_sizes=[[32], [64], [128], [256], [512]],
  641. keep_top_k=100,
  642. nms_threshold=0.5,
  643. score_threshold=0.05,
  644. fpn_num_channels=256,
  645. rpn_batch_size_per_im=256,
  646. rpn_fg_fraction=0.5,
  647. test_pre_nms_top_n=None,
  648. test_post_nms_top_n=1000):
  649. self.init_params = locals()
  650. if backbone not in [
  651. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  652. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet'
  653. ]:
  654. raise ValueError(
  655. "backbone: {} is not supported. Please choose one of "
  656. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  657. "'ResNet101', 'ResNet101_vd', 'HRNet')".format(backbone))
  658. self.backbone_name = backbone
  659. if backbone == 'HRNet':
  660. if not with_fpn:
  661. logging.warning(
  662. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  663. format(backbone))
  664. with_fpn = True
  665. backbone = self._get_backbone(
  666. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  667. elif backbone == 'ResNet50_vd_ssld':
  668. if not with_fpn:
  669. logging.warning(
  670. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  671. format(backbone))
  672. with_fpn = True
  673. backbone = self._get_backbone(
  674. 'ResNet',
  675. variant='d',
  676. norm_type='bn',
  677. freeze_at=0,
  678. return_idx=[0, 1, 2, 3],
  679. num_stages=4,
  680. lr_mult_list=[0.05, 0.05, 0.1, 0.15])
  681. elif 'ResNet50' in backbone:
  682. if with_fpn:
  683. backbone = self._get_backbone(
  684. 'ResNet',
  685. variant='d' if '_vd' in backbone else 'b',
  686. norm_type='bn',
  687. freeze_at=0,
  688. return_idx=[0, 1, 2, 3],
  689. num_stages=4)
  690. else:
  691. backbone = self._get_backbone(
  692. 'ResNet',
  693. variant='d' if '_vd' in backbone else 'b',
  694. norm_type='bn',
  695. freeze_at=0,
  696. return_idx=[2],
  697. num_stages=3)
  698. elif 'ResNet34' in backbone:
  699. if not with_fpn:
  700. logging.warning(
  701. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  702. format(backbone))
  703. with_fpn = True
  704. backbone = self._get_backbone(
  705. 'ResNet',
  706. depth=34,
  707. variant='d' if 'vd' in backbone else 'b',
  708. norm_type='bn',
  709. freeze_at=0,
  710. return_idx=[0, 1, 2, 3],
  711. num_stages=4)
  712. else:
  713. if not with_fpn:
  714. logging.warning(
  715. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  716. format(backbone))
  717. with_fpn = True
  718. backbone = self._get_backbone(
  719. 'ResNet',
  720. depth=101,
  721. variant='d' if 'vd' in backbone else 'b',
  722. norm_type='bn',
  723. freeze_at=0,
  724. return_idx=[0, 1, 2, 3],
  725. num_stages=4)
  726. rpn_in_channel = backbone.out_shape[0].channels
  727. if with_fpn:
  728. self.backbone_name = self.backbone_name + '_fpn'
  729. if 'HRNet' in self.backbone_name:
  730. neck = ppdet.modeling.HRFPN(
  731. in_channels=[i.channels for i in backbone.out_shape],
  732. out_channel=fpn_num_channels,
  733. spatial_scales=[
  734. 1.0 / i.stride for i in backbone.out_shape
  735. ],
  736. share_conv=False)
  737. else:
  738. neck = ppdet.modeling.FPN(
  739. in_channels=[i.channels for i in backbone.out_shape],
  740. out_channel=fpn_num_channels,
  741. spatial_scales=[
  742. 1.0 / i.stride for i in backbone.out_shape
  743. ])
  744. rpn_in_channel = neck.out_shape[0].channels
  745. anchor_generator_cfg = {
  746. 'aspect_ratios': aspect_ratios,
  747. 'anchor_sizes': anchor_sizes,
  748. 'strides': [4, 8, 16, 32, 64]
  749. }
  750. train_proposal_cfg = {
  751. 'min_size': 0.0,
  752. 'nms_thresh': .7,
  753. 'pre_nms_top_n': 2000,
  754. 'post_nms_top_n': 1000,
  755. 'topk_after_collect': True
  756. }
  757. test_proposal_cfg = {
  758. 'min_size': 0.0,
  759. 'nms_thresh': .7,
  760. 'pre_nms_top_n': 1000
  761. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  762. 'post_nms_top_n': test_post_nms_top_n
  763. }
  764. head = ppdet.modeling.TwoFCHead(out_channel=1024)
  765. roi_extractor_cfg = {
  766. 'resolution': 7,
  767. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  768. 'sampling_ratio': 0,
  769. 'aligned': True
  770. }
  771. with_pool = False
  772. else:
  773. neck = None
  774. anchor_generator_cfg = {
  775. 'aspect_ratios': aspect_ratios,
  776. 'anchor_sizes': anchor_sizes,
  777. 'strides': [16]
  778. }
  779. train_proposal_cfg = {
  780. 'min_size': 0.0,
  781. 'nms_thresh': .7,
  782. 'pre_nms_top_n': 12000,
  783. 'post_nms_top_n': 2000,
  784. 'topk_after_collect': False
  785. }
  786. test_proposal_cfg = {
  787. 'min_size': 0.0,
  788. 'nms_thresh': .7,
  789. 'pre_nms_top_n': 6000
  790. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  791. 'post_nms_top_n': test_post_nms_top_n
  792. }
  793. head = ppdet.modeling.Res5Head()
  794. roi_extractor_cfg = {
  795. 'resolution': 14,
  796. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  797. 'sampling_ratio': 0,
  798. 'aligned': True
  799. }
  800. with_pool = True
  801. rpn_target_assign_cfg = {
  802. 'batch_size_per_im': rpn_batch_size_per_im,
  803. 'fg_fraction': rpn_fg_fraction,
  804. 'negative_overlap': .3,
  805. 'positive_overlap': .7,
  806. 'use_random': True
  807. }
  808. rpn_head = ppdet.modeling.RPNHead(
  809. anchor_generator=anchor_generator_cfg,
  810. rpn_target_assign=rpn_target_assign_cfg,
  811. train_proposal=train_proposal_cfg,
  812. test_proposal=test_proposal_cfg,
  813. in_channel=rpn_in_channel)
  814. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  815. bbox_head = ppdet.modeling.BBoxHead(
  816. head=head,
  817. in_channel=head.out_shape[0].channels,
  818. roi_extractor=roi_extractor_cfg,
  819. with_pool=with_pool,
  820. bbox_assigner=bbox_assigner,
  821. num_classes=num_classes)
  822. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  823. num_classes=num_classes,
  824. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  825. nms=ppdet.modeling.MultiClassNMS(
  826. score_threshold=score_threshold,
  827. keep_top_k=keep_top_k,
  828. nms_threshold=nms_threshold))
  829. params = {
  830. 'backbone': backbone,
  831. 'neck': neck,
  832. 'rpn_head': rpn_head,
  833. 'bbox_head': bbox_head,
  834. 'bbox_post_process': bbox_post_process
  835. }
  836. self.with_fpn = with_fpn
  837. super(FasterRCNN, self).__init__(
  838. model_name='FasterRCNN', num_classes=num_classes, **params)
  839. def _compose_batch_transform(self, transforms, mode='train'):
  840. if mode == 'train':
  841. default_batch_transforms = [
  842. _BatchPadding(
  843. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  844. ]
  845. else:
  846. default_batch_transforms = [
  847. _BatchPadding(
  848. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  849. ]
  850. custom_batch_transforms = []
  851. for i, op in enumerate(transforms.transforms):
  852. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  853. if mode != 'train':
  854. raise Exception(
  855. "{} cannot be present in the {} transforms. ".format(
  856. op.__class__.__name__, mode) +
  857. "Please check the {} transforms.".format(mode))
  858. custom_batch_transforms.insert(0, copy.deepcopy(op))
  859. batch_transforms = BatchCompose(custom_batch_transforms +
  860. default_batch_transforms)
  861. return batch_transforms
  862. class PPYOLO(YOLOv3):
  863. def __init__(self,
  864. num_classes=80,
  865. backbone='ResNet50_vd_dcn',
  866. anchors=None,
  867. anchor_masks=None,
  868. use_coord_conv=True,
  869. use_iou_aware=True,
  870. use_spp=True,
  871. use_drop_block=True,
  872. scale_x_y=1.05,
  873. ignore_threshold=0.7,
  874. label_smooth=False,
  875. use_iou_loss=True,
  876. use_matrix_nms=True,
  877. nms_score_threshold=0.01,
  878. nms_topk=-1,
  879. nms_keep_topk=100,
  880. nms_iou_threshold=0.45):
  881. self.init_params = locals()
  882. if backbone not in [
  883. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  884. 'MobileNetV3_small'
  885. ]:
  886. raise ValueError(
  887. "backbone: {} is not supported. Please choose one of "
  888. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  889. format(backbone))
  890. self.backbone_name = backbone
  891. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  892. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  893. norm_type = 'sync_bn'
  894. else:
  895. norm_type = 'bn'
  896. if anchors is None and anchor_masks is None:
  897. if 'MobileNetV3' in backbone:
  898. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  899. [120, 195], [254, 235]]
  900. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  901. elif backbone == 'ResNet50_vd_dcn':
  902. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  903. [59, 119], [116, 90], [156, 198], [373, 326]]
  904. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  905. else:
  906. anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169],
  907. [344, 319]]
  908. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  909. elif anchors is None or anchor_masks is None:
  910. raise ValueError("Please define both anchors and anchor_masks.")
  911. if backbone == 'ResNet50_vd_dcn':
  912. backbone = self._get_backbone(
  913. 'ResNet',
  914. variant='d',
  915. norm_type=norm_type,
  916. return_idx=[1, 2, 3],
  917. dcn_v2_stages=[3],
  918. freeze_at=-1,
  919. freeze_norm=False,
  920. norm_decay=0.)
  921. downsample_ratios = [32, 16, 8]
  922. elif backbone == 'ResNet18_vd':
  923. backbone = self._get_backbone(
  924. 'ResNet',
  925. depth=18,
  926. variant='d',
  927. norm_type=norm_type,
  928. return_idx=[2, 3],
  929. freeze_at=-1,
  930. freeze_norm=False,
  931. norm_decay=0.)
  932. downsample_ratios = [32, 16, 8]
  933. elif backbone == 'MobileNetV3_large':
  934. backbone = self._get_backbone(
  935. 'MobileNetV3',
  936. model_name='large',
  937. norm_type=norm_type,
  938. scale=1,
  939. with_extra_blocks=False,
  940. extra_block_filters=[],
  941. feature_maps=[13, 16])
  942. downsample_ratios = [32, 16]
  943. elif backbone == 'MobileNetV3_small':
  944. backbone = self._get_backbone(
  945. 'MobileNetV3',
  946. model_name='small',
  947. norm_type=norm_type,
  948. scale=1,
  949. with_extra_blocks=False,
  950. extra_block_filters=[],
  951. feature_maps=[9, 12])
  952. downsample_ratios = [32, 16]
  953. neck = ppdet.modeling.PPYOLOFPN(
  954. norm_type=norm_type,
  955. in_channels=[i.channels for i in backbone.out_shape],
  956. coord_conv=use_coord_conv,
  957. drop_block=use_drop_block,
  958. spp=use_spp,
  959. conv_block_num=0 if ('MobileNetV3' in self.backbone_name or
  960. self.backbone_name == 'ResNet18_vd') else 2)
  961. loss = ppdet.modeling.YOLOv3Loss(
  962. num_classes=num_classes,
  963. ignore_thresh=ignore_threshold,
  964. downsample=downsample_ratios,
  965. label_smooth=label_smooth,
  966. scale_x_y=scale_x_y,
  967. iou_loss=ppdet.modeling.IouLoss(
  968. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  969. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  970. if use_iou_aware else None)
  971. yolo_head = ppdet.modeling.YOLOv3Head(
  972. in_channels=[i.channels for i in neck.out_shape],
  973. anchors=anchors,
  974. anchor_masks=anchor_masks,
  975. num_classes=num_classes,
  976. loss=loss,
  977. iou_aware=use_iou_aware)
  978. if use_matrix_nms:
  979. nms = ppdet.modeling.MatrixNMS(
  980. keep_top_k=nms_keep_topk,
  981. score_threshold=nms_score_threshold,
  982. post_threshold=.05
  983. if 'MobileNetV3' in self.backbone_name else .01,
  984. nms_top_k=nms_topk,
  985. background_label=-1)
  986. else:
  987. nms = ppdet.modeling.MultiClassNMS(
  988. score_threshold=nms_score_threshold,
  989. nms_top_k=nms_topk,
  990. keep_top_k=nms_keep_topk,
  991. nms_threshold=nms_iou_threshold)
  992. post_process = ppdet.modeling.BBoxPostProcess(
  993. decode=ppdet.modeling.YOLOBox(
  994. num_classes=num_classes,
  995. conf_thresh=.005
  996. if 'MobileNetV3' in self.backbone_name else .01,
  997. scale_x_y=scale_x_y),
  998. nms=nms)
  999. params = {
  1000. 'backbone': backbone,
  1001. 'neck': neck,
  1002. 'yolo_head': yolo_head,
  1003. 'post_process': post_process
  1004. }
  1005. super(YOLOv3, self).__init__(
  1006. model_name='YOLOv3', num_classes=num_classes, **params)
  1007. self.anchors = anchors
  1008. self.anchor_masks = anchor_masks
  1009. self.downsample_ratios = downsample_ratios
  1010. self.model_name = 'PPYOLO'
  1011. class PPYOLOTiny(YOLOv3):
  1012. def __init__(self,
  1013. num_classes=80,
  1014. backbone='MobileNetV3',
  1015. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1016. [60, 170], [220, 125], [128, 222], [264, 266]],
  1017. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1018. use_iou_aware=False,
  1019. use_spp=True,
  1020. use_drop_block=True,
  1021. scale_x_y=1.05,
  1022. ignore_threshold=0.5,
  1023. label_smooth=False,
  1024. use_iou_loss=True,
  1025. use_matrix_nms=False,
  1026. nms_score_threshold=0.005,
  1027. nms_topk=1000,
  1028. nms_keep_topk=100,
  1029. nms_iou_threshold=0.45):
  1030. self.init_params = locals()
  1031. if backbone != 'MobileNetV3':
  1032. logging.warning(
  1033. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1034. "Backbone is forcibly set to MobileNetV3.")
  1035. self.backbone_name = 'MobileNetV3'
  1036. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1037. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1038. norm_type = 'sync_bn'
  1039. else:
  1040. norm_type = 'bn'
  1041. backbone = self._get_backbone(
  1042. 'MobileNetV3',
  1043. model_name='large',
  1044. norm_type=norm_type,
  1045. scale=.5,
  1046. with_extra_blocks=False,
  1047. extra_block_filters=[],
  1048. feature_maps=[7, 13, 16])
  1049. downsample_ratios = [32, 16, 8]
  1050. neck = ppdet.modeling.PPYOLOTinyFPN(
  1051. detection_block_channels=[160, 128, 96],
  1052. in_channels=[i.channels for i in backbone.out_shape],
  1053. spp=use_spp,
  1054. drop_block=use_drop_block)
  1055. loss = ppdet.modeling.YOLOv3Loss(
  1056. num_classes=num_classes,
  1057. ignore_thresh=ignore_threshold,
  1058. downsample=downsample_ratios,
  1059. label_smooth=label_smooth,
  1060. scale_x_y=scale_x_y,
  1061. iou_loss=ppdet.modeling.IouLoss(
  1062. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1063. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1064. if use_iou_aware else None)
  1065. yolo_head = ppdet.modeling.YOLOv3Head(
  1066. in_channels=[i.channels for i in neck.out_shape],
  1067. anchors=anchors,
  1068. anchor_masks=anchor_masks,
  1069. num_classes=num_classes,
  1070. loss=loss,
  1071. iou_aware=use_iou_aware)
  1072. if use_matrix_nms:
  1073. nms = ppdet.modeling.MatrixNMS(
  1074. keep_top_k=nms_keep_topk,
  1075. score_threshold=nms_score_threshold,
  1076. post_threshold=.05,
  1077. nms_top_k=nms_topk,
  1078. background_label=-1)
  1079. else:
  1080. nms = ppdet.modeling.MultiClassNMS(
  1081. score_threshold=nms_score_threshold,
  1082. nms_top_k=nms_topk,
  1083. keep_top_k=nms_keep_topk,
  1084. nms_threshold=nms_iou_threshold)
  1085. post_process = ppdet.modeling.BBoxPostProcess(
  1086. decode=ppdet.modeling.YOLOBox(
  1087. num_classes=num_classes,
  1088. conf_thresh=.005,
  1089. downsample_ratio=32,
  1090. clip_bbox=True,
  1091. scale_x_y=scale_x_y),
  1092. nms=nms)
  1093. params = {
  1094. 'backbone': backbone,
  1095. 'neck': neck,
  1096. 'yolo_head': yolo_head,
  1097. 'post_process': post_process
  1098. }
  1099. super(YOLOv3, self).__init__(
  1100. model_name='YOLOv3', num_classes=num_classes, **params)
  1101. self.anchors = anchors
  1102. self.anchor_masks = anchor_masks
  1103. self.downsample_ratios = downsample_ratios
  1104. self.num_max_boxes = 100
  1105. self.model_name = 'PPYOLOTiny'
  1106. class PPYOLOv2(YOLOv3):
  1107. def __init__(self,
  1108. num_classes=80,
  1109. backbone='ResNet50_vd_dcn',
  1110. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1111. [59, 119], [116, 90], [156, 198], [373, 326]],
  1112. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1113. use_iou_aware=True,
  1114. use_spp=True,
  1115. use_drop_block=True,
  1116. scale_x_y=1.05,
  1117. ignore_threshold=0.7,
  1118. label_smooth=False,
  1119. use_iou_loss=True,
  1120. use_matrix_nms=True,
  1121. nms_score_threshold=0.01,
  1122. nms_topk=-1,
  1123. nms_keep_topk=100,
  1124. nms_iou_threshold=0.45):
  1125. self.init_params = locals()
  1126. if backbone not in ['ResNet50_vd_dcn', 'ResNet101_vd_dcn']:
  1127. raise ValueError(
  1128. "backbone: {} is not supported. Please choose one of "
  1129. "('ResNet50_vd_dcn', 'ResNet18_vd')".format(backbone))
  1130. self.backbone_name = backbone
  1131. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1132. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1133. norm_type = 'sync_bn'
  1134. else:
  1135. norm_type = 'bn'
  1136. if backbone == 'ResNet50_vd_dcn':
  1137. backbone = self._get_backbone(
  1138. 'ResNet',
  1139. variant='d',
  1140. norm_type=norm_type,
  1141. return_idx=[1, 2, 3],
  1142. dcn_v2_stages=[3],
  1143. freeze_at=-1,
  1144. freeze_norm=False,
  1145. norm_decay=0.)
  1146. downsample_ratios = [32, 16, 8]
  1147. elif backbone == 'ResNet101_vd_dcn':
  1148. backbone = self._get_backbone(
  1149. 'ResNet',
  1150. depth=101,
  1151. variant='d',
  1152. norm_type=norm_type,
  1153. return_idx=[1, 2, 3],
  1154. dcn_v2_stages=[3],
  1155. freeze_at=-1,
  1156. freeze_norm=False,
  1157. norm_decay=0.)
  1158. downsample_ratios = [32, 16, 8]
  1159. neck = ppdet.modeling.PPYOLOPAN(
  1160. norm_type=norm_type,
  1161. in_channels=[i.channels for i in backbone.out_shape],
  1162. drop_block=use_drop_block,
  1163. block_size=3,
  1164. keep_prob=.9,
  1165. spp=use_spp)
  1166. loss = ppdet.modeling.YOLOv3Loss(
  1167. num_classes=num_classes,
  1168. ignore_thresh=ignore_threshold,
  1169. downsample=downsample_ratios,
  1170. label_smooth=label_smooth,
  1171. scale_x_y=scale_x_y,
  1172. iou_loss=ppdet.modeling.IouLoss(
  1173. loss_weight=2.5, loss_square=True) if use_iou_loss else None,
  1174. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1175. if use_iou_aware else None)
  1176. yolo_head = ppdet.modeling.YOLOv3Head(
  1177. in_channels=[i.channels for i in neck.out_shape],
  1178. anchors=anchors,
  1179. anchor_masks=anchor_masks,
  1180. num_classes=num_classes,
  1181. loss=loss,
  1182. iou_aware=use_iou_aware,
  1183. iou_aware_factor=.5)
  1184. if use_matrix_nms:
  1185. nms = ppdet.modeling.MatrixNMS(
  1186. keep_top_k=nms_keep_topk,
  1187. score_threshold=nms_score_threshold,
  1188. post_threshold=.01,
  1189. nms_top_k=nms_topk,
  1190. background_label=-1)
  1191. else:
  1192. nms = ppdet.modeling.MultiClassNMS(
  1193. score_threshold=nms_score_threshold,
  1194. nms_top_k=nms_topk,
  1195. keep_top_k=nms_keep_topk,
  1196. nms_threshold=nms_iou_threshold)
  1197. post_process = ppdet.modeling.BBoxPostProcess(
  1198. decode=ppdet.modeling.YOLOBox(
  1199. num_classes=num_classes,
  1200. conf_thresh=.01,
  1201. downsample_ratio=32,
  1202. clip_bbox=True,
  1203. scale_x_y=scale_x_y),
  1204. nms=nms)
  1205. params = {
  1206. 'backbone': backbone,
  1207. 'neck': neck,
  1208. 'yolo_head': yolo_head,
  1209. 'post_process': post_process
  1210. }
  1211. super(YOLOv3, self).__init__(
  1212. model_name='YOLOv3', num_classes=num_classes, **params)
  1213. self.anchors = anchors
  1214. self.anchor_masks = anchor_masks
  1215. self.downsample_ratios = downsample_ratios
  1216. self.num_max_boxes = 100
  1217. self.model_name = 'PPYOLOv2'
  1218. class MaskRCNN(BaseDetector):
  1219. def __init__(self,
  1220. num_classes=80,
  1221. backbone='ResNet50_vd',
  1222. with_fpn=True,
  1223. aspect_ratios=[0.5, 1.0, 2.0],
  1224. anchor_sizes=[[32], [64], [128], [256], [512]],
  1225. keep_top_k=100,
  1226. nms_threshold=0.5,
  1227. score_threshold=0.05,
  1228. fpn_num_channels=256,
  1229. rpn_batch_size_per_im=256,
  1230. rpn_fg_fraction=0.5,
  1231. test_pre_nms_top_n=None,
  1232. test_post_nms_top_n=1000):
  1233. self.init_params = locals()
  1234. if backbone not in [
  1235. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1236. 'ResNet101_vd'
  1237. ]:
  1238. raise ValueError(
  1239. "backbone: {} is not supported. Please choose one of "
  1240. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1241. format(backbone))
  1242. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1243. if backbone == 'ResNet50':
  1244. if with_fpn:
  1245. backbone = self._get_backbone(
  1246. 'ResNet',
  1247. norm_type='bn',
  1248. freeze_at=0,
  1249. return_idx=[0, 1, 2, 3],
  1250. num_stages=4)
  1251. else:
  1252. backbone = self._get_backbone(
  1253. 'ResNet',
  1254. norm_type='bn',
  1255. freeze_at=0,
  1256. return_idx=[2],
  1257. num_stages=3)
  1258. elif 'ResNet50_vd' in backbone:
  1259. if not with_fpn:
  1260. logging.warning(
  1261. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1262. format(backbone))
  1263. with_fpn = True
  1264. backbone = self._get_backbone(
  1265. 'ResNet',
  1266. variant='d',
  1267. norm_type='bn',
  1268. freeze_at=0,
  1269. return_idx=[0, 1, 2, 3],
  1270. num_stages=4,
  1271. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1272. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0])
  1273. else:
  1274. if not with_fpn:
  1275. logging.warning(
  1276. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1277. format(backbone))
  1278. with_fpn = True
  1279. backbone = self._get_backbone(
  1280. 'ResNet',
  1281. variant='d' if '_vd' in backbone else 'b',
  1282. depth=101,
  1283. norm_type='bn',
  1284. freeze_at=0,
  1285. return_idx=[0, 1, 2, 3],
  1286. num_stages=4)
  1287. rpn_in_channel = backbone.out_shape[0].channels
  1288. if with_fpn:
  1289. neck = ppdet.modeling.FPN(
  1290. in_channels=[i.channels for i in backbone.out_shape],
  1291. out_channel=fpn_num_channels,
  1292. spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
  1293. rpn_in_channel = neck.out_shape[0].channels
  1294. anchor_generator_cfg = {
  1295. 'aspect_ratios': aspect_ratios,
  1296. 'anchor_sizes': anchor_sizes,
  1297. 'strides': [4, 8, 16, 32, 64]
  1298. }
  1299. train_proposal_cfg = {
  1300. 'min_size': 0.0,
  1301. 'nms_thresh': .7,
  1302. 'pre_nms_top_n': 2000,
  1303. 'post_nms_top_n': 1000,
  1304. 'topk_after_collect': True
  1305. }
  1306. test_proposal_cfg = {
  1307. 'min_size': 0.0,
  1308. 'nms_thresh': .7,
  1309. 'pre_nms_top_n': 1000
  1310. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1311. 'post_nms_top_n': test_post_nms_top_n
  1312. }
  1313. bb_head = ppdet.modeling.TwoFCHead(
  1314. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1315. bb_roi_extractor_cfg = {
  1316. 'resolution': 7,
  1317. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1318. 'sampling_ratio': 0,
  1319. 'aligned': True
  1320. }
  1321. with_pool = False
  1322. m_head = ppdet.modeling.MaskFeat(
  1323. in_channel=neck.out_shape[0].channels,
  1324. out_channel=256,
  1325. num_convs=4)
  1326. m_roi_extractor_cfg = {
  1327. 'resolution': 14,
  1328. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1329. 'sampling_ratio': 0,
  1330. 'aligned': True
  1331. }
  1332. mask_assigner = MaskAssigner(
  1333. num_classes=num_classes, mask_resolution=28)
  1334. share_bbox_feat = False
  1335. else:
  1336. neck = None
  1337. anchor_generator_cfg = {
  1338. 'aspect_ratios': aspect_ratios,
  1339. 'anchor_sizes': anchor_sizes,
  1340. 'strides': [16]
  1341. }
  1342. train_proposal_cfg = {
  1343. 'min_size': 0.0,
  1344. 'nms_thresh': .7,
  1345. 'pre_nms_top_n': 12000,
  1346. 'post_nms_top_n': 2000,
  1347. 'topk_after_collect': False
  1348. }
  1349. test_proposal_cfg = {
  1350. 'min_size': 0.0,
  1351. 'nms_thresh': .7,
  1352. 'pre_nms_top_n': 6000
  1353. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1354. 'post_nms_top_n': test_post_nms_top_n
  1355. }
  1356. bb_head = ppdet.modeling.Res5Head()
  1357. bb_roi_extractor_cfg = {
  1358. 'resolution': 14,
  1359. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1360. 'sampling_ratio': 0,
  1361. 'aligned': True
  1362. }
  1363. with_pool = True
  1364. m_head = ppdet.modeling.MaskFeat(
  1365. in_channel=bb_head.out_shape[0].channels,
  1366. out_channel=256,
  1367. num_convs=0)
  1368. m_roi_extractor_cfg = {
  1369. 'resolution': 14,
  1370. 'spatial_scale': [1. / i.stride for i in backbone.out_shape],
  1371. 'sampling_ratio': 0,
  1372. 'aligned': True
  1373. }
  1374. mask_assigner = MaskAssigner(
  1375. num_classes=num_classes, mask_resolution=14)
  1376. share_bbox_feat = True
  1377. rpn_target_assign_cfg = {
  1378. 'batch_size_per_im': rpn_batch_size_per_im,
  1379. 'fg_fraction': rpn_fg_fraction,
  1380. 'negative_overlap': .3,
  1381. 'positive_overlap': .7,
  1382. 'use_random': True
  1383. }
  1384. rpn_head = ppdet.modeling.RPNHead(
  1385. anchor_generator=anchor_generator_cfg,
  1386. rpn_target_assign=rpn_target_assign_cfg,
  1387. train_proposal=train_proposal_cfg,
  1388. test_proposal=test_proposal_cfg,
  1389. in_channel=rpn_in_channel)
  1390. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1391. bbox_head = ppdet.modeling.BBoxHead(
  1392. head=bb_head,
  1393. in_channel=bb_head.out_shape[0].channels,
  1394. roi_extractor=bb_roi_extractor_cfg,
  1395. with_pool=with_pool,
  1396. bbox_assigner=bbox_assigner,
  1397. num_classes=num_classes)
  1398. mask_head = ppdet.modeling.MaskHead(
  1399. head=m_head,
  1400. roi_extractor=m_roi_extractor_cfg,
  1401. mask_assigner=mask_assigner,
  1402. share_bbox_feat=share_bbox_feat,
  1403. num_classes=num_classes)
  1404. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1405. num_classes=num_classes,
  1406. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1407. nms=ppdet.modeling.MultiClassNMS(
  1408. score_threshold=score_threshold,
  1409. keep_top_k=keep_top_k,
  1410. nms_threshold=nms_threshold))
  1411. mask_post_process = ppdet.modeling.MaskPostProcess(binary_thresh=.5)
  1412. params = {
  1413. 'backbone': backbone,
  1414. 'neck': neck,
  1415. 'rpn_head': rpn_head,
  1416. 'bbox_head': bbox_head,
  1417. 'mask_head': mask_head,
  1418. 'bbox_post_process': bbox_post_process,
  1419. 'mask_post_process': mask_post_process
  1420. }
  1421. self.with_fpn = with_fpn
  1422. super(MaskRCNN, self).__init__(
  1423. model_name='MaskRCNN', num_classes=num_classes, **params)
  1424. def _compose_batch_transform(self, transforms, mode='train'):
  1425. if mode == 'train':
  1426. default_batch_transforms = [
  1427. _BatchPadding(
  1428. pad_to_stride=32 if self.with_fpn else -1, pad_gt=True)
  1429. ]
  1430. else:
  1431. default_batch_transforms = [
  1432. _BatchPadding(
  1433. pad_to_stride=32 if self.with_fpn else -1, pad_gt=False)
  1434. ]
  1435. custom_batch_transforms = []
  1436. for i, op in enumerate(transforms.transforms):
  1437. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1438. if mode != 'train':
  1439. raise Exception(
  1440. "{} cannot be present in the {} transforms. ".format(
  1441. op.__class__.__name__, mode) +
  1442. "Please check the {} transforms.".format(mode))
  1443. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1444. batch_transforms = BatchCompose(custom_batch_transforms +
  1445. default_batch_transforms)
  1446. return batch_transforms