detector.py 91 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import numpy as np
  20. import paddle
  21. from paddle.static import InputSpec
  22. import paddlex.ppdet as ppdet
  23. from paddlex.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  24. import paddlex
  25. import paddlex.utils.logging as logging
  26. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  27. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
  28. from paddlex.cv.transforms import arrange_transforms
  29. from .base import BaseModel
  30. from .utils.det_metrics import VOCMetric, COCOMetric
  31. from .utils.ema import ExponentialMovingAverage
  32. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  33. __all__ = [
  34. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN",
  35. "PicoDet"
  36. ]
  37. class BaseDetector(BaseModel):
  38. def __init__(self, model_name, num_classes=80, **params):
  39. self.init_params.update(locals())
  40. if 'with_net' in self.init_params:
  41. del self.init_params['with_net']
  42. super(BaseDetector, self).__init__('detector')
  43. if not hasattr(ppdet.modeling, model_name):
  44. raise Exception("ERROR: There's no model named {}.".format(
  45. model_name))
  46. self.model_name = model_name
  47. self.num_classes = num_classes
  48. self.labels = None
  49. if params.get('with_net', True):
  50. params.pop('with_net', None)
  51. self.net = self.build_net(**params)
  52. def build_net(self, **params):
  53. with paddle.utils.unique_name.guard():
  54. net = ppdet.modeling.__dict__[self.model_name](**params)
  55. return net
  56. def _fix_transforms_shape(self, image_shape):
  57. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  58. def _define_input_spec(self, image_shape):
  59. input_spec = [{
  60. "image": InputSpec(
  61. shape=image_shape, name='image', dtype='float32'),
  62. "im_shape": InputSpec(
  63. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  64. "scale_factor": InputSpec(
  65. shape=[image_shape[0], 2],
  66. name='scale_factor',
  67. dtype='float32')
  68. }]
  69. return input_spec
  70. def _check_image_shape(self, image_shape):
  71. if len(image_shape) == 2:
  72. image_shape = [1, 3] + image_shape
  73. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  74. raise Exception(
  75. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  76. format(image_shape[-2:]))
  77. return image_shape
  78. def _get_test_inputs(self, image_shape):
  79. if image_shape is not None:
  80. image_shape = self._check_image_shape(image_shape)
  81. self._fix_transforms_shape(image_shape[-2:])
  82. else:
  83. image_shape = [None, 3, -1, -1]
  84. self.fixed_input_shape = image_shape
  85. return self._define_input_spec(image_shape)
  86. def _get_backbone(self, backbone_name, **params):
  87. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  88. return backbone
  89. def run(self, net, inputs, mode):
  90. net_out = net(inputs)
  91. if mode in ['train', 'eval']:
  92. outputs = net_out
  93. else:
  94. outputs = dict()
  95. for key in net_out:
  96. outputs[key] = net_out[key].numpy()
  97. return outputs
  98. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  99. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  100. num_steps_each_epoch):
  101. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  102. values = [(lr_decay_gamma**i) * learning_rate
  103. for i in range(len(lr_decay_epochs) + 1)]
  104. scheduler = paddle.optimizer.lr.PiecewiseDecay(
  105. boundaries=boundaries, values=values)
  106. if warmup_steps > 0:
  107. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  108. logging.error(
  109. "In function train(), parameters should satisfy: "
  110. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  111. exit=False)
  112. logging.error(
  113. "See this doc for more information: "
  114. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
  115. exit=False)
  116. scheduler = paddle.optimizer.lr.LinearWarmup(
  117. learning_rate=scheduler,
  118. warmup_steps=warmup_steps,
  119. start_lr=warmup_start_lr,
  120. end_lr=learning_rate)
  121. optimizer = paddle.optimizer.Momentum(
  122. scheduler,
  123. momentum=.9,
  124. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  125. parameters=parameters)
  126. return optimizer
  127. def train(self,
  128. num_epochs,
  129. train_dataset,
  130. train_batch_size=64,
  131. eval_dataset=None,
  132. optimizer=None,
  133. save_interval_epochs=1,
  134. log_interval_steps=10,
  135. save_dir='output',
  136. pretrain_weights='IMAGENET',
  137. learning_rate=.001,
  138. warmup_steps=0,
  139. warmup_start_lr=0.0,
  140. lr_decay_epochs=(216, 243),
  141. lr_decay_gamma=0.1,
  142. metric=None,
  143. use_ema=False,
  144. early_stop=False,
  145. early_stop_patience=5,
  146. use_vdl=True,
  147. resume_checkpoint=None):
  148. """
  149. Train the model.
  150. Args:
  151. num_epochs(int): The number of epochs.
  152. train_dataset(paddlex.dataset): Training dataset.
  153. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  154. eval_dataset(paddlex.dataset, optional):
  155. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  156. optimizer(paddle.optimizer.Optimizer or None, optional):
  157. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  158. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  159. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  160. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  161. pretrain_weights(str or None, optional):
  162. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  163. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  164. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  165. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  166. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  167. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  168. metric({'VOC', 'COCO', None}, optional):
  169. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  170. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  171. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  172. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  173. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  174. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  175. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  176. `pretrain_weights` can be set simultaneously. Defaults to None.
  177. """
  178. if self.status == 'Infer':
  179. logging.error(
  180. "Exported inference model does not support training.",
  181. exit=True)
  182. if pretrain_weights is not None and resume_checkpoint is not None:
  183. logging.error(
  184. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  185. exit=True)
  186. if train_dataset.__class__.__name__ == 'VOCDetection':
  187. train_dataset.data_fields = {
  188. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  189. 'difficult'
  190. }
  191. elif train_dataset.__class__.__name__ == 'CocoDetection':
  192. if self.__class__.__name__ == 'MaskRCNN':
  193. train_dataset.data_fields = {
  194. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  195. 'gt_poly', 'is_crowd'
  196. }
  197. else:
  198. train_dataset.data_fields = {
  199. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  200. 'is_crowd'
  201. }
  202. if metric is None:
  203. if eval_dataset.__class__.__name__ == 'VOCDetection':
  204. self.metric = 'voc'
  205. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  206. self.metric = 'coco'
  207. else:
  208. assert metric.lower() in ['coco', 'voc'], \
  209. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  210. self.metric = metric.lower()
  211. self.labels = train_dataset.labels
  212. self.num_max_boxes = train_dataset.num_max_boxes
  213. train_dataset.batch_transforms = self._compose_batch_transform(
  214. train_dataset.transforms, mode='train')
  215. # build optimizer if not defined
  216. if optimizer is None:
  217. num_steps_each_epoch = len(train_dataset) // train_batch_size
  218. self.optimizer = self.default_optimizer(
  219. parameters=self.net.parameters(),
  220. learning_rate=learning_rate,
  221. warmup_steps=warmup_steps,
  222. warmup_start_lr=warmup_start_lr,
  223. lr_decay_epochs=lr_decay_epochs,
  224. lr_decay_gamma=lr_decay_gamma,
  225. num_steps_each_epoch=num_steps_each_epoch)
  226. else:
  227. self.optimizer = optimizer
  228. # initiate weights
  229. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  230. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  231. [self.model_name, self.backbone_name])]:
  232. logging.warning(
  233. "Path of pretrain_weights('{}') does not exist!".format(
  234. pretrain_weights))
  235. pretrain_weights = det_pretrain_weights_dict['_'.join(
  236. [self.model_name, self.backbone_name])][0]
  237. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  238. "If you don't want to use pretrain weights, "
  239. "set pretrain_weights to be None.".format(
  240. pretrain_weights))
  241. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  242. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  243. logging.error(
  244. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  245. exit=True)
  246. pretrained_dir = osp.join(save_dir, 'pretrain')
  247. self.net_initialize(
  248. pretrain_weights=pretrain_weights,
  249. save_dir=pretrained_dir,
  250. resume_checkpoint=resume_checkpoint)
  251. if use_ema:
  252. ema = ExponentialMovingAverage(
  253. decay=.9998, model=self.net, use_thres_step=True)
  254. else:
  255. ema = None
  256. # start train loop
  257. self.train_loop(
  258. num_epochs=num_epochs,
  259. train_dataset=train_dataset,
  260. train_batch_size=train_batch_size,
  261. eval_dataset=eval_dataset,
  262. save_interval_epochs=save_interval_epochs,
  263. log_interval_steps=log_interval_steps,
  264. save_dir=save_dir,
  265. ema=ema,
  266. early_stop=early_stop,
  267. early_stop_patience=early_stop_patience,
  268. use_vdl=use_vdl)
  269. def quant_aware_train(self,
  270. num_epochs,
  271. train_dataset,
  272. train_batch_size=64,
  273. eval_dataset=None,
  274. optimizer=None,
  275. save_interval_epochs=1,
  276. log_interval_steps=10,
  277. save_dir='output',
  278. learning_rate=.00001,
  279. warmup_steps=0,
  280. warmup_start_lr=0.0,
  281. lr_decay_epochs=(216, 243),
  282. lr_decay_gamma=0.1,
  283. metric=None,
  284. use_ema=False,
  285. early_stop=False,
  286. early_stop_patience=5,
  287. use_vdl=True,
  288. resume_checkpoint=None,
  289. quant_config=None):
  290. """
  291. Quantization-aware training.
  292. Args:
  293. num_epochs(int): The number of epochs.
  294. train_dataset(paddlex.dataset): Training dataset.
  295. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  296. eval_dataset(paddlex.dataset, optional):
  297. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  298. optimizer(paddle.optimizer.Optimizer or None, optional):
  299. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  300. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  301. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  302. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  303. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  304. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  305. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  306. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  307. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  308. metric({'VOC', 'COCO', None}, optional):
  309. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  310. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  311. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  312. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  313. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  314. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  315. configuration will be used. Defaults to None.
  316. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  317. from. If None, no training checkpoint will be resumed. Defaults to None.
  318. """
  319. self._prepare_qat(quant_config)
  320. self.train(
  321. num_epochs=num_epochs,
  322. train_dataset=train_dataset,
  323. train_batch_size=train_batch_size,
  324. eval_dataset=eval_dataset,
  325. optimizer=optimizer,
  326. save_interval_epochs=save_interval_epochs,
  327. log_interval_steps=log_interval_steps,
  328. save_dir=save_dir,
  329. pretrain_weights=None,
  330. learning_rate=learning_rate,
  331. warmup_steps=warmup_steps,
  332. warmup_start_lr=warmup_start_lr,
  333. lr_decay_epochs=lr_decay_epochs,
  334. lr_decay_gamma=lr_decay_gamma,
  335. metric=metric,
  336. use_ema=use_ema,
  337. early_stop=early_stop,
  338. early_stop_patience=early_stop_patience,
  339. use_vdl=use_vdl,
  340. resume_checkpoint=resume_checkpoint)
  341. def evaluate(self,
  342. eval_dataset,
  343. batch_size=1,
  344. metric=None,
  345. return_details=False):
  346. """
  347. Evaluate the model.
  348. Args:
  349. eval_dataset(paddlex.dataset): Evaluation dataset.
  350. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  351. metric({'VOC', 'COCO', None}, optional):
  352. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  353. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  354. Returns:
  355. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  356. """
  357. if metric is None:
  358. if not hasattr(self, 'metric'):
  359. if eval_dataset.__class__.__name__ == 'VOCDetection':
  360. self.metric = 'voc'
  361. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  362. self.metric = 'coco'
  363. else:
  364. assert metric.lower() in ['coco', 'voc'], \
  365. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  366. self.metric = metric.lower()
  367. if self.metric == 'voc':
  368. eval_dataset.data_fields = {
  369. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  370. 'difficult'
  371. }
  372. elif self.metric == 'coco':
  373. if self.__class__.__name__ == 'MaskRCNN':
  374. eval_dataset.data_fields = {
  375. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  376. 'gt_poly', 'is_crowd'
  377. }
  378. else:
  379. eval_dataset.data_fields = {
  380. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  381. 'is_crowd'
  382. }
  383. eval_dataset.batch_transforms = self._compose_batch_transform(
  384. eval_dataset.transforms, mode='eval')
  385. arrange_transforms(
  386. model_type=self.model_type,
  387. transforms=eval_dataset.transforms,
  388. mode='eval')
  389. self.net.eval()
  390. nranks = paddle.distributed.get_world_size()
  391. local_rank = paddle.distributed.get_rank()
  392. if nranks > 1:
  393. # Initialize parallel environment if not done.
  394. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  395. ):
  396. paddle.distributed.init_parallel_env()
  397. if batch_size > 1:
  398. logging.warning(
  399. "Detector only supports single card evaluation with batch_size=1 "
  400. "during evaluation, so batch_size is forcibly set to 1.")
  401. batch_size = 1
  402. if nranks < 2 or local_rank == 0:
  403. self.eval_data_loader = self.build_data_loader(
  404. eval_dataset, batch_size=batch_size, mode='eval')
  405. is_bbox_normalized = False
  406. if eval_dataset.batch_transforms is not None:
  407. is_bbox_normalized = any(
  408. isinstance(t, _NormalizeBox)
  409. for t in eval_dataset.batch_transforms.batch_transforms)
  410. if self.metric == 'voc':
  411. eval_metric = VOCMetric(
  412. labels=eval_dataset.labels,
  413. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  414. is_bbox_normalized=is_bbox_normalized,
  415. classwise=False)
  416. else:
  417. eval_metric = COCOMetric(
  418. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  419. classwise=False)
  420. scores = collections.OrderedDict()
  421. logging.info(
  422. "Start to evaluate(total_samples={}, total_steps={})...".
  423. format(eval_dataset.num_samples, eval_dataset.num_samples))
  424. with paddle.no_grad():
  425. for step, data in enumerate(self.eval_data_loader):
  426. outputs = self.run(self.net, data, 'eval')
  427. eval_metric.update(data, outputs)
  428. eval_metric.accumulate()
  429. self.eval_details = eval_metric.details
  430. scores.update(eval_metric.get())
  431. eval_metric.reset()
  432. if return_details:
  433. return scores, self.eval_details
  434. return scores
  435. def predict(self, img_file, transforms=None):
  436. """
  437. Do inference.
  438. Args:
  439. img_file(List[np.ndarray or str], str or np.ndarray):
  440. Image path or decoded image data in a BGR format, which also could constitute a list,
  441. meaning all images to be predicted as a mini-batch.
  442. transforms(paddlex.transforms.Compose or None, optional):
  443. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  444. Returns:
  445. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  446. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  447. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  448. category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
  449. category(str): category name
  450. bbox(list): bounding box in [x, y, w, h] format
  451. score(str): confidence
  452. mask(dict): Only for instance segmentation task. Mask of the object in RLE format
  453. """
  454. if transforms is None and not hasattr(self, 'test_transforms'):
  455. raise Exception("transforms need to be defined, now is None.")
  456. if transforms is None:
  457. transforms = self.test_transforms
  458. if isinstance(img_file, (str, np.ndarray)):
  459. images = [img_file]
  460. else:
  461. images = img_file
  462. batch_samples = self._preprocess(images, transforms)
  463. self.net.eval()
  464. outputs = self.run(self.net, batch_samples, 'test')
  465. prediction = self._postprocess(outputs)
  466. if isinstance(img_file, (str, np.ndarray)):
  467. prediction = prediction[0]
  468. return prediction
  469. def _preprocess(self, images, transforms, to_tensor=True):
  470. arrange_transforms(
  471. model_type=self.model_type, transforms=transforms, mode='test')
  472. batch_samples = list()
  473. for im in images:
  474. sample = {'image': im}
  475. batch_samples.append(transforms(sample))
  476. batch_transforms = self._compose_batch_transform(transforms, 'test')
  477. batch_samples = batch_transforms(batch_samples)
  478. if to_tensor:
  479. for k in batch_samples:
  480. batch_samples[k] = paddle.to_tensor(batch_samples[k])
  481. return batch_samples
  482. def _postprocess(self, batch_pred):
  483. infer_result = {}
  484. if 'bbox' in batch_pred:
  485. bboxes = batch_pred['bbox']
  486. bbox_nums = batch_pred['bbox_num']
  487. det_res = []
  488. k = 0
  489. for i in range(len(bbox_nums)):
  490. det_nums = bbox_nums[i]
  491. for j in range(det_nums):
  492. dt = bboxes[k]
  493. k = k + 1
  494. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  495. if int(num_id) < 0:
  496. continue
  497. category = self.labels[int(num_id)]
  498. w = xmax - xmin
  499. h = ymax - ymin
  500. bbox = [xmin, ymin, w, h]
  501. dt_res = {
  502. 'category_id': int(num_id),
  503. 'category': category,
  504. 'bbox': bbox,
  505. 'score': score
  506. }
  507. det_res.append(dt_res)
  508. infer_result['bbox'] = det_res
  509. if 'mask' in batch_pred:
  510. masks = batch_pred['mask']
  511. bboxes = batch_pred['bbox']
  512. mask_nums = batch_pred['bbox_num']
  513. seg_res = []
  514. k = 0
  515. for i in range(len(mask_nums)):
  516. det_nums = mask_nums[i]
  517. for j in range(det_nums):
  518. mask = masks[k].astype(np.uint8)
  519. score = float(bboxes[k][1])
  520. label = int(bboxes[k][0])
  521. k = k + 1
  522. if label == -1:
  523. continue
  524. category = self.labels[int(label)]
  525. sg_res = {
  526. 'category_id': int(label),
  527. 'category': category,
  528. 'mask': mask.astype('uint8'),
  529. 'score': score
  530. }
  531. seg_res.append(sg_res)
  532. infer_result['mask'] = seg_res
  533. bbox_num = batch_pred['bbox_num']
  534. results = []
  535. start = 0
  536. for num in bbox_num:
  537. end = start + num
  538. curr_res = infer_result['bbox'][start:end]
  539. if 'mask' in infer_result:
  540. mask_res = infer_result['mask'][start:end]
  541. for box, mask in zip(curr_res, mask_res):
  542. box.update(mask)
  543. results.append(curr_res)
  544. start = end
  545. return results
  546. class PicoDet(BaseDetector):
  547. def __init__(self,
  548. num_classes=80,
  549. backbone='ESNet_m',
  550. nms_score_threshold=.025,
  551. nms_top_k=1000,
  552. nms_keep_top_k=100,
  553. nms_iou_threshold=.6,
  554. **params):
  555. self.init_params = locals()
  556. if backbone not in {
  557. 'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3',
  558. 'ResNet18'
  559. }:
  560. raise ValueError(
  561. "backbone: {} is not supported. Please choose one of "
  562. "('ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18')".
  563. format(backbone))
  564. self.backbone_name = backbone
  565. if params.get('with_net', True):
  566. if backbone == 'ESNet_s':
  567. backbone = self._get_backbone(
  568. 'ESNet',
  569. scale=.75,
  570. feature_maps=[4, 11, 14],
  571. act="hard_swish",
  572. channel_ratio=[
  573. 0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5,
  574. 0.5, 0.5, 0.5
  575. ])
  576. neck_out_channels = 96
  577. head_num_convs = 2
  578. elif backbone == 'ESNet_m':
  579. backbone = self._get_backbone(
  580. 'ESNet',
  581. scale=1.0,
  582. feature_maps=[4, 11, 14],
  583. act="hard_swish",
  584. channel_ratio=[
  585. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  586. 0.625, 1.0, 0.625, 0.75
  587. ])
  588. neck_out_channels = 128
  589. head_num_convs = 4
  590. elif backbone == 'ESNet_l':
  591. backbone = self._get_backbone(
  592. 'ESNet',
  593. scale=1.25,
  594. feature_maps=[4, 11, 14],
  595. act="hard_swish",
  596. channel_ratio=[
  597. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  598. 0.625, 1.0, 0.625, 0.75
  599. ])
  600. neck_out_channels = 160
  601. head_num_convs = 4
  602. elif backbone == 'LCNet':
  603. backbone = self._get_backbone(
  604. 'LCNet', scale=1.5, feature_maps=[3, 4, 5])
  605. neck_out_channels = 128
  606. head_num_convs = 4
  607. elif backbone == 'MobileNetV3':
  608. backbone = self._get_backbone(
  609. 'MobileNetV3',
  610. scale=1.0,
  611. with_extra_blocks=False,
  612. extra_block_filters=[],
  613. feature_maps=[7, 13, 16])
  614. neck_out_channels = 128
  615. head_num_convs = 4
  616. else:
  617. backbone = self._get_backbone(
  618. 'ResNet',
  619. depth=18,
  620. variant='d',
  621. return_idx=[1, 2, 3],
  622. freeze_at=-1,
  623. freeze_norm=False,
  624. norm_decay=0.)
  625. neck_out_channels = 128
  626. head_num_convs = 4
  627. neck = ppdet.modeling.CSPPAN(
  628. in_channels=[i.channels for i in backbone.out_shape],
  629. out_channels=neck_out_channels,
  630. num_features=4,
  631. num_csp_blocks=1,
  632. use_depthwise=True)
  633. head_conv_feat = ppdet.modeling.PicoFeat(
  634. feat_in=neck_out_channels,
  635. feat_out=neck_out_channels,
  636. num_fpn_stride=4,
  637. num_convs=head_num_convs,
  638. norm_type='bn',
  639. share_cls_reg=True, )
  640. loss_class = ppdet.modeling.VarifocalLoss(
  641. use_sigmoid=True, iou_weighted=True, loss_weight=1.0)
  642. loss_dfl = ppdet.modeling.DistributionFocalLoss(loss_weight=.25)
  643. loss_bbox = ppdet.modeling.GIoULoss(loss_weight=2.0)
  644. assigner = ppdet.modeling.SimOTAAssigner(
  645. candidate_topk=10, iou_weight=6)
  646. nms = ppdet.modeling.MultiClassNMS(
  647. nms_top_k=nms_top_k,
  648. keep_top_k=nms_keep_top_k,
  649. score_threshold=nms_score_threshold,
  650. nms_threshold=nms_iou_threshold)
  651. head = ppdet.modeling.PicoHead(
  652. conv_feat=head_conv_feat,
  653. num_classes=num_classes,
  654. fpn_stride=[8, 16, 32, 64],
  655. prior_prob=0.01,
  656. loss_class=loss_class,
  657. loss_dfl=loss_dfl,
  658. loss_bbox=loss_bbox,
  659. assigner=assigner,
  660. reg_max=7,
  661. feat_in_chan=neck_out_channels,
  662. nms=nms)
  663. params.update({
  664. 'backbone': backbone,
  665. 'neck': neck,
  666. 'head': head,
  667. })
  668. super(PicoDet, self).__init__(
  669. model_name='PicoDet', num_classes=num_classes, **params)
  670. def _compose_batch_transform(self, transforms, mode='train'):
  671. default_batch_transforms = [_BatchPadding(pad_to_stride=32)]
  672. if mode == 'eval':
  673. collate_batch = True
  674. else:
  675. collate_batch = False
  676. custom_batch_transforms = []
  677. for i, op in enumerate(transforms.transforms):
  678. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  679. if mode != 'train':
  680. raise Exception(
  681. "{} cannot be present in the {} transforms. ".format(
  682. op.__class__.__name__, mode) +
  683. "Please check the {} transforms.".format(mode))
  684. custom_batch_transforms.insert(0, copy.deepcopy(op))
  685. batch_transforms = BatchCompose(
  686. custom_batch_transforms + default_batch_transforms,
  687. collate_batch=collate_batch)
  688. return batch_transforms
  689. class YOLOv3(BaseDetector):
  690. def __init__(self,
  691. num_classes=80,
  692. backbone='MobileNetV1',
  693. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  694. [59, 119], [116, 90], [156, 198], [373, 326]],
  695. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  696. ignore_threshold=0.7,
  697. nms_score_threshold=0.01,
  698. nms_topk=1000,
  699. nms_keep_topk=100,
  700. nms_iou_threshold=0.45,
  701. label_smooth=False,
  702. **params):
  703. self.init_params = locals()
  704. if backbone not in {
  705. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  706. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  707. }:
  708. raise ValueError(
  709. "backbone: {} is not supported. Please choose one of "
  710. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
  711. "'ResNet50_vd_dcn', 'ResNet34')".format(backbone))
  712. self.backbone_name = backbone
  713. if params.get('with_net', True):
  714. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  715. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  716. norm_type = 'sync_bn'
  717. else:
  718. norm_type = 'bn'
  719. if 'MobileNetV1' in backbone:
  720. norm_type = 'bn'
  721. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  722. elif 'MobileNetV3' in backbone:
  723. backbone = self._get_backbone(
  724. 'MobileNetV3',
  725. norm_type=norm_type,
  726. feature_maps=[7, 13, 16])
  727. elif backbone == 'ResNet50_vd_dcn':
  728. backbone = self._get_backbone(
  729. 'ResNet',
  730. norm_type=norm_type,
  731. variant='d',
  732. return_idx=[1, 2, 3],
  733. dcn_v2_stages=[3],
  734. freeze_at=-1,
  735. freeze_norm=False)
  736. elif backbone == 'ResNet34':
  737. backbone = self._get_backbone(
  738. 'ResNet',
  739. depth=34,
  740. norm_type=norm_type,
  741. return_idx=[1, 2, 3],
  742. freeze_at=-1,
  743. freeze_norm=False,
  744. norm_decay=0.)
  745. else:
  746. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  747. neck = ppdet.modeling.YOLOv3FPN(
  748. norm_type=norm_type,
  749. in_channels=[i.channels for i in backbone.out_shape])
  750. loss = ppdet.modeling.YOLOv3Loss(
  751. num_classes=num_classes,
  752. ignore_thresh=ignore_threshold,
  753. label_smooth=label_smooth)
  754. yolo_head = ppdet.modeling.YOLOv3Head(
  755. in_channels=[i.channels for i in neck.out_shape],
  756. anchors=anchors,
  757. anchor_masks=anchor_masks,
  758. num_classes=num_classes,
  759. loss=loss)
  760. post_process = ppdet.modeling.BBoxPostProcess(
  761. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  762. nms=ppdet.modeling.MultiClassNMS(
  763. score_threshold=nms_score_threshold,
  764. nms_top_k=nms_topk,
  765. keep_top_k=nms_keep_topk,
  766. nms_threshold=nms_iou_threshold))
  767. params.update({
  768. 'backbone': backbone,
  769. 'neck': neck,
  770. 'yolo_head': yolo_head,
  771. 'post_process': post_process
  772. })
  773. super(YOLOv3, self).__init__(
  774. model_name='YOLOv3', num_classes=num_classes, **params)
  775. self.anchors = anchors
  776. self.anchor_masks = anchor_masks
  777. def _compose_batch_transform(self, transforms, mode='train'):
  778. if mode == 'train':
  779. default_batch_transforms = [
  780. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  781. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  782. _Gt2YoloTarget(
  783. anchor_masks=self.anchor_masks,
  784. anchors=self.anchors,
  785. downsample_ratios=getattr(self, 'downsample_ratios',
  786. [32, 16, 8]),
  787. num_classes=self.num_classes)
  788. ]
  789. else:
  790. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  791. if mode == 'eval' and self.metric == 'voc':
  792. collate_batch = False
  793. else:
  794. collate_batch = True
  795. custom_batch_transforms = []
  796. for i, op in enumerate(transforms.transforms):
  797. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  798. if mode != 'train':
  799. raise Exception(
  800. "{} cannot be present in the {} transforms. ".format(
  801. op.__class__.__name__, mode) +
  802. "Please check the {} transforms.".format(mode))
  803. custom_batch_transforms.insert(0, copy.deepcopy(op))
  804. batch_transforms = BatchCompose(
  805. custom_batch_transforms + default_batch_transforms,
  806. collate_batch=collate_batch)
  807. return batch_transforms
  808. def _fix_transforms_shape(self, image_shape):
  809. if hasattr(self, 'test_transforms'):
  810. if self.test_transforms is not None:
  811. has_resize_op = False
  812. resize_op_idx = -1
  813. normalize_op_idx = len(self.test_transforms.transforms)
  814. for idx, op in enumerate(self.test_transforms.transforms):
  815. name = op.__class__.__name__
  816. if name == 'Resize':
  817. has_resize_op = True
  818. resize_op_idx = idx
  819. if name == 'Normalize':
  820. normalize_op_idx = idx
  821. if not has_resize_op:
  822. self.test_transforms.transforms.insert(
  823. normalize_op_idx,
  824. Resize(
  825. target_size=image_shape, interp='CUBIC'))
  826. else:
  827. self.test_transforms.transforms[
  828. resize_op_idx].target_size = image_shape
  829. class FasterRCNN(BaseDetector):
  830. def __init__(self,
  831. num_classes=80,
  832. backbone='ResNet50',
  833. with_fpn=True,
  834. with_dcn=False,
  835. aspect_ratios=[0.5, 1.0, 2.0],
  836. anchor_sizes=[[32], [64], [128], [256], [512]],
  837. keep_top_k=100,
  838. nms_threshold=0.5,
  839. score_threshold=0.05,
  840. fpn_num_channels=256,
  841. rpn_batch_size_per_im=256,
  842. rpn_fg_fraction=0.5,
  843. test_pre_nms_top_n=None,
  844. test_post_nms_top_n=1000,
  845. **params):
  846. self.init_params = locals()
  847. if backbone not in {
  848. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  849. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  850. }:
  851. raise ValueError(
  852. "backbone: {} is not supported. Please choose one of "
  853. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  854. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  855. self.backbone_name = backbone
  856. if params.get('with_net', True):
  857. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  858. if backbone == 'HRNet_W18':
  859. if not with_fpn:
  860. logging.warning(
  861. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  862. format(backbone))
  863. with_fpn = True
  864. if with_dcn:
  865. logging.warning(
  866. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  867. format(backbone))
  868. backbone = self._get_backbone(
  869. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  870. elif backbone == 'ResNet50_vd_ssld':
  871. if not with_fpn:
  872. logging.warning(
  873. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  874. format(backbone))
  875. with_fpn = True
  876. backbone = self._get_backbone(
  877. 'ResNet',
  878. variant='d',
  879. norm_type='bn',
  880. freeze_at=0,
  881. return_idx=[0, 1, 2, 3],
  882. num_stages=4,
  883. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  884. dcn_v2_stages=dcn_v2_stages)
  885. elif 'ResNet50' in backbone:
  886. if with_fpn:
  887. backbone = self._get_backbone(
  888. 'ResNet',
  889. variant='d' if '_vd' in backbone else 'b',
  890. norm_type='bn',
  891. freeze_at=0,
  892. return_idx=[0, 1, 2, 3],
  893. num_stages=4,
  894. dcn_v2_stages=dcn_v2_stages)
  895. else:
  896. if with_dcn:
  897. logging.warning(
  898. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  899. format(backbone))
  900. backbone = self._get_backbone(
  901. 'ResNet',
  902. variant='d' if '_vd' in backbone else 'b',
  903. norm_type='bn',
  904. freeze_at=0,
  905. return_idx=[2],
  906. num_stages=3)
  907. elif 'ResNet34' in backbone:
  908. if not with_fpn:
  909. logging.warning(
  910. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  911. format(backbone))
  912. with_fpn = True
  913. backbone = self._get_backbone(
  914. 'ResNet',
  915. depth=34,
  916. variant='d' if 'vd' in backbone else 'b',
  917. norm_type='bn',
  918. freeze_at=0,
  919. return_idx=[0, 1, 2, 3],
  920. num_stages=4,
  921. dcn_v2_stages=dcn_v2_stages)
  922. else:
  923. if not with_fpn:
  924. logging.warning(
  925. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  926. format(backbone))
  927. with_fpn = True
  928. backbone = self._get_backbone(
  929. 'ResNet',
  930. depth=101,
  931. variant='d' if 'vd' in backbone else 'b',
  932. norm_type='bn',
  933. freeze_at=0,
  934. return_idx=[0, 1, 2, 3],
  935. num_stages=4,
  936. dcn_v2_stages=dcn_v2_stages)
  937. rpn_in_channel = backbone.out_shape[0].channels
  938. if with_fpn:
  939. self.backbone_name = self.backbone_name + '_fpn'
  940. if 'HRNet' in self.backbone_name:
  941. neck = ppdet.modeling.HRFPN(
  942. in_channels=[i.channels for i in backbone.out_shape],
  943. out_channel=fpn_num_channels,
  944. spatial_scales=[
  945. 1.0 / i.stride for i in backbone.out_shape
  946. ],
  947. share_conv=False)
  948. else:
  949. neck = ppdet.modeling.FPN(
  950. in_channels=[i.channels for i in backbone.out_shape],
  951. out_channel=fpn_num_channels,
  952. spatial_scales=[
  953. 1.0 / i.stride for i in backbone.out_shape
  954. ])
  955. rpn_in_channel = neck.out_shape[0].channels
  956. anchor_generator_cfg = {
  957. 'aspect_ratios': aspect_ratios,
  958. 'anchor_sizes': anchor_sizes,
  959. 'strides': [4, 8, 16, 32, 64]
  960. }
  961. train_proposal_cfg = {
  962. 'min_size': 0.0,
  963. 'nms_thresh': .7,
  964. 'pre_nms_top_n': 2000,
  965. 'post_nms_top_n': 1000,
  966. 'topk_after_collect': True
  967. }
  968. test_proposal_cfg = {
  969. 'min_size': 0.0,
  970. 'nms_thresh': .7,
  971. 'pre_nms_top_n': 1000
  972. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  973. 'post_nms_top_n': test_post_nms_top_n
  974. }
  975. head = ppdet.modeling.TwoFCHead(
  976. in_channel=neck.out_shape[0].channels, out_channel=1024)
  977. roi_extractor_cfg = {
  978. 'resolution': 7,
  979. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  980. 'sampling_ratio': 0,
  981. 'aligned': True
  982. }
  983. with_pool = False
  984. else:
  985. neck = None
  986. anchor_generator_cfg = {
  987. 'aspect_ratios': aspect_ratios,
  988. 'anchor_sizes': anchor_sizes,
  989. 'strides': [16]
  990. }
  991. train_proposal_cfg = {
  992. 'min_size': 0.0,
  993. 'nms_thresh': .7,
  994. 'pre_nms_top_n': 12000,
  995. 'post_nms_top_n': 2000,
  996. 'topk_after_collect': False
  997. }
  998. test_proposal_cfg = {
  999. 'min_size': 0.0,
  1000. 'nms_thresh': .7,
  1001. 'pre_nms_top_n': 6000
  1002. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1003. 'post_nms_top_n': test_post_nms_top_n
  1004. }
  1005. head = ppdet.modeling.Res5Head()
  1006. roi_extractor_cfg = {
  1007. 'resolution': 14,
  1008. 'spatial_scale':
  1009. [1. / i.stride for i in backbone.out_shape],
  1010. 'sampling_ratio': 0,
  1011. 'aligned': True
  1012. }
  1013. with_pool = True
  1014. rpn_target_assign_cfg = {
  1015. 'batch_size_per_im': rpn_batch_size_per_im,
  1016. 'fg_fraction': rpn_fg_fraction,
  1017. 'negative_overlap': .3,
  1018. 'positive_overlap': .7,
  1019. 'use_random': True
  1020. }
  1021. rpn_head = ppdet.modeling.RPNHead(
  1022. anchor_generator=anchor_generator_cfg,
  1023. rpn_target_assign=rpn_target_assign_cfg,
  1024. train_proposal=train_proposal_cfg,
  1025. test_proposal=test_proposal_cfg,
  1026. in_channel=rpn_in_channel)
  1027. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1028. bbox_head = ppdet.modeling.BBoxHead(
  1029. head=head,
  1030. in_channel=head.out_shape[0].channels,
  1031. roi_extractor=roi_extractor_cfg,
  1032. with_pool=with_pool,
  1033. bbox_assigner=bbox_assigner,
  1034. num_classes=num_classes)
  1035. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1036. num_classes=num_classes,
  1037. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1038. nms=ppdet.modeling.MultiClassNMS(
  1039. score_threshold=score_threshold,
  1040. keep_top_k=keep_top_k,
  1041. nms_threshold=nms_threshold))
  1042. params.update({
  1043. 'backbone': backbone,
  1044. 'neck': neck,
  1045. 'rpn_head': rpn_head,
  1046. 'bbox_head': bbox_head,
  1047. 'bbox_post_process': bbox_post_process
  1048. })
  1049. else:
  1050. if backbone not in {'ResNet50', 'ResNet50_vd'}:
  1051. with_fpn = True
  1052. self.with_fpn = with_fpn
  1053. super(FasterRCNN, self).__init__(
  1054. model_name='FasterRCNN', num_classes=num_classes, **params)
  1055. def train(self,
  1056. num_epochs,
  1057. train_dataset,
  1058. train_batch_size=64,
  1059. eval_dataset=None,
  1060. optimizer=None,
  1061. save_interval_epochs=1,
  1062. log_interval_steps=10,
  1063. save_dir='output',
  1064. pretrain_weights='IMAGENET',
  1065. learning_rate=.001,
  1066. warmup_steps=0,
  1067. warmup_start_lr=0.0,
  1068. lr_decay_epochs=(216, 243),
  1069. lr_decay_gamma=0.1,
  1070. metric=None,
  1071. use_ema=False,
  1072. early_stop=False,
  1073. early_stop_patience=5,
  1074. use_vdl=True,
  1075. resume_checkpoint=None):
  1076. """
  1077. Train the model.
  1078. Args:
  1079. num_epochs(int): The number of epochs.
  1080. train_dataset(paddlex.dataset): Training dataset.
  1081. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  1082. eval_dataset(paddlex.dataset, optional):
  1083. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  1084. optimizer(paddle.optimizer.Optimizer or None, optional):
  1085. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  1086. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  1087. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  1088. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  1089. pretrain_weights(str or None, optional):
  1090. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  1091. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  1092. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  1093. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  1094. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  1095. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  1096. metric({'VOC', 'COCO', None}, optional):
  1097. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  1098. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  1099. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  1100. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  1101. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  1102. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  1103. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  1104. `pretrain_weights` can be set simultaneously. Defaults to None.
  1105. """
  1106. if train_dataset.pos_num < len(train_dataset.file_list):
  1107. train_dataset.num_workers = 0
  1108. if train_batch_size != 1:
  1109. train_batch_size = 1
  1110. logging.warning(
  1111. "Training RCNN models with negative samples only support batch size equals to 1 "
  1112. "on a single gpu/cpu card, `train_batch_size` is forcibly set to 1."
  1113. )
  1114. nranks = paddle.distributed.get_world_size()
  1115. local_rank = paddle.distributed.get_rank()
  1116. # single card training
  1117. if nranks < 2 or local_rank == 0:
  1118. super(FasterRCNN, self).train(
  1119. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1120. optimizer, save_interval_epochs, log_interval_steps,
  1121. save_dir, pretrain_weights, learning_rate, warmup_steps,
  1122. warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric,
  1123. use_ema, early_stop, early_stop_patience, use_vdl,
  1124. resume_checkpoint)
  1125. else:
  1126. super(FasterRCNN, self).train(
  1127. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1128. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  1129. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  1130. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  1131. early_stop_patience, use_vdl, resume_checkpoint)
  1132. def _compose_batch_transform(self, transforms, mode='train'):
  1133. if mode == 'train':
  1134. default_batch_transforms = [
  1135. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1136. ]
  1137. collate_batch = False
  1138. else:
  1139. default_batch_transforms = [
  1140. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1141. ]
  1142. collate_batch = True
  1143. custom_batch_transforms = []
  1144. for i, op in enumerate(transforms.transforms):
  1145. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1146. if mode != 'train':
  1147. raise Exception(
  1148. "{} cannot be present in the {} transforms. ".format(
  1149. op.__class__.__name__, mode) +
  1150. "Please check the {} transforms.".format(mode))
  1151. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1152. batch_transforms = BatchCompose(
  1153. custom_batch_transforms + default_batch_transforms,
  1154. collate_batch=collate_batch)
  1155. return batch_transforms
  1156. def _fix_transforms_shape(self, image_shape):
  1157. if hasattr(self, 'test_transforms'):
  1158. if self.test_transforms is not None:
  1159. has_resize_op = False
  1160. resize_op_idx = -1
  1161. normalize_op_idx = len(self.test_transforms.transforms)
  1162. for idx, op in enumerate(self.test_transforms.transforms):
  1163. name = op.__class__.__name__
  1164. if name == 'ResizeByShort':
  1165. has_resize_op = True
  1166. resize_op_idx = idx
  1167. if name == 'Normalize':
  1168. normalize_op_idx = idx
  1169. if not has_resize_op:
  1170. self.test_transforms.transforms.insert(
  1171. normalize_op_idx,
  1172. Resize(
  1173. target_size=image_shape,
  1174. keep_ratio=True,
  1175. interp='CUBIC'))
  1176. else:
  1177. self.test_transforms.transforms[resize_op_idx] = Resize(
  1178. target_size=image_shape,
  1179. keep_ratio=True,
  1180. interp='CUBIC')
  1181. self.test_transforms.transforms.append(
  1182. Padding(im_padding_value=[0., 0., 0.]))
  1183. def _get_test_inputs(self, image_shape):
  1184. if image_shape is not None:
  1185. image_shape = self._check_image_shape(image_shape)
  1186. self._fix_transforms_shape(image_shape[-2:])
  1187. else:
  1188. image_shape = [None, 3, -1, -1]
  1189. if self.with_fpn:
  1190. self.test_transforms.transforms.append(
  1191. Padding(im_padding_value=[0., 0., 0.]))
  1192. self.fixed_input_shape = image_shape
  1193. return self._define_input_spec(image_shape)
  1194. class PPYOLO(YOLOv3):
  1195. def __init__(self,
  1196. num_classes=80,
  1197. backbone='ResNet50_vd_dcn',
  1198. anchors=None,
  1199. anchor_masks=None,
  1200. use_coord_conv=True,
  1201. use_iou_aware=True,
  1202. use_spp=True,
  1203. use_drop_block=True,
  1204. scale_x_y=1.05,
  1205. ignore_threshold=0.7,
  1206. label_smooth=False,
  1207. use_iou_loss=True,
  1208. use_matrix_nms=True,
  1209. nms_score_threshold=0.01,
  1210. nms_topk=-1,
  1211. nms_keep_topk=100,
  1212. nms_iou_threshold=0.45,
  1213. **params):
  1214. self.init_params = locals()
  1215. if backbone not in {
  1216. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  1217. 'MobileNetV3_small'
  1218. }:
  1219. raise ValueError(
  1220. "backbone: {} is not supported. Please choose one of "
  1221. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  1222. format(backbone))
  1223. self.backbone_name = backbone
  1224. self.downsample_ratios = [
  1225. 32, 16, 8
  1226. ] if backbone == 'ResNet50_vd_dcn' else [32, 16]
  1227. if params.get('with_net', True):
  1228. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1229. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1230. norm_type = 'sync_bn'
  1231. else:
  1232. norm_type = 'bn'
  1233. if anchors is None and anchor_masks is None:
  1234. if 'MobileNetV3' in backbone:
  1235. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1236. [120, 195], [254, 235]]
  1237. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1238. elif backbone == 'ResNet50_vd_dcn':
  1239. anchors = [[10, 13], [16, 30], [33, 23], [30, 61],
  1240. [62, 45], [59, 119], [116, 90], [156, 198],
  1241. [373, 326]]
  1242. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1243. else:
  1244. anchors = [[10, 14], [23, 27], [37, 58], [81, 82],
  1245. [135, 169], [344, 319]]
  1246. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1247. elif anchors is None or anchor_masks is None:
  1248. raise ValueError(
  1249. "Please define both anchors and anchor_masks.")
  1250. if backbone == 'ResNet50_vd_dcn':
  1251. backbone = self._get_backbone(
  1252. 'ResNet',
  1253. variant='d',
  1254. norm_type=norm_type,
  1255. return_idx=[1, 2, 3],
  1256. dcn_v2_stages=[3],
  1257. freeze_at=-1,
  1258. freeze_norm=False,
  1259. norm_decay=0.)
  1260. elif backbone == 'ResNet18_vd':
  1261. backbone = self._get_backbone(
  1262. 'ResNet',
  1263. depth=18,
  1264. variant='d',
  1265. norm_type=norm_type,
  1266. return_idx=[2, 3],
  1267. freeze_at=-1,
  1268. freeze_norm=False,
  1269. norm_decay=0.)
  1270. elif backbone == 'MobileNetV3_large':
  1271. backbone = self._get_backbone(
  1272. 'MobileNetV3',
  1273. model_name='large',
  1274. norm_type=norm_type,
  1275. scale=1,
  1276. with_extra_blocks=False,
  1277. extra_block_filters=[],
  1278. feature_maps=[13, 16])
  1279. elif backbone == 'MobileNetV3_small':
  1280. backbone = self._get_backbone(
  1281. 'MobileNetV3',
  1282. model_name='small',
  1283. norm_type=norm_type,
  1284. scale=1,
  1285. with_extra_blocks=False,
  1286. extra_block_filters=[],
  1287. feature_maps=[9, 12])
  1288. neck = ppdet.modeling.PPYOLOFPN(
  1289. norm_type=norm_type,
  1290. in_channels=[i.channels for i in backbone.out_shape],
  1291. coord_conv=use_coord_conv,
  1292. drop_block=use_drop_block,
  1293. spp=use_spp,
  1294. conv_block_num=0
  1295. if ('MobileNetV3' in self.backbone_name or
  1296. self.backbone_name == 'ResNet18_vd') else 2)
  1297. loss = ppdet.modeling.YOLOv3Loss(
  1298. num_classes=num_classes,
  1299. ignore_thresh=ignore_threshold,
  1300. downsample=self.downsample_ratios,
  1301. label_smooth=label_smooth,
  1302. scale_x_y=scale_x_y,
  1303. iou_loss=ppdet.modeling.IouLoss(
  1304. loss_weight=2.5, loss_square=True)
  1305. if use_iou_loss else None,
  1306. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1307. if use_iou_aware else None)
  1308. yolo_head = ppdet.modeling.YOLOv3Head(
  1309. in_channels=[i.channels for i in neck.out_shape],
  1310. anchors=anchors,
  1311. anchor_masks=anchor_masks,
  1312. num_classes=num_classes,
  1313. loss=loss,
  1314. iou_aware=use_iou_aware)
  1315. if use_matrix_nms:
  1316. nms = ppdet.modeling.MatrixNMS(
  1317. keep_top_k=nms_keep_topk,
  1318. score_threshold=nms_score_threshold,
  1319. post_threshold=.05
  1320. if 'MobileNetV3' in self.backbone_name else .01,
  1321. nms_top_k=nms_topk,
  1322. background_label=-1)
  1323. else:
  1324. nms = ppdet.modeling.MultiClassNMS(
  1325. score_threshold=nms_score_threshold,
  1326. nms_top_k=nms_topk,
  1327. keep_top_k=nms_keep_topk,
  1328. nms_threshold=nms_iou_threshold)
  1329. post_process = ppdet.modeling.BBoxPostProcess(
  1330. decode=ppdet.modeling.YOLOBox(
  1331. num_classes=num_classes,
  1332. conf_thresh=.005
  1333. if 'MobileNetV3' in self.backbone_name else .01,
  1334. scale_x_y=scale_x_y),
  1335. nms=nms)
  1336. params.update({
  1337. 'backbone': backbone,
  1338. 'neck': neck,
  1339. 'yolo_head': yolo_head,
  1340. 'post_process': post_process
  1341. })
  1342. super(YOLOv3, self).__init__(
  1343. model_name='YOLOv3', num_classes=num_classes, **params)
  1344. self.anchors = anchors
  1345. self.anchor_masks = anchor_masks
  1346. self.model_name = 'PPYOLO'
  1347. def _get_test_inputs(self, image_shape):
  1348. if image_shape is not None:
  1349. image_shape = self._check_image_shape(image_shape)
  1350. self._fix_transforms_shape(image_shape[-2:])
  1351. else:
  1352. image_shape = [None, 3, 608, 608]
  1353. if hasattr(self, 'test_transforms'):
  1354. if self.test_transforms is not None:
  1355. for idx, op in enumerate(self.test_transforms.transforms):
  1356. name = op.__class__.__name__
  1357. if name == 'Resize':
  1358. image_shape = [None, 3] + list(
  1359. self.test_transforms.transforms[
  1360. idx].target_size)
  1361. logging.warning(
  1362. '[Important!!!] When exporting inference model for {},'.format(
  1363. self.__class__.__name__) +
  1364. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1365. format(image_shape) +
  1366. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1367. format(image_shape[1:]) + 'should be specified manually.')
  1368. self.fixed_input_shape = image_shape
  1369. return self._define_input_spec(image_shape)
  1370. class PPYOLOTiny(YOLOv3):
  1371. def __init__(self,
  1372. num_classes=80,
  1373. backbone='MobileNetV3',
  1374. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1375. [60, 170], [220, 125], [128, 222], [264, 266]],
  1376. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1377. use_iou_aware=False,
  1378. use_spp=True,
  1379. use_drop_block=True,
  1380. scale_x_y=1.05,
  1381. ignore_threshold=0.5,
  1382. label_smooth=False,
  1383. use_iou_loss=True,
  1384. use_matrix_nms=False,
  1385. nms_score_threshold=0.005,
  1386. nms_topk=1000,
  1387. nms_keep_topk=100,
  1388. nms_iou_threshold=0.45,
  1389. **params):
  1390. self.init_params = locals()
  1391. if backbone != 'MobileNetV3':
  1392. logging.warning(
  1393. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1394. "Backbone is forcibly set to MobileNetV3.")
  1395. self.backbone_name = 'MobileNetV3'
  1396. self.downsample_ratios = [32, 16, 8]
  1397. if params.get('with_net', True):
  1398. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1399. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1400. norm_type = 'sync_bn'
  1401. else:
  1402. norm_type = 'bn'
  1403. backbone = self._get_backbone(
  1404. 'MobileNetV3',
  1405. model_name='large',
  1406. norm_type=norm_type,
  1407. scale=.5,
  1408. with_extra_blocks=False,
  1409. extra_block_filters=[],
  1410. feature_maps=[7, 13, 16])
  1411. neck = ppdet.modeling.PPYOLOTinyFPN(
  1412. detection_block_channels=[160, 128, 96],
  1413. in_channels=[i.channels for i in backbone.out_shape],
  1414. spp=use_spp,
  1415. drop_block=use_drop_block)
  1416. loss = ppdet.modeling.YOLOv3Loss(
  1417. num_classes=num_classes,
  1418. ignore_thresh=ignore_threshold,
  1419. downsample=self.downsample_ratios,
  1420. label_smooth=label_smooth,
  1421. scale_x_y=scale_x_y,
  1422. iou_loss=ppdet.modeling.IouLoss(
  1423. loss_weight=2.5, loss_square=True)
  1424. if use_iou_loss else None,
  1425. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1426. if use_iou_aware else None)
  1427. yolo_head = ppdet.modeling.YOLOv3Head(
  1428. in_channels=[i.channels for i in neck.out_shape],
  1429. anchors=anchors,
  1430. anchor_masks=anchor_masks,
  1431. num_classes=num_classes,
  1432. loss=loss,
  1433. iou_aware=use_iou_aware)
  1434. if use_matrix_nms:
  1435. nms = ppdet.modeling.MatrixNMS(
  1436. keep_top_k=nms_keep_topk,
  1437. score_threshold=nms_score_threshold,
  1438. post_threshold=.05,
  1439. nms_top_k=nms_topk,
  1440. background_label=-1)
  1441. else:
  1442. nms = ppdet.modeling.MultiClassNMS(
  1443. score_threshold=nms_score_threshold,
  1444. nms_top_k=nms_topk,
  1445. keep_top_k=nms_keep_topk,
  1446. nms_threshold=nms_iou_threshold)
  1447. post_process = ppdet.modeling.BBoxPostProcess(
  1448. decode=ppdet.modeling.YOLOBox(
  1449. num_classes=num_classes,
  1450. conf_thresh=.005,
  1451. downsample_ratio=32,
  1452. clip_bbox=True,
  1453. scale_x_y=scale_x_y),
  1454. nms=nms)
  1455. params.update({
  1456. 'backbone': backbone,
  1457. 'neck': neck,
  1458. 'yolo_head': yolo_head,
  1459. 'post_process': post_process
  1460. })
  1461. super(YOLOv3, self).__init__(
  1462. model_name='YOLOv3', num_classes=num_classes, **params)
  1463. self.anchors = anchors
  1464. self.anchor_masks = anchor_masks
  1465. self.model_name = 'PPYOLOTiny'
  1466. def _get_test_inputs(self, image_shape):
  1467. if image_shape is not None:
  1468. image_shape = self._check_image_shape(image_shape)
  1469. self._fix_transforms_shape(image_shape[-2:])
  1470. else:
  1471. image_shape = [None, 3, 320, 320]
  1472. if hasattr(self, 'test_transforms'):
  1473. if self.test_transforms is not None:
  1474. for idx, op in enumerate(self.test_transforms.transforms):
  1475. name = op.__class__.__name__
  1476. if name == 'Resize':
  1477. image_shape = [None, 3] + list(
  1478. self.test_transforms.transforms[
  1479. idx].target_size)
  1480. logging.warning(
  1481. '[Important!!!] When exporting inference model for {},'.format(
  1482. self.__class__.__name__) +
  1483. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1484. format(image_shape) +
  1485. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1486. format(image_shape[1:]) + 'should be specified manually.')
  1487. self.fixed_input_shape = image_shape
  1488. return self._define_input_spec(image_shape)
  1489. class PPYOLOv2(YOLOv3):
  1490. def __init__(self,
  1491. num_classes=80,
  1492. backbone='ResNet50_vd_dcn',
  1493. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1494. [59, 119], [116, 90], [156, 198], [373, 326]],
  1495. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1496. use_iou_aware=True,
  1497. use_spp=True,
  1498. use_drop_block=True,
  1499. scale_x_y=1.05,
  1500. ignore_threshold=0.7,
  1501. label_smooth=False,
  1502. use_iou_loss=True,
  1503. use_matrix_nms=True,
  1504. nms_score_threshold=0.01,
  1505. nms_topk=-1,
  1506. nms_keep_topk=100,
  1507. nms_iou_threshold=0.45,
  1508. **params):
  1509. self.init_params = locals()
  1510. if backbone not in {'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}:
  1511. raise ValueError(
  1512. "backbone: {} is not supported. Please choose one of "
  1513. "('ResNet50_vd_dcn', 'ResNet101_vd_dcn')".format(backbone))
  1514. self.backbone_name = backbone
  1515. self.downsample_ratios = [32, 16, 8]
  1516. if params.get('with_net', True):
  1517. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1518. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1519. norm_type = 'sync_bn'
  1520. else:
  1521. norm_type = 'bn'
  1522. if backbone == 'ResNet50_vd_dcn':
  1523. backbone = self._get_backbone(
  1524. 'ResNet',
  1525. variant='d',
  1526. norm_type=norm_type,
  1527. return_idx=[1, 2, 3],
  1528. dcn_v2_stages=[3],
  1529. freeze_at=-1,
  1530. freeze_norm=False,
  1531. norm_decay=0.)
  1532. elif backbone == 'ResNet101_vd_dcn':
  1533. backbone = self._get_backbone(
  1534. 'ResNet',
  1535. depth=101,
  1536. variant='d',
  1537. norm_type=norm_type,
  1538. return_idx=[1, 2, 3],
  1539. dcn_v2_stages=[3],
  1540. freeze_at=-1,
  1541. freeze_norm=False,
  1542. norm_decay=0.)
  1543. neck = ppdet.modeling.PPYOLOPAN(
  1544. norm_type=norm_type,
  1545. in_channels=[i.channels for i in backbone.out_shape],
  1546. drop_block=use_drop_block,
  1547. block_size=3,
  1548. keep_prob=.9,
  1549. spp=use_spp)
  1550. loss = ppdet.modeling.YOLOv3Loss(
  1551. num_classes=num_classes,
  1552. ignore_thresh=ignore_threshold,
  1553. downsample=self.downsample_ratios,
  1554. label_smooth=label_smooth,
  1555. scale_x_y=scale_x_y,
  1556. iou_loss=ppdet.modeling.IouLoss(
  1557. loss_weight=2.5, loss_square=True)
  1558. if use_iou_loss else None,
  1559. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1560. if use_iou_aware else None)
  1561. yolo_head = ppdet.modeling.YOLOv3Head(
  1562. in_channels=[i.channels for i in neck.out_shape],
  1563. anchors=anchors,
  1564. anchor_masks=anchor_masks,
  1565. num_classes=num_classes,
  1566. loss=loss,
  1567. iou_aware=use_iou_aware,
  1568. iou_aware_factor=.5)
  1569. if use_matrix_nms:
  1570. nms = ppdet.modeling.MatrixNMS(
  1571. keep_top_k=nms_keep_topk,
  1572. score_threshold=nms_score_threshold,
  1573. post_threshold=.01,
  1574. nms_top_k=nms_topk,
  1575. background_label=-1)
  1576. else:
  1577. nms = ppdet.modeling.MultiClassNMS(
  1578. score_threshold=nms_score_threshold,
  1579. nms_top_k=nms_topk,
  1580. keep_top_k=nms_keep_topk,
  1581. nms_threshold=nms_iou_threshold)
  1582. post_process = ppdet.modeling.BBoxPostProcess(
  1583. decode=ppdet.modeling.YOLOBox(
  1584. num_classes=num_classes,
  1585. conf_thresh=.01,
  1586. downsample_ratio=32,
  1587. clip_bbox=True,
  1588. scale_x_y=scale_x_y),
  1589. nms=nms)
  1590. params.update({
  1591. 'backbone': backbone,
  1592. 'neck': neck,
  1593. 'yolo_head': yolo_head,
  1594. 'post_process': post_process
  1595. })
  1596. super(YOLOv3, self).__init__(
  1597. model_name='YOLOv3', num_classes=num_classes, **params)
  1598. self.anchors = anchors
  1599. self.anchor_masks = anchor_masks
  1600. self.model_name = 'PPYOLOv2'
  1601. def _get_test_inputs(self, image_shape):
  1602. if image_shape is not None:
  1603. image_shape = self._check_image_shape(image_shape)
  1604. self._fix_transforms_shape(image_shape[-2:])
  1605. else:
  1606. image_shape = [None, 3, 640, 640]
  1607. if hasattr(self, 'test_transforms'):
  1608. if self.test_transforms is not None:
  1609. for idx, op in enumerate(self.test_transforms.transforms):
  1610. name = op.__class__.__name__
  1611. if name == 'Resize':
  1612. image_shape = [None, 3] + list(
  1613. self.test_transforms.transforms[
  1614. idx].target_size)
  1615. logging.warning(
  1616. '[Important!!!] When exporting inference model for {},'.format(
  1617. self.__class__.__name__) +
  1618. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1619. format(image_shape) +
  1620. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1621. format(image_shape[1:]) + 'should be specified manually.')
  1622. self.fixed_input_shape = image_shape
  1623. return self._define_input_spec(image_shape)
  1624. class MaskRCNN(BaseDetector):
  1625. def __init__(self,
  1626. num_classes=80,
  1627. backbone='ResNet50_vd',
  1628. with_fpn=True,
  1629. with_dcn=False,
  1630. aspect_ratios=[0.5, 1.0, 2.0],
  1631. anchor_sizes=[[32], [64], [128], [256], [512]],
  1632. keep_top_k=100,
  1633. nms_threshold=0.5,
  1634. score_threshold=0.05,
  1635. fpn_num_channels=256,
  1636. rpn_batch_size_per_im=256,
  1637. rpn_fg_fraction=0.5,
  1638. test_pre_nms_top_n=None,
  1639. test_post_nms_top_n=1000,
  1640. **params):
  1641. self.init_params = locals()
  1642. if backbone not in {
  1643. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1644. 'ResNet101_vd'
  1645. }:
  1646. raise ValueError(
  1647. "backbone: {} is not supported. Please choose one of "
  1648. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1649. format(backbone))
  1650. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1651. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1652. if params.get('with_net', True):
  1653. if backbone == 'ResNet50':
  1654. if with_fpn:
  1655. backbone = self._get_backbone(
  1656. 'ResNet',
  1657. norm_type='bn',
  1658. freeze_at=0,
  1659. return_idx=[0, 1, 2, 3],
  1660. num_stages=4,
  1661. dcn_v2_stages=dcn_v2_stages)
  1662. else:
  1663. if with_dcn:
  1664. logging.warning(
  1665. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1666. format(backbone))
  1667. backbone = self._get_backbone(
  1668. 'ResNet',
  1669. norm_type='bn',
  1670. freeze_at=0,
  1671. return_idx=[2],
  1672. num_stages=3)
  1673. elif 'ResNet50_vd' in backbone:
  1674. if not with_fpn:
  1675. logging.warning(
  1676. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1677. format(backbone))
  1678. with_fpn = True
  1679. backbone = self._get_backbone(
  1680. 'ResNet',
  1681. variant='d',
  1682. norm_type='bn',
  1683. freeze_at=0,
  1684. return_idx=[0, 1, 2, 3],
  1685. num_stages=4,
  1686. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1687. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1688. dcn_v2_stages=dcn_v2_stages)
  1689. else:
  1690. if not with_fpn:
  1691. logging.warning(
  1692. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1693. format(backbone))
  1694. with_fpn = True
  1695. backbone = self._get_backbone(
  1696. 'ResNet',
  1697. variant='d' if '_vd' in backbone else 'b',
  1698. depth=101,
  1699. norm_type='bn',
  1700. freeze_at=0,
  1701. return_idx=[0, 1, 2, 3],
  1702. num_stages=4,
  1703. dcn_v2_stages=dcn_v2_stages)
  1704. rpn_in_channel = backbone.out_shape[0].channels
  1705. if with_fpn:
  1706. neck = ppdet.modeling.FPN(
  1707. in_channels=[i.channels for i in backbone.out_shape],
  1708. out_channel=fpn_num_channels,
  1709. spatial_scales=[
  1710. 1.0 / i.stride for i in backbone.out_shape
  1711. ])
  1712. rpn_in_channel = neck.out_shape[0].channels
  1713. anchor_generator_cfg = {
  1714. 'aspect_ratios': aspect_ratios,
  1715. 'anchor_sizes': anchor_sizes,
  1716. 'strides': [4, 8, 16, 32, 64]
  1717. }
  1718. train_proposal_cfg = {
  1719. 'min_size': 0.0,
  1720. 'nms_thresh': .7,
  1721. 'pre_nms_top_n': 2000,
  1722. 'post_nms_top_n': 1000,
  1723. 'topk_after_collect': True
  1724. }
  1725. test_proposal_cfg = {
  1726. 'min_size': 0.0,
  1727. 'nms_thresh': .7,
  1728. 'pre_nms_top_n': 1000
  1729. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1730. 'post_nms_top_n': test_post_nms_top_n
  1731. }
  1732. bb_head = ppdet.modeling.TwoFCHead(
  1733. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1734. bb_roi_extractor_cfg = {
  1735. 'resolution': 7,
  1736. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1737. 'sampling_ratio': 0,
  1738. 'aligned': True
  1739. }
  1740. with_pool = False
  1741. m_head = ppdet.modeling.MaskFeat(
  1742. in_channel=neck.out_shape[0].channels,
  1743. out_channel=256,
  1744. num_convs=4)
  1745. m_roi_extractor_cfg = {
  1746. 'resolution': 14,
  1747. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1748. 'sampling_ratio': 0,
  1749. 'aligned': True
  1750. }
  1751. mask_assigner = MaskAssigner(
  1752. num_classes=num_classes, mask_resolution=28)
  1753. share_bbox_feat = False
  1754. else:
  1755. neck = None
  1756. anchor_generator_cfg = {
  1757. 'aspect_ratios': aspect_ratios,
  1758. 'anchor_sizes': anchor_sizes,
  1759. 'strides': [16]
  1760. }
  1761. train_proposal_cfg = {
  1762. 'min_size': 0.0,
  1763. 'nms_thresh': .7,
  1764. 'pre_nms_top_n': 12000,
  1765. 'post_nms_top_n': 2000,
  1766. 'topk_after_collect': False
  1767. }
  1768. test_proposal_cfg = {
  1769. 'min_size': 0.0,
  1770. 'nms_thresh': .7,
  1771. 'pre_nms_top_n': 6000
  1772. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1773. 'post_nms_top_n': test_post_nms_top_n
  1774. }
  1775. bb_head = ppdet.modeling.Res5Head()
  1776. bb_roi_extractor_cfg = {
  1777. 'resolution': 14,
  1778. 'spatial_scale':
  1779. [1. / i.stride for i in backbone.out_shape],
  1780. 'sampling_ratio': 0,
  1781. 'aligned': True
  1782. }
  1783. with_pool = True
  1784. m_head = ppdet.modeling.MaskFeat(
  1785. in_channel=bb_head.out_shape[0].channels,
  1786. out_channel=256,
  1787. num_convs=0)
  1788. m_roi_extractor_cfg = {
  1789. 'resolution': 14,
  1790. 'spatial_scale':
  1791. [1. / i.stride for i in backbone.out_shape],
  1792. 'sampling_ratio': 0,
  1793. 'aligned': True
  1794. }
  1795. mask_assigner = MaskAssigner(
  1796. num_classes=num_classes, mask_resolution=14)
  1797. share_bbox_feat = True
  1798. rpn_target_assign_cfg = {
  1799. 'batch_size_per_im': rpn_batch_size_per_im,
  1800. 'fg_fraction': rpn_fg_fraction,
  1801. 'negative_overlap': .3,
  1802. 'positive_overlap': .7,
  1803. 'use_random': True
  1804. }
  1805. rpn_head = ppdet.modeling.RPNHead(
  1806. anchor_generator=anchor_generator_cfg,
  1807. rpn_target_assign=rpn_target_assign_cfg,
  1808. train_proposal=train_proposal_cfg,
  1809. test_proposal=test_proposal_cfg,
  1810. in_channel=rpn_in_channel)
  1811. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1812. bbox_head = ppdet.modeling.BBoxHead(
  1813. head=bb_head,
  1814. in_channel=bb_head.out_shape[0].channels,
  1815. roi_extractor=bb_roi_extractor_cfg,
  1816. with_pool=with_pool,
  1817. bbox_assigner=bbox_assigner,
  1818. num_classes=num_classes)
  1819. mask_head = ppdet.modeling.MaskHead(
  1820. head=m_head,
  1821. roi_extractor=m_roi_extractor_cfg,
  1822. mask_assigner=mask_assigner,
  1823. share_bbox_feat=share_bbox_feat,
  1824. num_classes=num_classes)
  1825. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1826. num_classes=num_classes,
  1827. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1828. nms=ppdet.modeling.MultiClassNMS(
  1829. score_threshold=score_threshold,
  1830. keep_top_k=keep_top_k,
  1831. nms_threshold=nms_threshold))
  1832. mask_post_process = ppdet.modeling.MaskPostProcess(
  1833. binary_thresh=.5)
  1834. params.update({
  1835. 'backbone': backbone,
  1836. 'neck': neck,
  1837. 'rpn_head': rpn_head,
  1838. 'bbox_head': bbox_head,
  1839. 'mask_head': mask_head,
  1840. 'bbox_post_process': bbox_post_process,
  1841. 'mask_post_process': mask_post_process
  1842. })
  1843. self.with_fpn = with_fpn
  1844. super(MaskRCNN, self).__init__(
  1845. model_name='MaskRCNN', num_classes=num_classes, **params)
  1846. def train(self,
  1847. num_epochs,
  1848. train_dataset,
  1849. train_batch_size=64,
  1850. eval_dataset=None,
  1851. optimizer=None,
  1852. save_interval_epochs=1,
  1853. log_interval_steps=10,
  1854. save_dir='output',
  1855. pretrain_weights='IMAGENET',
  1856. learning_rate=.001,
  1857. warmup_steps=0,
  1858. warmup_start_lr=0.0,
  1859. lr_decay_epochs=(216, 243),
  1860. lr_decay_gamma=0.1,
  1861. metric=None,
  1862. use_ema=False,
  1863. early_stop=False,
  1864. early_stop_patience=5,
  1865. use_vdl=True,
  1866. resume_checkpoint=None):
  1867. """
  1868. Train the model.
  1869. Args:
  1870. num_epochs(int): The number of epochs.
  1871. train_dataset(paddlex.dataset): Training dataset.
  1872. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  1873. eval_dataset(paddlex.dataset, optional):
  1874. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  1875. optimizer(paddle.optimizer.Optimizer or None, optional):
  1876. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  1877. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  1878. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  1879. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  1880. pretrain_weights(str or None, optional):
  1881. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  1882. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  1883. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  1884. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  1885. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  1886. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  1887. metric({'VOC', 'COCO', None}, optional):
  1888. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  1889. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  1890. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  1891. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  1892. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  1893. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  1894. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  1895. `pretrain_weights` can be set simultaneously. Defaults to None.
  1896. """
  1897. if train_dataset.pos_num < len(train_dataset.file_list):
  1898. train_dataset.num_workers = 0
  1899. if train_batch_size != 1:
  1900. train_batch_size = 1
  1901. logging.warning(
  1902. "Training RCNN models with negative samples only support batch size equals to 1 "
  1903. "on a single gpu/cpu card, `train_batch_size` is forcibly set to 1."
  1904. )
  1905. nranks = paddle.distributed.get_world_size()
  1906. local_rank = paddle.distributed.get_rank()
  1907. # single card training
  1908. if nranks < 2 or local_rank == 0:
  1909. super(MaskRCNN, self).train(
  1910. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1911. optimizer, save_interval_epochs, log_interval_steps,
  1912. save_dir, pretrain_weights, learning_rate, warmup_steps,
  1913. warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric,
  1914. use_ema, early_stop, early_stop_patience, use_vdl,
  1915. resume_checkpoint)
  1916. else:
  1917. super(MaskRCNN, self).train(
  1918. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1919. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  1920. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  1921. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  1922. early_stop_patience, use_vdl, resume_checkpoint)
  1923. def _compose_batch_transform(self, transforms, mode='train'):
  1924. if mode == 'train':
  1925. default_batch_transforms = [
  1926. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1927. ]
  1928. collate_batch = False
  1929. else:
  1930. default_batch_transforms = [
  1931. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1932. ]
  1933. collate_batch = True
  1934. custom_batch_transforms = []
  1935. for i, op in enumerate(transforms.transforms):
  1936. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1937. if mode != 'train':
  1938. raise Exception(
  1939. "{} cannot be present in the {} transforms. ".format(
  1940. op.__class__.__name__, mode) +
  1941. "Please check the {} transforms.".format(mode))
  1942. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1943. batch_transforms = BatchCompose(
  1944. custom_batch_transforms + default_batch_transforms,
  1945. collate_batch=collate_batch)
  1946. return batch_transforms
  1947. def _fix_transforms_shape(self, image_shape):
  1948. if hasattr(self, 'test_transforms'):
  1949. if self.test_transforms is not None:
  1950. has_resize_op = False
  1951. resize_op_idx = -1
  1952. normalize_op_idx = len(self.test_transforms.transforms)
  1953. for idx, op in enumerate(self.test_transforms.transforms):
  1954. name = op.__class__.__name__
  1955. if name == 'ResizeByShort':
  1956. has_resize_op = True
  1957. resize_op_idx = idx
  1958. if name == 'Normalize':
  1959. normalize_op_idx = idx
  1960. if not has_resize_op:
  1961. self.test_transforms.transforms.insert(
  1962. normalize_op_idx,
  1963. Resize(
  1964. target_size=image_shape,
  1965. keep_ratio=True,
  1966. interp='CUBIC'))
  1967. else:
  1968. self.test_transforms.transforms[resize_op_idx] = Resize(
  1969. target_size=image_shape,
  1970. keep_ratio=True,
  1971. interp='CUBIC')
  1972. self.test_transforms.transforms.append(
  1973. Padding(im_padding_value=[0., 0., 0.]))
  1974. def _get_test_inputs(self, image_shape):
  1975. if image_shape is not None:
  1976. image_shape = self._check_image_shape(image_shape)
  1977. self._fix_transforms_shape(image_shape[-2:])
  1978. else:
  1979. image_shape = [None, 3, -1, -1]
  1980. if self.with_fpn:
  1981. self.test_transforms.transforms.append(
  1982. Padding(im_padding_value=[0., 0., 0.]))
  1983. self.fixed_input_shape = image_shape
  1984. return self._define_input_spec(image_shape)