detector.py 97 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import collections
  16. import copy
  17. import os
  18. import os.path as osp
  19. import numpy as np
  20. import paddle
  21. from paddle.static import InputSpec
  22. import paddlex.ppdet as ppdet
  23. from paddlex.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
  24. import paddlex
  25. import paddlex.utils.logging as logging
  26. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
  27. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
  28. _BatchPadding, _Gt2YoloTarget
  29. from paddlex.cv.transforms import arrange_transforms
  30. from .base import BaseModel
  31. from .utils.det_metrics import VOCMetric, COCOMetric
  32. from paddlex.ppdet.optimizer import ModelEMA
  33. from paddlex.utils.checkpoint import det_pretrain_weights_dict
  34. __all__ = [
  35. "YOLOv3", "FasterRCNN", "PPYOLO", "PPYOLOTiny", "PPYOLOv2", "MaskRCNN",
  36. "PicoDet"
  37. ]
  38. class BaseDetector(BaseModel):
  39. def __init__(self, model_name, num_classes=80, **params):
  40. self.init_params.update(locals())
  41. if 'with_net' in self.init_params:
  42. del self.init_params['with_net']
  43. super(BaseDetector, self).__init__('detector')
  44. if not hasattr(ppdet.modeling, model_name):
  45. raise Exception("ERROR: There's no model named {}.".format(
  46. model_name))
  47. self.model_name = model_name
  48. self.num_classes = num_classes
  49. self.labels = None
  50. if params.get('with_net', True):
  51. params.pop('with_net', None)
  52. self.net = self.build_net(**params)
  53. def build_net(self, **params):
  54. with paddle.utils.unique_name.guard():
  55. net = ppdet.modeling.__dict__[self.model_name](**params)
  56. return net
  57. def _fix_transforms_shape(self, image_shape):
  58. raise NotImplementedError("_fix_transforms_shape: not implemented!")
  59. def _define_input_spec(self, image_shape):
  60. input_spec = [{
  61. "image": InputSpec(
  62. shape=image_shape, name='image', dtype='float32'),
  63. "im_shape": InputSpec(
  64. shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
  65. "scale_factor": InputSpec(
  66. shape=[image_shape[0], 2],
  67. name='scale_factor',
  68. dtype='float32')
  69. }]
  70. return input_spec
  71. def _check_image_shape(self, image_shape):
  72. if len(image_shape) == 2:
  73. image_shape = [1, 3] + image_shape
  74. if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
  75. raise Exception(
  76. "Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
  77. format(image_shape[-2:]))
  78. return image_shape
  79. def _get_test_inputs(self, image_shape):
  80. if image_shape is not None:
  81. image_shape = self._check_image_shape(image_shape)
  82. self._fix_transforms_shape(image_shape[-2:])
  83. else:
  84. image_shape = [None, 3, -1, -1]
  85. self.fixed_input_shape = image_shape
  86. return self._define_input_spec(image_shape)
  87. def _get_backbone(self, backbone_name, **params):
  88. backbone = getattr(ppdet.modeling, backbone_name)(**params)
  89. return backbone
  90. def run(self, net, inputs, mode):
  91. net_out = net(inputs)
  92. if mode in ['train', 'eval']:
  93. outputs = net_out
  94. else:
  95. outputs = dict()
  96. for key in net_out:
  97. outputs[key] = net_out[key].numpy()
  98. return outputs
  99. def default_optimizer(self,
  100. parameters,
  101. learning_rate,
  102. warmup_steps,
  103. warmup_start_lr,
  104. lr_decay_epochs,
  105. lr_decay_gamma,
  106. num_steps_each_epoch,
  107. reg_coeff=1e-04,
  108. scheduler='Piecewise',
  109. num_epochs=None):
  110. if scheduler.lower() == 'piecewise':
  111. if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
  112. 0] * num_steps_each_epoch:
  113. logging.error(
  114. "In function train(), parameters must satisfy: "
  115. "warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
  116. "See this doc for more information: "
  117. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
  118. exit=False)
  119. logging.error(
  120. "Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
  121. "must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
  122. format(lr_decay_epochs[0] * num_steps_each_epoch,
  123. warmup_steps // num_steps_each_epoch),
  124. exit=True)
  125. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  126. values = [(lr_decay_gamma**i) * learning_rate
  127. for i in range(len(lr_decay_epochs) + 1)]
  128. scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
  129. elif scheduler.lower() == 'cosine':
  130. if num_epochs is None:
  131. logging.error(
  132. "`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
  133. format(num_epochs),
  134. exit=False)
  135. if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
  136. logging.error(
  137. "In function train(), parameters must satisfy: "
  138. "warmup_steps <= num_epochs * num_samples_in_train_dataset. "
  139. "See this doc for more information: "
  140. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
  141. exit=False)
  142. logging.error(
  143. "`warmup_steps` must be less than the total number of steps({}), "
  144. "please modify 'num_epochs' or 'warmup_steps' in train function".
  145. format(num_epochs * num_steps_each_epoch),
  146. exit=True)
  147. T_max = num_epochs * num_steps_each_epoch - warmup_steps
  148. scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
  149. learning_rate=learning_rate,
  150. T_max=T_max,
  151. eta_min=0.0,
  152. last_epoch=-1)
  153. else:
  154. logging.error(
  155. "Invalid learning rate scheduler: {}!".format(scheduler),
  156. exit=True)
  157. if warmup_steps > 0:
  158. scheduler = paddle.optimizer.lr.LinearWarmup(
  159. learning_rate=scheduler,
  160. warmup_steps=warmup_steps,
  161. start_lr=warmup_start_lr,
  162. end_lr=learning_rate)
  163. optimizer = paddle.optimizer.Momentum(
  164. scheduler,
  165. momentum=.9,
  166. weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
  167. parameters=parameters)
  168. return optimizer
  169. def train(self,
  170. num_epochs,
  171. train_dataset,
  172. train_batch_size=64,
  173. eval_dataset=None,
  174. optimizer=None,
  175. save_interval_epochs=1,
  176. log_interval_steps=10,
  177. save_dir='output',
  178. pretrain_weights='IMAGENET',
  179. learning_rate=.001,
  180. warmup_steps=0,
  181. warmup_start_lr=0.0,
  182. lr_decay_epochs=(216, 243),
  183. lr_decay_gamma=0.1,
  184. metric=None,
  185. use_ema=False,
  186. early_stop=False,
  187. early_stop_patience=5,
  188. use_vdl=True,
  189. resume_checkpoint=None):
  190. """
  191. Train the model.
  192. Args:
  193. num_epochs(int): The number of epochs.
  194. train_dataset(paddlex.dataset): Training dataset.
  195. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  196. eval_dataset(paddlex.dataset, optional):
  197. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  198. optimizer(paddle.optimizer.Optimizer or None, optional):
  199. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  200. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  201. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  202. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  203. pretrain_weights(str or None, optional):
  204. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  205. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  206. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  207. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  208. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  209. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  210. metric({'VOC', 'COCO', None}, optional):
  211. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  212. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  213. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  214. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  215. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  216. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  217. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  218. `pretrain_weights` can be set simultaneously. Defaults to None.
  219. """
  220. if self.status == 'Infer':
  221. logging.error(
  222. "Exported inference model does not support training.",
  223. exit=True)
  224. if pretrain_weights is not None and resume_checkpoint is not None:
  225. logging.error(
  226. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  227. exit=True)
  228. if train_dataset.__class__.__name__ == 'VOCDetection':
  229. train_dataset.data_fields = {
  230. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  231. 'difficult'
  232. }
  233. elif train_dataset.__class__.__name__ == 'CocoDetection':
  234. if self.__class__.__name__ == 'MaskRCNN':
  235. train_dataset.data_fields = {
  236. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  237. 'gt_poly', 'is_crowd'
  238. }
  239. else:
  240. train_dataset.data_fields = {
  241. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  242. 'is_crowd'
  243. }
  244. if metric is None:
  245. if eval_dataset.__class__.__name__ == 'VOCDetection':
  246. self.metric = 'voc'
  247. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  248. self.metric = 'coco'
  249. else:
  250. assert metric.lower() in ['coco', 'voc'], \
  251. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  252. self.metric = metric.lower()
  253. self.labels = train_dataset.labels
  254. self.num_max_boxes = train_dataset.num_max_boxes
  255. train_dataset.batch_transforms = self._compose_batch_transform(
  256. train_dataset.transforms, mode='train')
  257. # build optimizer if not defined
  258. if optimizer is None:
  259. num_steps_each_epoch = len(train_dataset) // train_batch_size
  260. self.optimizer = self.default_optimizer(
  261. parameters=self.net.parameters(),
  262. learning_rate=learning_rate,
  263. warmup_steps=warmup_steps,
  264. warmup_start_lr=warmup_start_lr,
  265. lr_decay_epochs=lr_decay_epochs,
  266. lr_decay_gamma=lr_decay_gamma,
  267. num_steps_each_epoch=num_steps_each_epoch)
  268. else:
  269. self.optimizer = optimizer
  270. # initiate weights
  271. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  272. if pretrain_weights not in det_pretrain_weights_dict['_'.join(
  273. [self.model_name, self.backbone_name])]:
  274. logging.warning(
  275. "Path of pretrain_weights('{}') does not exist!".format(
  276. pretrain_weights))
  277. pretrain_weights = det_pretrain_weights_dict['_'.join(
  278. [self.model_name, self.backbone_name])][0]
  279. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  280. "If you don't want to use pretrain weights, "
  281. "set pretrain_weights to be None.".format(
  282. pretrain_weights))
  283. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  284. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  285. logging.error(
  286. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  287. exit=True)
  288. pretrained_dir = osp.join(save_dir, 'pretrain')
  289. self.net_initialize(
  290. pretrain_weights=pretrain_weights,
  291. save_dir=pretrained_dir,
  292. resume_checkpoint=resume_checkpoint,
  293. is_backbone_weights=(pretrain_weights == 'IMAGENET' and
  294. 'ESNet_' in self.backbone_name))
  295. if use_ema:
  296. ema = ModelEMA(model=self.net, decay=.9998, use_thres_step=True)
  297. else:
  298. ema = None
  299. # start train loop
  300. self.train_loop(
  301. num_epochs=num_epochs,
  302. train_dataset=train_dataset,
  303. train_batch_size=train_batch_size,
  304. eval_dataset=eval_dataset,
  305. save_interval_epochs=save_interval_epochs,
  306. log_interval_steps=log_interval_steps,
  307. save_dir=save_dir,
  308. ema=ema,
  309. early_stop=early_stop,
  310. early_stop_patience=early_stop_patience,
  311. use_vdl=use_vdl)
  312. def quant_aware_train(self,
  313. num_epochs,
  314. train_dataset,
  315. train_batch_size=64,
  316. eval_dataset=None,
  317. optimizer=None,
  318. save_interval_epochs=1,
  319. log_interval_steps=10,
  320. save_dir='output',
  321. learning_rate=.00001,
  322. warmup_steps=0,
  323. warmup_start_lr=0.0,
  324. lr_decay_epochs=(216, 243),
  325. lr_decay_gamma=0.1,
  326. metric=None,
  327. use_ema=False,
  328. early_stop=False,
  329. early_stop_patience=5,
  330. use_vdl=True,
  331. resume_checkpoint=None,
  332. quant_config=None):
  333. """
  334. Quantization-aware training.
  335. Args:
  336. num_epochs(int): The number of epochs.
  337. train_dataset(paddlex.dataset): Training dataset.
  338. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  339. eval_dataset(paddlex.dataset, optional):
  340. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  341. optimizer(paddle.optimizer.Optimizer or None, optional):
  342. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  343. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  344. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  345. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  346. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  347. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  348. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  349. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  350. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  351. metric({'VOC', 'COCO', None}, optional):
  352. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  353. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  354. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  355. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  356. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  357. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  358. configuration will be used. Defaults to None.
  359. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  360. from. If None, no training checkpoint will be resumed. Defaults to None.
  361. """
  362. self._prepare_qat(quant_config)
  363. self.train(
  364. num_epochs=num_epochs,
  365. train_dataset=train_dataset,
  366. train_batch_size=train_batch_size,
  367. eval_dataset=eval_dataset,
  368. optimizer=optimizer,
  369. save_interval_epochs=save_interval_epochs,
  370. log_interval_steps=log_interval_steps,
  371. save_dir=save_dir,
  372. pretrain_weights=None,
  373. learning_rate=learning_rate,
  374. warmup_steps=warmup_steps,
  375. warmup_start_lr=warmup_start_lr,
  376. lr_decay_epochs=lr_decay_epochs,
  377. lr_decay_gamma=lr_decay_gamma,
  378. metric=metric,
  379. use_ema=use_ema,
  380. early_stop=early_stop,
  381. early_stop_patience=early_stop_patience,
  382. use_vdl=use_vdl,
  383. resume_checkpoint=resume_checkpoint)
  384. def evaluate(self,
  385. eval_dataset,
  386. batch_size=1,
  387. metric=None,
  388. return_details=False):
  389. """
  390. Evaluate the model.
  391. Args:
  392. eval_dataset(paddlex.dataset): Evaluation dataset.
  393. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  394. metric({'VOC', 'COCO', None}, optional):
  395. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  396. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  397. Returns:
  398. collections.OrderedDict with key-value pairs: {"mAP(0.50, 11point)":`mean average precision`}.
  399. """
  400. if metric is None:
  401. if not hasattr(self, 'metric'):
  402. if eval_dataset.__class__.__name__ == 'VOCDetection':
  403. self.metric = 'voc'
  404. elif eval_dataset.__class__.__name__ == 'CocoDetection':
  405. self.metric = 'coco'
  406. else:
  407. assert metric.lower() in ['coco', 'voc'], \
  408. "Evaluation metric {} is not supported, please choose form 'COCO' and 'VOC'"
  409. self.metric = metric.lower()
  410. if self.metric == 'voc':
  411. eval_dataset.data_fields = {
  412. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  413. 'difficult'
  414. }
  415. elif self.metric == 'coco':
  416. if self.__class__.__name__ == 'MaskRCNN':
  417. eval_dataset.data_fields = {
  418. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  419. 'gt_poly', 'is_crowd'
  420. }
  421. else:
  422. eval_dataset.data_fields = {
  423. 'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
  424. 'is_crowd'
  425. }
  426. eval_dataset.batch_transforms = self._compose_batch_transform(
  427. eval_dataset.transforms, mode='eval')
  428. arrange_transforms(
  429. model_type=self.model_type,
  430. transforms=eval_dataset.transforms,
  431. mode='eval')
  432. self.net.eval()
  433. nranks = paddle.distributed.get_world_size()
  434. local_rank = paddle.distributed.get_rank()
  435. if nranks > 1:
  436. # Initialize parallel environment if not done.
  437. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  438. ):
  439. paddle.distributed.init_parallel_env()
  440. if batch_size > 1:
  441. logging.warning(
  442. "Detector only supports single card evaluation with batch_size=1 "
  443. "during evaluation, so batch_size is forcibly set to 1.")
  444. batch_size = 1
  445. if nranks < 2 or local_rank == 0:
  446. self.eval_data_loader = self.build_data_loader(
  447. eval_dataset, batch_size=batch_size, mode='eval')
  448. is_bbox_normalized = False
  449. if eval_dataset.batch_transforms is not None:
  450. is_bbox_normalized = any(
  451. isinstance(t, _NormalizeBox)
  452. for t in eval_dataset.batch_transforms.batch_transforms)
  453. if self.metric == 'voc':
  454. eval_metric = VOCMetric(
  455. labels=eval_dataset.labels,
  456. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  457. is_bbox_normalized=is_bbox_normalized,
  458. classwise=False)
  459. else:
  460. eval_metric = COCOMetric(
  461. coco_gt=copy.deepcopy(eval_dataset.coco_gt),
  462. classwise=False)
  463. scores = collections.OrderedDict()
  464. logging.info(
  465. "Start to evaluate(total_samples={}, total_steps={})...".
  466. format(eval_dataset.num_samples, eval_dataset.num_samples))
  467. with paddle.no_grad():
  468. for step, data in enumerate(self.eval_data_loader):
  469. outputs = self.run(self.net, data, 'eval')
  470. eval_metric.update(data, outputs)
  471. eval_metric.accumulate()
  472. self.eval_details = eval_metric.details
  473. scores.update(eval_metric.get())
  474. eval_metric.reset()
  475. if return_details:
  476. return scores, self.eval_details
  477. return scores
  478. def predict(self, img_file, transforms=None):
  479. """
  480. Do inference.
  481. Args:
  482. img_file(List[np.ndarray or str], str or np.ndarray):
  483. Image path or decoded image data in a BGR format, which also could constitute a list,
  484. meaning all images to be predicted as a mini-batch.
  485. transforms(paddlex.transforms.Compose or None, optional):
  486. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  487. Returns:
  488. If img_file is a string or np.array, the result is a list of dict with key-value pairs:
  489. {"category_id": `category_id`, "category": `category`, "bbox": `[x, y, w, h]`, "score": `score`}.
  490. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  491. category_id(int): the predicted category ID. 0 represents the first category in the dataset, and so on.
  492. category(str): category name
  493. bbox(list): bounding box in [x, y, w, h] format
  494. score(str): confidence
  495. mask(dict): Only for instance segmentation task. Mask of the object in RLE format
  496. """
  497. if transforms is None and not hasattr(self, 'test_transforms'):
  498. raise Exception("transforms need to be defined, now is None.")
  499. if transforms is None:
  500. transforms = self.test_transforms
  501. if isinstance(img_file, (str, np.ndarray)):
  502. images = [img_file]
  503. else:
  504. images = img_file
  505. batch_samples = self._preprocess(images, transforms)
  506. self.net.eval()
  507. outputs = self.run(self.net, batch_samples, 'test')
  508. prediction = self._postprocess(outputs)
  509. if isinstance(img_file, (str, np.ndarray)):
  510. prediction = prediction[0]
  511. return prediction
  512. def _preprocess(self, images, transforms, to_tensor=True):
  513. arrange_transforms(
  514. model_type=self.model_type, transforms=transforms, mode='test')
  515. batch_samples = list()
  516. for im in images:
  517. sample = {'image': im}
  518. batch_samples.append(transforms(sample))
  519. batch_transforms = self._compose_batch_transform(transforms, 'test')
  520. batch_samples = batch_transforms(batch_samples)
  521. if to_tensor:
  522. for k in batch_samples:
  523. batch_samples[k] = paddle.to_tensor(batch_samples[k])
  524. return batch_samples
  525. def _postprocess(self, batch_pred):
  526. infer_result = {}
  527. if 'bbox' in batch_pred:
  528. bboxes = batch_pred['bbox']
  529. bbox_nums = batch_pred['bbox_num']
  530. det_res = []
  531. k = 0
  532. for i in range(len(bbox_nums)):
  533. det_nums = bbox_nums[i]
  534. for j in range(det_nums):
  535. dt = bboxes[k]
  536. k = k + 1
  537. num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
  538. if int(num_id) < 0:
  539. continue
  540. category = self.labels[int(num_id)]
  541. w = xmax - xmin
  542. h = ymax - ymin
  543. bbox = [xmin, ymin, w, h]
  544. dt_res = {
  545. 'category_id': int(num_id),
  546. 'category': category,
  547. 'bbox': bbox,
  548. 'score': score
  549. }
  550. det_res.append(dt_res)
  551. infer_result['bbox'] = det_res
  552. if 'mask' in batch_pred:
  553. masks = batch_pred['mask']
  554. bboxes = batch_pred['bbox']
  555. mask_nums = batch_pred['bbox_num']
  556. seg_res = []
  557. k = 0
  558. for i in range(len(mask_nums)):
  559. det_nums = mask_nums[i]
  560. for j in range(det_nums):
  561. mask = masks[k].astype(np.uint8)
  562. score = float(bboxes[k][1])
  563. label = int(bboxes[k][0])
  564. k = k + 1
  565. if label == -1:
  566. continue
  567. category = self.labels[int(label)]
  568. sg_res = {
  569. 'category_id': int(label),
  570. 'category': category,
  571. 'mask': mask.astype('uint8'),
  572. 'score': score
  573. }
  574. seg_res.append(sg_res)
  575. infer_result['mask'] = seg_res
  576. bbox_num = batch_pred['bbox_num']
  577. results = []
  578. start = 0
  579. for num in bbox_num:
  580. end = start + num
  581. curr_res = infer_result['bbox'][start:end]
  582. if 'mask' in infer_result:
  583. mask_res = infer_result['mask'][start:end]
  584. for box, mask in zip(curr_res, mask_res):
  585. box.update(mask)
  586. results.append(curr_res)
  587. start = end
  588. return results
  589. class PicoDet(BaseDetector):
  590. def __init__(self,
  591. num_classes=80,
  592. backbone='ESNet_m',
  593. nms_score_threshold=.025,
  594. nms_topk=1000,
  595. nms_keep_topk=100,
  596. nms_iou_threshold=.6,
  597. **params):
  598. self.init_params = locals()
  599. if backbone not in {
  600. 'ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3',
  601. 'ResNet18_vd'
  602. }:
  603. raise ValueError(
  604. "backbone: {} is not supported. Please choose one of "
  605. "('ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet', 'MobileNetV3', 'ResNet18_vd')".
  606. format(backbone))
  607. self.backbone_name = backbone
  608. if params.get('with_net', True):
  609. if backbone == 'ESNet_s':
  610. backbone = self._get_backbone(
  611. 'ESNet',
  612. scale=.75,
  613. feature_maps=[4, 11, 14],
  614. act="hard_swish",
  615. channel_ratio=[
  616. 0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5,
  617. 0.5, 0.5, 0.5
  618. ])
  619. neck_out_channels = 96
  620. head_num_convs = 2
  621. elif backbone == 'ESNet_m':
  622. backbone = self._get_backbone(
  623. 'ESNet',
  624. scale=1.0,
  625. feature_maps=[4, 11, 14],
  626. act="hard_swish",
  627. channel_ratio=[
  628. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  629. 0.625, 1.0, 0.625, 0.75
  630. ])
  631. neck_out_channels = 128
  632. head_num_convs = 4
  633. elif backbone == 'ESNet_l':
  634. backbone = self._get_backbone(
  635. 'ESNet',
  636. scale=1.25,
  637. feature_maps=[4, 11, 14],
  638. act="hard_swish",
  639. channel_ratio=[
  640. 0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
  641. 0.625, 1.0, 0.625, 0.75
  642. ])
  643. neck_out_channels = 160
  644. head_num_convs = 4
  645. elif backbone == 'LCNet':
  646. backbone = self._get_backbone(
  647. 'LCNet', scale=1.5, feature_maps=[3, 4, 5])
  648. neck_out_channels = 128
  649. head_num_convs = 4
  650. elif backbone == 'MobileNetV3':
  651. backbone = self._get_backbone(
  652. 'MobileNetV3',
  653. scale=1.0,
  654. with_extra_blocks=False,
  655. extra_block_filters=[],
  656. feature_maps=[7, 13, 16])
  657. neck_out_channels = 128
  658. head_num_convs = 4
  659. else:
  660. backbone = self._get_backbone(
  661. 'ResNet',
  662. depth=18,
  663. variant='d',
  664. return_idx=[1, 2, 3],
  665. freeze_at=-1,
  666. freeze_norm=False,
  667. norm_decay=0.)
  668. neck_out_channels = 128
  669. head_num_convs = 4
  670. neck = ppdet.modeling.CSPPAN(
  671. in_channels=[i.channels for i in backbone.out_shape],
  672. out_channels=neck_out_channels,
  673. num_features=4,
  674. num_csp_blocks=1,
  675. use_depthwise=True)
  676. head_conv_feat = ppdet.modeling.PicoFeat(
  677. feat_in=neck_out_channels,
  678. feat_out=neck_out_channels,
  679. num_fpn_stride=4,
  680. num_convs=head_num_convs,
  681. norm_type='bn',
  682. share_cls_reg=True, )
  683. loss_class = ppdet.modeling.VarifocalLoss(
  684. use_sigmoid=True, iou_weighted=True, loss_weight=1.0)
  685. loss_dfl = ppdet.modeling.DistributionFocalLoss(loss_weight=.25)
  686. loss_bbox = ppdet.modeling.GIoULoss(loss_weight=2.0)
  687. assigner = ppdet.modeling.SimOTAAssigner(
  688. candidate_topk=10, iou_weight=6, num_classes=num_classes)
  689. nms = ppdet.modeling.MultiClassNMS(
  690. nms_top_k=nms_topk,
  691. keep_top_k=nms_keep_topk,
  692. score_threshold=nms_score_threshold,
  693. nms_threshold=nms_iou_threshold)
  694. head = ppdet.modeling.PicoHead(
  695. conv_feat=head_conv_feat,
  696. num_classes=num_classes,
  697. fpn_stride=[8, 16, 32, 64],
  698. prior_prob=0.01,
  699. reg_max=7,
  700. cell_offset=.5,
  701. loss_class=loss_class,
  702. loss_dfl=loss_dfl,
  703. loss_bbox=loss_bbox,
  704. assigner=assigner,
  705. feat_in_chan=neck_out_channels,
  706. nms=nms)
  707. params.update({
  708. 'backbone': backbone,
  709. 'neck': neck,
  710. 'head': head,
  711. })
  712. super(PicoDet, self).__init__(
  713. model_name='PicoDet', num_classes=num_classes, **params)
  714. def _compose_batch_transform(self, transforms, mode='train'):
  715. default_batch_transforms = [_BatchPadding(pad_to_stride=32)]
  716. if mode == 'eval':
  717. collate_batch = True
  718. else:
  719. collate_batch = False
  720. custom_batch_transforms = []
  721. for i, op in enumerate(transforms.transforms):
  722. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  723. if mode != 'train':
  724. raise Exception(
  725. "{} cannot be present in the {} transforms. ".format(
  726. op.__class__.__name__, mode) +
  727. "Please check the {} transforms.".format(mode))
  728. custom_batch_transforms.insert(0, copy.deepcopy(op))
  729. batch_transforms = BatchCompose(
  730. custom_batch_transforms + default_batch_transforms,
  731. collate_batch=collate_batch)
  732. return batch_transforms
  733. def _fix_transforms_shape(self, image_shape):
  734. if getattr(self, 'test_transforms', None):
  735. has_resize_op = False
  736. resize_op_idx = -1
  737. normalize_op_idx = len(self.test_transforms.transforms)
  738. for idx, op in enumerate(self.test_transforms.transforms):
  739. name = op.__class__.__name__
  740. if name == 'Resize':
  741. has_resize_op = True
  742. resize_op_idx = idx
  743. if name == 'Normalize':
  744. normalize_op_idx = idx
  745. if not has_resize_op:
  746. self.test_transforms.transforms.insert(
  747. normalize_op_idx,
  748. Resize(
  749. target_size=image_shape, interp='CUBIC'))
  750. else:
  751. self.test_transforms.transforms[
  752. resize_op_idx].target_size = image_shape
  753. def _get_test_inputs(self, image_shape):
  754. if image_shape is not None:
  755. image_shape = self._check_image_shape(image_shape)
  756. self._fix_transforms_shape(image_shape[-2:])
  757. else:
  758. image_shape = [None, 3, 320, 320]
  759. if getattr(self, 'test_transforms', None):
  760. for idx, op in enumerate(self.test_transforms.transforms):
  761. name = op.__class__.__name__
  762. if name == 'Resize':
  763. image_shape = [None, 3] + list(
  764. self.test_transforms.transforms[idx].target_size)
  765. logging.warning(
  766. '[Important!!!] When exporting inference model for {}, '
  767. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  768. 'Please ensure image shape after transforms is {}, if not, '
  769. 'fixed_input_shape should be specified manually.'
  770. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  771. self.fixed_input_shape = image_shape
  772. return self._define_input_spec(image_shape)
  773. def train(self,
  774. num_epochs,
  775. train_dataset,
  776. train_batch_size=64,
  777. eval_dataset=None,
  778. optimizer=None,
  779. save_interval_epochs=1,
  780. log_interval_steps=10,
  781. save_dir='output',
  782. pretrain_weights='IMAGENET',
  783. learning_rate=.001,
  784. warmup_steps=0,
  785. warmup_start_lr=0.0,
  786. lr_decay_epochs=(216, 243),
  787. lr_decay_gamma=0.1,
  788. metric=None,
  789. use_ema=False,
  790. early_stop=False,
  791. early_stop_patience=5,
  792. use_vdl=True,
  793. resume_checkpoint=None):
  794. """
  795. Train the model.
  796. Args:
  797. num_epochs(int): The number of epochs.
  798. train_dataset(paddlex.dataset): Training dataset.
  799. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  800. eval_dataset(paddlex.dataset, optional):
  801. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  802. optimizer(paddle.optimizer.Optimizer or None, optional):
  803. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  804. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  805. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  806. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  807. pretrain_weights(str or None, optional):
  808. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  809. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  810. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  811. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  812. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  813. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  814. metric({'VOC', 'COCO', None}, optional):
  815. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  816. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  817. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  818. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  819. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  820. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  821. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  822. `pretrain_weights` can be set simultaneously. Defaults to None.
  823. """
  824. if optimizer is None:
  825. num_steps_each_epoch = len(train_dataset) // train_batch_size
  826. optimizer = self.default_optimizer(
  827. parameters=self.net.parameters(),
  828. learning_rate=learning_rate,
  829. warmup_steps=warmup_steps,
  830. warmup_start_lr=warmup_start_lr,
  831. lr_decay_epochs=lr_decay_epochs,
  832. lr_decay_gamma=lr_decay_gamma,
  833. num_steps_each_epoch=num_steps_each_epoch,
  834. reg_coeff=4e-05,
  835. scheduler='Cosine',
  836. num_epochs=num_epochs)
  837. super(PicoDet, self).train(
  838. num_epochs=num_epochs,
  839. train_dataset=train_dataset,
  840. train_batch_size=train_batch_size,
  841. eval_dataset=eval_dataset,
  842. optimizer=optimizer,
  843. save_interval_epochs=save_interval_epochs,
  844. log_interval_steps=log_interval_steps,
  845. save_dir=save_dir,
  846. pretrain_weights=pretrain_weights,
  847. learning_rate=learning_rate,
  848. warmup_steps=warmup_steps,
  849. warmup_start_lr=warmup_start_lr,
  850. lr_decay_epochs=lr_decay_epochs,
  851. lr_decay_gamma=lr_decay_gamma,
  852. metric=metric,
  853. use_ema=use_ema,
  854. early_stop=early_stop,
  855. early_stop_patience=early_stop_patience,
  856. use_vdl=use_vdl,
  857. resume_checkpoint=resume_checkpoint)
  858. class YOLOv3(BaseDetector):
  859. def __init__(self,
  860. num_classes=80,
  861. backbone='MobileNetV1',
  862. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  863. [59, 119], [116, 90], [156, 198], [373, 326]],
  864. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  865. ignore_threshold=0.7,
  866. nms_score_threshold=0.01,
  867. nms_topk=1000,
  868. nms_keep_topk=100,
  869. nms_iou_threshold=0.45,
  870. label_smooth=False,
  871. **params):
  872. self.init_params = locals()
  873. if backbone not in {
  874. 'MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3',
  875. 'MobileNetV3_ssld', 'DarkNet53', 'ResNet50_vd_dcn', 'ResNet34'
  876. }:
  877. raise ValueError(
  878. "backbone: {} is not supported. Please choose one of "
  879. "('MobileNetV1', 'MobileNetV1_ssld', 'MobileNetV3', 'MobileNetV3_ssld', 'DarkNet53', "
  880. "'ResNet50_vd_dcn', 'ResNet34')".format(backbone))
  881. self.backbone_name = backbone
  882. if params.get('with_net', True):
  883. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  884. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  885. norm_type = 'sync_bn'
  886. else:
  887. norm_type = 'bn'
  888. if 'MobileNetV1' in backbone:
  889. norm_type = 'bn'
  890. backbone = self._get_backbone('MobileNet', norm_type=norm_type)
  891. elif 'MobileNetV3' in backbone:
  892. backbone = self._get_backbone(
  893. 'MobileNetV3',
  894. norm_type=norm_type,
  895. feature_maps=[7, 13, 16])
  896. elif backbone == 'ResNet50_vd_dcn':
  897. backbone = self._get_backbone(
  898. 'ResNet',
  899. norm_type=norm_type,
  900. variant='d',
  901. return_idx=[1, 2, 3],
  902. dcn_v2_stages=[3],
  903. freeze_at=-1,
  904. freeze_norm=False)
  905. elif backbone == 'ResNet34':
  906. backbone = self._get_backbone(
  907. 'ResNet',
  908. depth=34,
  909. norm_type=norm_type,
  910. return_idx=[1, 2, 3],
  911. freeze_at=-1,
  912. freeze_norm=False,
  913. norm_decay=0.)
  914. else:
  915. backbone = self._get_backbone('DarkNet', norm_type=norm_type)
  916. neck = ppdet.modeling.YOLOv3FPN(
  917. norm_type=norm_type,
  918. in_channels=[i.channels for i in backbone.out_shape])
  919. loss = ppdet.modeling.YOLOv3Loss(
  920. num_classes=num_classes,
  921. ignore_thresh=ignore_threshold,
  922. label_smooth=label_smooth)
  923. yolo_head = ppdet.modeling.YOLOv3Head(
  924. in_channels=[i.channels for i in neck.out_shape],
  925. anchors=anchors,
  926. anchor_masks=anchor_masks,
  927. num_classes=num_classes,
  928. loss=loss)
  929. post_process = ppdet.modeling.BBoxPostProcess(
  930. decode=ppdet.modeling.YOLOBox(num_classes=num_classes),
  931. nms=ppdet.modeling.MultiClassNMS(
  932. score_threshold=nms_score_threshold,
  933. nms_top_k=nms_topk,
  934. keep_top_k=nms_keep_topk,
  935. nms_threshold=nms_iou_threshold))
  936. params.update({
  937. 'backbone': backbone,
  938. 'neck': neck,
  939. 'yolo_head': yolo_head,
  940. 'post_process': post_process
  941. })
  942. super(YOLOv3, self).__init__(
  943. model_name='YOLOv3', num_classes=num_classes, **params)
  944. self.anchors = anchors
  945. self.anchor_masks = anchor_masks
  946. def _compose_batch_transform(self, transforms, mode='train'):
  947. if mode == 'train':
  948. default_batch_transforms = [
  949. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  950. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  951. _Gt2YoloTarget(
  952. anchor_masks=self.anchor_masks,
  953. anchors=self.anchors,
  954. downsample_ratios=getattr(self, 'downsample_ratios',
  955. [32, 16, 8]),
  956. num_classes=self.num_classes)
  957. ]
  958. else:
  959. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  960. if mode == 'eval' and self.metric == 'voc':
  961. collate_batch = False
  962. else:
  963. collate_batch = True
  964. custom_batch_transforms = []
  965. for i, op in enumerate(transforms.transforms):
  966. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  967. if mode != 'train':
  968. raise Exception(
  969. "{} cannot be present in the {} transforms. ".format(
  970. op.__class__.__name__, mode) +
  971. "Please check the {} transforms.".format(mode))
  972. custom_batch_transforms.insert(0, copy.deepcopy(op))
  973. batch_transforms = BatchCompose(
  974. custom_batch_transforms + default_batch_transforms,
  975. collate_batch=collate_batch)
  976. return batch_transforms
  977. def _fix_transforms_shape(self, image_shape):
  978. if getattr(self, 'test_transforms', None):
  979. has_resize_op = False
  980. resize_op_idx = -1
  981. normalize_op_idx = len(self.test_transforms.transforms)
  982. for idx, op in enumerate(self.test_transforms.transforms):
  983. name = op.__class__.__name__
  984. if name == 'Resize':
  985. has_resize_op = True
  986. resize_op_idx = idx
  987. if name == 'Normalize':
  988. normalize_op_idx = idx
  989. if not has_resize_op:
  990. self.test_transforms.transforms.insert(
  991. normalize_op_idx,
  992. Resize(
  993. target_size=image_shape, interp='CUBIC'))
  994. else:
  995. self.test_transforms.transforms[
  996. resize_op_idx].target_size = image_shape
  997. class FasterRCNN(BaseDetector):
  998. def __init__(self,
  999. num_classes=80,
  1000. backbone='ResNet50',
  1001. with_fpn=True,
  1002. with_dcn=False,
  1003. aspect_ratios=[0.5, 1.0, 2.0],
  1004. anchor_sizes=[[32], [64], [128], [256], [512]],
  1005. keep_top_k=100,
  1006. nms_threshold=0.5,
  1007. score_threshold=0.05,
  1008. fpn_num_channels=256,
  1009. rpn_batch_size_per_im=256,
  1010. rpn_fg_fraction=0.5,
  1011. test_pre_nms_top_n=None,
  1012. test_post_nms_top_n=1000,
  1013. **params):
  1014. self.init_params = locals()
  1015. if backbone not in {
  1016. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
  1017. 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18'
  1018. }:
  1019. raise ValueError(
  1020. "backbone: {} is not supported. Please choose one of "
  1021. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
  1022. "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
  1023. self.backbone_name = backbone
  1024. if params.get('with_net', True):
  1025. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1026. if backbone == 'HRNet_W18':
  1027. if not with_fpn:
  1028. logging.warning(
  1029. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1030. format(backbone))
  1031. with_fpn = True
  1032. if with_dcn:
  1033. logging.warning(
  1034. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1035. format(backbone))
  1036. backbone = self._get_backbone(
  1037. 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
  1038. elif backbone == 'ResNet50_vd_ssld':
  1039. if not with_fpn:
  1040. logging.warning(
  1041. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1042. format(backbone))
  1043. with_fpn = True
  1044. backbone = self._get_backbone(
  1045. 'ResNet',
  1046. variant='d',
  1047. norm_type='bn',
  1048. freeze_at=0,
  1049. return_idx=[0, 1, 2, 3],
  1050. num_stages=4,
  1051. lr_mult_list=[0.05, 0.05, 0.1, 0.15],
  1052. dcn_v2_stages=dcn_v2_stages)
  1053. elif 'ResNet50' in backbone:
  1054. if with_fpn:
  1055. backbone = self._get_backbone(
  1056. 'ResNet',
  1057. variant='d' if '_vd' in backbone else 'b',
  1058. norm_type='bn',
  1059. freeze_at=0,
  1060. return_idx=[0, 1, 2, 3],
  1061. num_stages=4,
  1062. dcn_v2_stages=dcn_v2_stages)
  1063. else:
  1064. if with_dcn:
  1065. logging.warning(
  1066. "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1067. format(backbone))
  1068. backbone = self._get_backbone(
  1069. 'ResNet',
  1070. variant='d' if '_vd' in backbone else 'b',
  1071. norm_type='bn',
  1072. freeze_at=0,
  1073. return_idx=[2],
  1074. num_stages=3)
  1075. elif 'ResNet34' in backbone:
  1076. if not with_fpn:
  1077. logging.warning(
  1078. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1079. format(backbone))
  1080. with_fpn = True
  1081. backbone = self._get_backbone(
  1082. 'ResNet',
  1083. depth=34,
  1084. variant='d' if 'vd' in backbone else 'b',
  1085. norm_type='bn',
  1086. freeze_at=0,
  1087. return_idx=[0, 1, 2, 3],
  1088. num_stages=4,
  1089. dcn_v2_stages=dcn_v2_stages)
  1090. else:
  1091. if not with_fpn:
  1092. logging.warning(
  1093. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1094. format(backbone))
  1095. with_fpn = True
  1096. backbone = self._get_backbone(
  1097. 'ResNet',
  1098. depth=101,
  1099. variant='d' if 'vd' in backbone else 'b',
  1100. norm_type='bn',
  1101. freeze_at=0,
  1102. return_idx=[0, 1, 2, 3],
  1103. num_stages=4,
  1104. dcn_v2_stages=dcn_v2_stages)
  1105. rpn_in_channel = backbone.out_shape[0].channels
  1106. if with_fpn:
  1107. self.backbone_name = self.backbone_name + '_fpn'
  1108. if 'HRNet' in self.backbone_name:
  1109. neck = ppdet.modeling.HRFPN(
  1110. in_channels=[i.channels for i in backbone.out_shape],
  1111. out_channel=fpn_num_channels,
  1112. spatial_scales=[
  1113. 1.0 / i.stride for i in backbone.out_shape
  1114. ],
  1115. share_conv=False)
  1116. else:
  1117. neck = ppdet.modeling.FPN(
  1118. in_channels=[i.channels for i in backbone.out_shape],
  1119. out_channel=fpn_num_channels,
  1120. spatial_scales=[
  1121. 1.0 / i.stride for i in backbone.out_shape
  1122. ])
  1123. rpn_in_channel = neck.out_shape[0].channels
  1124. anchor_generator_cfg = {
  1125. 'aspect_ratios': aspect_ratios,
  1126. 'anchor_sizes': anchor_sizes,
  1127. 'strides': [4, 8, 16, 32, 64]
  1128. }
  1129. train_proposal_cfg = {
  1130. 'min_size': 0.0,
  1131. 'nms_thresh': .7,
  1132. 'pre_nms_top_n': 2000,
  1133. 'post_nms_top_n': 1000,
  1134. 'topk_after_collect': True
  1135. }
  1136. test_proposal_cfg = {
  1137. 'min_size': 0.0,
  1138. 'nms_thresh': .7,
  1139. 'pre_nms_top_n': 1000
  1140. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1141. 'post_nms_top_n': test_post_nms_top_n
  1142. }
  1143. head = ppdet.modeling.TwoFCHead(
  1144. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1145. roi_extractor_cfg = {
  1146. 'resolution': 7,
  1147. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1148. 'sampling_ratio': 0,
  1149. 'aligned': True
  1150. }
  1151. with_pool = False
  1152. else:
  1153. neck = None
  1154. anchor_generator_cfg = {
  1155. 'aspect_ratios': aspect_ratios,
  1156. 'anchor_sizes': anchor_sizes,
  1157. 'strides': [16]
  1158. }
  1159. train_proposal_cfg = {
  1160. 'min_size': 0.0,
  1161. 'nms_thresh': .7,
  1162. 'pre_nms_top_n': 12000,
  1163. 'post_nms_top_n': 2000,
  1164. 'topk_after_collect': False
  1165. }
  1166. test_proposal_cfg = {
  1167. 'min_size': 0.0,
  1168. 'nms_thresh': .7,
  1169. 'pre_nms_top_n': 6000
  1170. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1171. 'post_nms_top_n': test_post_nms_top_n
  1172. }
  1173. head = ppdet.modeling.Res5Head()
  1174. roi_extractor_cfg = {
  1175. 'resolution': 14,
  1176. 'spatial_scale':
  1177. [1. / i.stride for i in backbone.out_shape],
  1178. 'sampling_ratio': 0,
  1179. 'aligned': True
  1180. }
  1181. with_pool = True
  1182. rpn_target_assign_cfg = {
  1183. 'batch_size_per_im': rpn_batch_size_per_im,
  1184. 'fg_fraction': rpn_fg_fraction,
  1185. 'negative_overlap': .3,
  1186. 'positive_overlap': .7,
  1187. 'use_random': True
  1188. }
  1189. rpn_head = ppdet.modeling.RPNHead(
  1190. anchor_generator=anchor_generator_cfg,
  1191. rpn_target_assign=rpn_target_assign_cfg,
  1192. train_proposal=train_proposal_cfg,
  1193. test_proposal=test_proposal_cfg,
  1194. in_channel=rpn_in_channel)
  1195. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1196. bbox_head = ppdet.modeling.BBoxHead(
  1197. head=head,
  1198. in_channel=head.out_shape[0].channels,
  1199. roi_extractor=roi_extractor_cfg,
  1200. with_pool=with_pool,
  1201. bbox_assigner=bbox_assigner,
  1202. num_classes=num_classes)
  1203. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1204. num_classes=num_classes,
  1205. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1206. nms=ppdet.modeling.MultiClassNMS(
  1207. score_threshold=score_threshold,
  1208. keep_top_k=keep_top_k,
  1209. nms_threshold=nms_threshold))
  1210. params.update({
  1211. 'backbone': backbone,
  1212. 'neck': neck,
  1213. 'rpn_head': rpn_head,
  1214. 'bbox_head': bbox_head,
  1215. 'bbox_post_process': bbox_post_process
  1216. })
  1217. else:
  1218. if backbone not in {'ResNet50', 'ResNet50_vd'}:
  1219. with_fpn = True
  1220. self.with_fpn = with_fpn
  1221. super(FasterRCNN, self).__init__(
  1222. model_name='FasterRCNN', num_classes=num_classes, **params)
  1223. def train(self,
  1224. num_epochs,
  1225. train_dataset,
  1226. train_batch_size=64,
  1227. eval_dataset=None,
  1228. optimizer=None,
  1229. save_interval_epochs=1,
  1230. log_interval_steps=10,
  1231. save_dir='output',
  1232. pretrain_weights='IMAGENET',
  1233. learning_rate=.001,
  1234. warmup_steps=0,
  1235. warmup_start_lr=0.0,
  1236. lr_decay_epochs=(216, 243),
  1237. lr_decay_gamma=0.1,
  1238. metric=None,
  1239. use_ema=False,
  1240. early_stop=False,
  1241. early_stop_patience=5,
  1242. use_vdl=True,
  1243. resume_checkpoint=None):
  1244. """
  1245. Train the model.
  1246. Args:
  1247. num_epochs(int): The number of epochs.
  1248. train_dataset(paddlex.dataset): Training dataset.
  1249. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  1250. eval_dataset(paddlex.dataset, optional):
  1251. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  1252. optimizer(paddle.optimizer.Optimizer or None, optional):
  1253. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  1254. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  1255. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  1256. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  1257. pretrain_weights(str or None, optional):
  1258. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  1259. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  1260. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  1261. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  1262. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  1263. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  1264. metric({'VOC', 'COCO', None}, optional):
  1265. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  1266. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  1267. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  1268. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  1269. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  1270. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  1271. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  1272. `pretrain_weights` can be set simultaneously. Defaults to None.
  1273. """
  1274. if train_dataset.pos_num < len(train_dataset.file_list):
  1275. train_dataset.num_workers = 0
  1276. super(FasterRCNN, self).train(
  1277. num_epochs, train_dataset, train_batch_size, eval_dataset,
  1278. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  1279. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  1280. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  1281. early_stop_patience, use_vdl, resume_checkpoint)
  1282. def _compose_batch_transform(self, transforms, mode='train'):
  1283. if mode == 'train':
  1284. default_batch_transforms = [
  1285. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1286. ]
  1287. else:
  1288. default_batch_transforms = [
  1289. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  1290. ]
  1291. custom_batch_transforms = []
  1292. for i, op in enumerate(transforms.transforms):
  1293. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  1294. if mode != 'train':
  1295. raise Exception(
  1296. "{} cannot be present in the {} transforms. ".format(
  1297. op.__class__.__name__, mode) +
  1298. "Please check the {} transforms.".format(mode))
  1299. custom_batch_transforms.insert(0, copy.deepcopy(op))
  1300. batch_transforms = BatchCompose(
  1301. custom_batch_transforms + default_batch_transforms,
  1302. collate_batch=False)
  1303. return batch_transforms
  1304. def _fix_transforms_shape(self, image_shape):
  1305. if getattr(self, 'test_transforms', None):
  1306. has_resize_op = False
  1307. resize_op_idx = -1
  1308. normalize_op_idx = len(self.test_transforms.transforms)
  1309. for idx, op in enumerate(self.test_transforms.transforms):
  1310. name = op.__class__.__name__
  1311. if name == 'ResizeByShort':
  1312. has_resize_op = True
  1313. resize_op_idx = idx
  1314. if name == 'Normalize':
  1315. normalize_op_idx = idx
  1316. if not has_resize_op:
  1317. self.test_transforms.transforms.insert(
  1318. normalize_op_idx,
  1319. Resize(
  1320. target_size=image_shape,
  1321. keep_ratio=True,
  1322. interp='CUBIC'))
  1323. else:
  1324. self.test_transforms.transforms[resize_op_idx] = Resize(
  1325. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  1326. self.test_transforms.transforms.append(
  1327. Padding(im_padding_value=[0., 0., 0.]))
  1328. def _get_test_inputs(self, image_shape):
  1329. if image_shape is not None:
  1330. image_shape = self._check_image_shape(image_shape)
  1331. self._fix_transforms_shape(image_shape[-2:])
  1332. else:
  1333. image_shape = [None, 3, -1, -1]
  1334. if self.with_fpn:
  1335. self.test_transforms.transforms.append(
  1336. Padding(im_padding_value=[0., 0., 0.]))
  1337. self.fixed_input_shape = image_shape
  1338. return self._define_input_spec(image_shape)
  1339. class PPYOLO(YOLOv3):
  1340. def __init__(self,
  1341. num_classes=80,
  1342. backbone='ResNet50_vd_dcn',
  1343. anchors=None,
  1344. anchor_masks=None,
  1345. use_coord_conv=True,
  1346. use_iou_aware=True,
  1347. use_spp=True,
  1348. use_drop_block=True,
  1349. scale_x_y=1.05,
  1350. ignore_threshold=0.7,
  1351. label_smooth=False,
  1352. use_iou_loss=True,
  1353. use_matrix_nms=True,
  1354. nms_score_threshold=0.01,
  1355. nms_topk=-1,
  1356. nms_keep_topk=100,
  1357. nms_iou_threshold=0.45,
  1358. **params):
  1359. self.init_params = locals()
  1360. if backbone not in {
  1361. 'ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large',
  1362. 'MobileNetV3_small'
  1363. }:
  1364. raise ValueError(
  1365. "backbone: {} is not supported. Please choose one of "
  1366. "('ResNet50_vd_dcn', 'ResNet18_vd', 'MobileNetV3_large', 'MobileNetV3_small')".
  1367. format(backbone))
  1368. self.backbone_name = backbone
  1369. self.downsample_ratios = [
  1370. 32, 16, 8
  1371. ] if backbone == 'ResNet50_vd_dcn' else [32, 16]
  1372. if params.get('with_net', True):
  1373. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1374. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1375. norm_type = 'sync_bn'
  1376. else:
  1377. norm_type = 'bn'
  1378. if anchors is None and anchor_masks is None:
  1379. if 'MobileNetV3' in backbone:
  1380. anchors = [[11, 18], [34, 47], [51, 126], [115, 71],
  1381. [120, 195], [254, 235]]
  1382. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1383. elif backbone == 'ResNet50_vd_dcn':
  1384. anchors = [[10, 13], [16, 30], [33, 23], [30, 61],
  1385. [62, 45], [59, 119], [116, 90], [156, 198],
  1386. [373, 326]]
  1387. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  1388. else:
  1389. anchors = [[10, 14], [23, 27], [37, 58], [81, 82],
  1390. [135, 169], [344, 319]]
  1391. anchor_masks = [[3, 4, 5], [0, 1, 2]]
  1392. elif anchors is None or anchor_masks is None:
  1393. raise ValueError(
  1394. "Please define both anchors and anchor_masks.")
  1395. if backbone == 'ResNet50_vd_dcn':
  1396. backbone = self._get_backbone(
  1397. 'ResNet',
  1398. variant='d',
  1399. norm_type=norm_type,
  1400. return_idx=[1, 2, 3],
  1401. dcn_v2_stages=[3],
  1402. freeze_at=-1,
  1403. freeze_norm=False,
  1404. norm_decay=0.)
  1405. elif backbone == 'ResNet18_vd':
  1406. backbone = self._get_backbone(
  1407. 'ResNet',
  1408. depth=18,
  1409. variant='d',
  1410. norm_type=norm_type,
  1411. return_idx=[2, 3],
  1412. freeze_at=-1,
  1413. freeze_norm=False,
  1414. norm_decay=0.)
  1415. elif backbone == 'MobileNetV3_large':
  1416. backbone = self._get_backbone(
  1417. 'MobileNetV3',
  1418. model_name='large',
  1419. norm_type=norm_type,
  1420. scale=1,
  1421. with_extra_blocks=False,
  1422. extra_block_filters=[],
  1423. feature_maps=[13, 16])
  1424. elif backbone == 'MobileNetV3_small':
  1425. backbone = self._get_backbone(
  1426. 'MobileNetV3',
  1427. model_name='small',
  1428. norm_type=norm_type,
  1429. scale=1,
  1430. with_extra_blocks=False,
  1431. extra_block_filters=[],
  1432. feature_maps=[9, 12])
  1433. neck = ppdet.modeling.PPYOLOFPN(
  1434. norm_type=norm_type,
  1435. in_channels=[i.channels for i in backbone.out_shape],
  1436. coord_conv=use_coord_conv,
  1437. drop_block=use_drop_block,
  1438. spp=use_spp,
  1439. conv_block_num=0
  1440. if ('MobileNetV3' in self.backbone_name or
  1441. self.backbone_name == 'ResNet18_vd') else 2)
  1442. loss = ppdet.modeling.YOLOv3Loss(
  1443. num_classes=num_classes,
  1444. ignore_thresh=ignore_threshold,
  1445. downsample=self.downsample_ratios,
  1446. label_smooth=label_smooth,
  1447. scale_x_y=scale_x_y,
  1448. iou_loss=ppdet.modeling.IouLoss(
  1449. loss_weight=2.5, loss_square=True)
  1450. if use_iou_loss else None,
  1451. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1452. if use_iou_aware else None)
  1453. yolo_head = ppdet.modeling.YOLOv3Head(
  1454. in_channels=[i.channels for i in neck.out_shape],
  1455. anchors=anchors,
  1456. anchor_masks=anchor_masks,
  1457. num_classes=num_classes,
  1458. loss=loss,
  1459. iou_aware=use_iou_aware)
  1460. if use_matrix_nms:
  1461. nms = ppdet.modeling.MatrixNMS(
  1462. keep_top_k=nms_keep_topk,
  1463. score_threshold=nms_score_threshold,
  1464. post_threshold=.05
  1465. if 'MobileNetV3' in self.backbone_name else .01,
  1466. nms_top_k=nms_topk,
  1467. background_label=-1)
  1468. else:
  1469. nms = ppdet.modeling.MultiClassNMS(
  1470. score_threshold=nms_score_threshold,
  1471. nms_top_k=nms_topk,
  1472. keep_top_k=nms_keep_topk,
  1473. nms_threshold=nms_iou_threshold)
  1474. post_process = ppdet.modeling.BBoxPostProcess(
  1475. decode=ppdet.modeling.YOLOBox(
  1476. num_classes=num_classes,
  1477. conf_thresh=.005
  1478. if 'MobileNetV3' in self.backbone_name else .01,
  1479. scale_x_y=scale_x_y),
  1480. nms=nms)
  1481. params.update({
  1482. 'backbone': backbone,
  1483. 'neck': neck,
  1484. 'yolo_head': yolo_head,
  1485. 'post_process': post_process
  1486. })
  1487. super(YOLOv3, self).__init__(
  1488. model_name='YOLOv3', num_classes=num_classes, **params)
  1489. self.anchors = anchors
  1490. self.anchor_masks = anchor_masks
  1491. self.model_name = 'PPYOLO'
  1492. def _get_test_inputs(self, image_shape):
  1493. if image_shape is not None:
  1494. image_shape = self._check_image_shape(image_shape)
  1495. self._fix_transforms_shape(image_shape[-2:])
  1496. else:
  1497. image_shape = [None, 3, 608, 608]
  1498. if getattr(self, 'test_transforms', None):
  1499. for idx, op in enumerate(self.test_transforms.transforms):
  1500. name = op.__class__.__name__
  1501. if name == 'Resize':
  1502. image_shape = [None, 3] + list(
  1503. self.test_transforms.transforms[idx].target_size)
  1504. logging.warning(
  1505. '[Important!!!] When exporting inference model for {}, '
  1506. 'if fixed_input_shape is not set, it will be forcibly set to {}. '
  1507. 'Please ensure image shape after transforms is {}, if not, '
  1508. 'fixed_input_shape should be specified manually.'
  1509. .format(self.__class__.__name__, image_shape, image_shape[1:]))
  1510. self.fixed_input_shape = image_shape
  1511. return self._define_input_spec(image_shape)
  1512. class PPYOLOTiny(YOLOv3):
  1513. def __init__(self,
  1514. num_classes=80,
  1515. backbone='MobileNetV3',
  1516. anchors=[[10, 15], [24, 36], [72, 42], [35, 87], [102, 96],
  1517. [60, 170], [220, 125], [128, 222], [264, 266]],
  1518. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1519. use_iou_aware=False,
  1520. use_spp=True,
  1521. use_drop_block=True,
  1522. scale_x_y=1.05,
  1523. ignore_threshold=0.5,
  1524. label_smooth=False,
  1525. use_iou_loss=True,
  1526. use_matrix_nms=False,
  1527. nms_score_threshold=0.005,
  1528. nms_topk=1000,
  1529. nms_keep_topk=100,
  1530. nms_iou_threshold=0.45,
  1531. **params):
  1532. self.init_params = locals()
  1533. if backbone != 'MobileNetV3':
  1534. logging.warning(
  1535. "PPYOLOTiny only supports MobileNetV3 as backbone. "
  1536. "Backbone is forcibly set to MobileNetV3.")
  1537. self.backbone_name = 'MobileNetV3'
  1538. self.downsample_ratios = [32, 16, 8]
  1539. if params.get('with_net', True):
  1540. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1541. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1542. norm_type = 'sync_bn'
  1543. else:
  1544. norm_type = 'bn'
  1545. backbone = self._get_backbone(
  1546. 'MobileNetV3',
  1547. model_name='large',
  1548. norm_type=norm_type,
  1549. scale=.5,
  1550. with_extra_blocks=False,
  1551. extra_block_filters=[],
  1552. feature_maps=[7, 13, 16])
  1553. neck = ppdet.modeling.PPYOLOTinyFPN(
  1554. detection_block_channels=[160, 128, 96],
  1555. in_channels=[i.channels for i in backbone.out_shape],
  1556. spp=use_spp,
  1557. drop_block=use_drop_block)
  1558. loss = ppdet.modeling.YOLOv3Loss(
  1559. num_classes=num_classes,
  1560. ignore_thresh=ignore_threshold,
  1561. downsample=self.downsample_ratios,
  1562. label_smooth=label_smooth,
  1563. scale_x_y=scale_x_y,
  1564. iou_loss=ppdet.modeling.IouLoss(
  1565. loss_weight=2.5, loss_square=True)
  1566. if use_iou_loss else None,
  1567. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1568. if use_iou_aware else None)
  1569. yolo_head = ppdet.modeling.YOLOv3Head(
  1570. in_channels=[i.channels for i in neck.out_shape],
  1571. anchors=anchors,
  1572. anchor_masks=anchor_masks,
  1573. num_classes=num_classes,
  1574. loss=loss,
  1575. iou_aware=use_iou_aware)
  1576. if use_matrix_nms:
  1577. nms = ppdet.modeling.MatrixNMS(
  1578. keep_top_k=nms_keep_topk,
  1579. score_threshold=nms_score_threshold,
  1580. post_threshold=.05,
  1581. nms_top_k=nms_topk,
  1582. background_label=-1)
  1583. else:
  1584. nms = ppdet.modeling.MultiClassNMS(
  1585. score_threshold=nms_score_threshold,
  1586. nms_top_k=nms_topk,
  1587. keep_top_k=nms_keep_topk,
  1588. nms_threshold=nms_iou_threshold)
  1589. post_process = ppdet.modeling.BBoxPostProcess(
  1590. decode=ppdet.modeling.YOLOBox(
  1591. num_classes=num_classes,
  1592. conf_thresh=.005,
  1593. downsample_ratio=32,
  1594. clip_bbox=True,
  1595. scale_x_y=scale_x_y),
  1596. nms=nms)
  1597. params.update({
  1598. 'backbone': backbone,
  1599. 'neck': neck,
  1600. 'yolo_head': yolo_head,
  1601. 'post_process': post_process
  1602. })
  1603. super(YOLOv3, self).__init__(
  1604. model_name='YOLOv3', num_classes=num_classes, **params)
  1605. self.anchors = anchors
  1606. self.anchor_masks = anchor_masks
  1607. self.model_name = 'PPYOLOTiny'
  1608. def _get_test_inputs(self, image_shape):
  1609. if image_shape is not None:
  1610. image_shape = self._check_image_shape(image_shape)
  1611. self._fix_transforms_shape(image_shape[-2:])
  1612. else:
  1613. image_shape = [None, 3, 320, 320]
  1614. if getattr(self, 'test_transforms', None):
  1615. for idx, op in enumerate(self.test_transforms.transforms):
  1616. name = op.__class__.__name__
  1617. if name == 'Resize':
  1618. image_shape = [None, 3] + list(
  1619. self.test_transforms.transforms[idx].target_size)
  1620. logging.warning(
  1621. '[Important!!!] When exporting inference model for {},'.format(
  1622. self.__class__.__name__) +
  1623. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1624. format(image_shape) +
  1625. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1626. format(image_shape[1:]) + 'should be specified manually.')
  1627. self.fixed_input_shape = image_shape
  1628. return self._define_input_spec(image_shape)
  1629. class PPYOLOv2(YOLOv3):
  1630. def __init__(self,
  1631. num_classes=80,
  1632. backbone='ResNet50_vd_dcn',
  1633. anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  1634. [59, 119], [116, 90], [156, 198], [373, 326]],
  1635. anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  1636. use_iou_aware=True,
  1637. use_spp=True,
  1638. use_drop_block=True,
  1639. scale_x_y=1.05,
  1640. ignore_threshold=0.7,
  1641. label_smooth=False,
  1642. use_iou_loss=True,
  1643. use_matrix_nms=True,
  1644. nms_score_threshold=0.01,
  1645. nms_topk=-1,
  1646. nms_keep_topk=100,
  1647. nms_iou_threshold=0.45,
  1648. **params):
  1649. self.init_params = locals()
  1650. if backbone not in {'ResNet50_vd_dcn', 'ResNet101_vd_dcn'}:
  1651. raise ValueError(
  1652. "backbone: {} is not supported. Please choose one of "
  1653. "('ResNet50_vd_dcn', 'ResNet101_vd_dcn')".format(backbone))
  1654. self.backbone_name = backbone
  1655. self.downsample_ratios = [32, 16, 8]
  1656. if params.get('with_net', True):
  1657. if paddlex.env_info['place'] == 'gpu' and paddlex.env_info[
  1658. 'num'] > 1 and not os.environ.get('PADDLEX_EXPORT_STAGE'):
  1659. norm_type = 'sync_bn'
  1660. else:
  1661. norm_type = 'bn'
  1662. if backbone == 'ResNet50_vd_dcn':
  1663. backbone = self._get_backbone(
  1664. 'ResNet',
  1665. variant='d',
  1666. norm_type=norm_type,
  1667. return_idx=[1, 2, 3],
  1668. dcn_v2_stages=[3],
  1669. freeze_at=-1,
  1670. freeze_norm=False,
  1671. norm_decay=0.)
  1672. elif backbone == 'ResNet101_vd_dcn':
  1673. backbone = self._get_backbone(
  1674. 'ResNet',
  1675. depth=101,
  1676. variant='d',
  1677. norm_type=norm_type,
  1678. return_idx=[1, 2, 3],
  1679. dcn_v2_stages=[3],
  1680. freeze_at=-1,
  1681. freeze_norm=False,
  1682. norm_decay=0.)
  1683. neck = ppdet.modeling.PPYOLOPAN(
  1684. norm_type=norm_type,
  1685. in_channels=[i.channels for i in backbone.out_shape],
  1686. drop_block=use_drop_block,
  1687. block_size=3,
  1688. keep_prob=.9,
  1689. spp=use_spp)
  1690. loss = ppdet.modeling.YOLOv3Loss(
  1691. num_classes=num_classes,
  1692. ignore_thresh=ignore_threshold,
  1693. downsample=self.downsample_ratios,
  1694. label_smooth=label_smooth,
  1695. scale_x_y=scale_x_y,
  1696. iou_loss=ppdet.modeling.IouLoss(
  1697. loss_weight=2.5, loss_square=True)
  1698. if use_iou_loss else None,
  1699. iou_aware_loss=ppdet.modeling.IouAwareLoss(loss_weight=1.0)
  1700. if use_iou_aware else None)
  1701. yolo_head = ppdet.modeling.YOLOv3Head(
  1702. in_channels=[i.channels for i in neck.out_shape],
  1703. anchors=anchors,
  1704. anchor_masks=anchor_masks,
  1705. num_classes=num_classes,
  1706. loss=loss,
  1707. iou_aware=use_iou_aware,
  1708. iou_aware_factor=.5)
  1709. if use_matrix_nms:
  1710. nms = ppdet.modeling.MatrixNMS(
  1711. keep_top_k=nms_keep_topk,
  1712. score_threshold=nms_score_threshold,
  1713. post_threshold=.01,
  1714. nms_top_k=nms_topk,
  1715. background_label=-1)
  1716. else:
  1717. nms = ppdet.modeling.MultiClassNMS(
  1718. score_threshold=nms_score_threshold,
  1719. nms_top_k=nms_topk,
  1720. keep_top_k=nms_keep_topk,
  1721. nms_threshold=nms_iou_threshold)
  1722. post_process = ppdet.modeling.BBoxPostProcess(
  1723. decode=ppdet.modeling.YOLOBox(
  1724. num_classes=num_classes,
  1725. conf_thresh=.01,
  1726. downsample_ratio=32,
  1727. clip_bbox=True,
  1728. scale_x_y=scale_x_y),
  1729. nms=nms)
  1730. params.update({
  1731. 'backbone': backbone,
  1732. 'neck': neck,
  1733. 'yolo_head': yolo_head,
  1734. 'post_process': post_process
  1735. })
  1736. super(YOLOv3, self).__init__(
  1737. model_name='YOLOv3', num_classes=num_classes, **params)
  1738. self.anchors = anchors
  1739. self.anchor_masks = anchor_masks
  1740. self.model_name = 'PPYOLOv2'
  1741. def _get_test_inputs(self, image_shape):
  1742. if image_shape is not None:
  1743. image_shape = self._check_image_shape(image_shape)
  1744. self._fix_transforms_shape(image_shape[-2:])
  1745. else:
  1746. image_shape = [None, 3, 640, 640]
  1747. if getattr(self, 'test_transforms', None):
  1748. for idx, op in enumerate(self.test_transforms.transforms):
  1749. name = op.__class__.__name__
  1750. if name == 'Resize':
  1751. image_shape = [None, 3] + list(
  1752. self.test_transforms.transforms[idx].target_size)
  1753. logging.warning(
  1754. '[Important!!!] When exporting inference model for {},'.format(
  1755. self.__class__.__name__) +
  1756. ' if fixed_input_shape is not set, it will be forcibly set to {}. '.
  1757. format(image_shape) +
  1758. 'Please check image shape after transforms is {}, if not, fixed_input_shape '.
  1759. format(image_shape[1:]) + 'should be specified manually.')
  1760. self.fixed_input_shape = image_shape
  1761. return self._define_input_spec(image_shape)
  1762. class MaskRCNN(BaseDetector):
  1763. def __init__(self,
  1764. num_classes=80,
  1765. backbone='ResNet50_vd',
  1766. with_fpn=True,
  1767. with_dcn=False,
  1768. aspect_ratios=[0.5, 1.0, 2.0],
  1769. anchor_sizes=[[32], [64], [128], [256], [512]],
  1770. keep_top_k=100,
  1771. nms_threshold=0.5,
  1772. score_threshold=0.05,
  1773. fpn_num_channels=256,
  1774. rpn_batch_size_per_im=256,
  1775. rpn_fg_fraction=0.5,
  1776. test_pre_nms_top_n=None,
  1777. test_post_nms_top_n=1000,
  1778. **params):
  1779. self.init_params = locals()
  1780. if backbone not in {
  1781. 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101',
  1782. 'ResNet101_vd'
  1783. }:
  1784. raise ValueError(
  1785. "backbone: {} is not supported. Please choose one of "
  1786. "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet101', 'ResNet101_vd')".
  1787. format(backbone))
  1788. self.backbone_name = backbone + '_fpn' if with_fpn else backbone
  1789. dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
  1790. if params.get('with_net', True):
  1791. if backbone == 'ResNet50':
  1792. if with_fpn:
  1793. backbone = self._get_backbone(
  1794. 'ResNet',
  1795. norm_type='bn',
  1796. freeze_at=0,
  1797. return_idx=[0, 1, 2, 3],
  1798. num_stages=4,
  1799. dcn_v2_stages=dcn_v2_stages)
  1800. else:
  1801. if with_dcn:
  1802. logging.warning(
  1803. "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
  1804. format(backbone))
  1805. backbone = self._get_backbone(
  1806. 'ResNet',
  1807. norm_type='bn',
  1808. freeze_at=0,
  1809. return_idx=[2],
  1810. num_stages=3)
  1811. elif 'ResNet50_vd' in backbone:
  1812. if not with_fpn:
  1813. logging.warning(
  1814. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1815. format(backbone))
  1816. with_fpn = True
  1817. backbone = self._get_backbone(
  1818. 'ResNet',
  1819. variant='d',
  1820. norm_type='bn',
  1821. freeze_at=0,
  1822. return_idx=[0, 1, 2, 3],
  1823. num_stages=4,
  1824. lr_mult_list=[0.05, 0.05, 0.1, 0.15]
  1825. if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
  1826. dcn_v2_stages=dcn_v2_stages)
  1827. else:
  1828. if not with_fpn:
  1829. logging.warning(
  1830. "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
  1831. format(backbone))
  1832. with_fpn = True
  1833. backbone = self._get_backbone(
  1834. 'ResNet',
  1835. variant='d' if '_vd' in backbone else 'b',
  1836. depth=101,
  1837. norm_type='bn',
  1838. freeze_at=0,
  1839. return_idx=[0, 1, 2, 3],
  1840. num_stages=4,
  1841. dcn_v2_stages=dcn_v2_stages)
  1842. rpn_in_channel = backbone.out_shape[0].channels
  1843. if with_fpn:
  1844. neck = ppdet.modeling.FPN(
  1845. in_channels=[i.channels for i in backbone.out_shape],
  1846. out_channel=fpn_num_channels,
  1847. spatial_scales=[
  1848. 1.0 / i.stride for i in backbone.out_shape
  1849. ])
  1850. rpn_in_channel = neck.out_shape[0].channels
  1851. anchor_generator_cfg = {
  1852. 'aspect_ratios': aspect_ratios,
  1853. 'anchor_sizes': anchor_sizes,
  1854. 'strides': [4, 8, 16, 32, 64]
  1855. }
  1856. train_proposal_cfg = {
  1857. 'min_size': 0.0,
  1858. 'nms_thresh': .7,
  1859. 'pre_nms_top_n': 2000,
  1860. 'post_nms_top_n': 1000,
  1861. 'topk_after_collect': True
  1862. }
  1863. test_proposal_cfg = {
  1864. 'min_size': 0.0,
  1865. 'nms_thresh': .7,
  1866. 'pre_nms_top_n': 1000
  1867. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1868. 'post_nms_top_n': test_post_nms_top_n
  1869. }
  1870. bb_head = ppdet.modeling.TwoFCHead(
  1871. in_channel=neck.out_shape[0].channels, out_channel=1024)
  1872. bb_roi_extractor_cfg = {
  1873. 'resolution': 7,
  1874. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1875. 'sampling_ratio': 0,
  1876. 'aligned': True
  1877. }
  1878. with_pool = False
  1879. m_head = ppdet.modeling.MaskFeat(
  1880. in_channel=neck.out_shape[0].channels,
  1881. out_channel=256,
  1882. num_convs=4)
  1883. m_roi_extractor_cfg = {
  1884. 'resolution': 14,
  1885. 'spatial_scale': [1. / i.stride for i in neck.out_shape],
  1886. 'sampling_ratio': 0,
  1887. 'aligned': True
  1888. }
  1889. mask_assigner = MaskAssigner(
  1890. num_classes=num_classes, mask_resolution=28)
  1891. share_bbox_feat = False
  1892. else:
  1893. neck = None
  1894. anchor_generator_cfg = {
  1895. 'aspect_ratios': aspect_ratios,
  1896. 'anchor_sizes': anchor_sizes,
  1897. 'strides': [16]
  1898. }
  1899. train_proposal_cfg = {
  1900. 'min_size': 0.0,
  1901. 'nms_thresh': .7,
  1902. 'pre_nms_top_n': 12000,
  1903. 'post_nms_top_n': 2000,
  1904. 'topk_after_collect': False
  1905. }
  1906. test_proposal_cfg = {
  1907. 'min_size': 0.0,
  1908. 'nms_thresh': .7,
  1909. 'pre_nms_top_n': 6000
  1910. if test_pre_nms_top_n is None else test_pre_nms_top_n,
  1911. 'post_nms_top_n': test_post_nms_top_n
  1912. }
  1913. bb_head = ppdet.modeling.Res5Head()
  1914. bb_roi_extractor_cfg = {
  1915. 'resolution': 14,
  1916. 'spatial_scale':
  1917. [1. / i.stride for i in backbone.out_shape],
  1918. 'sampling_ratio': 0,
  1919. 'aligned': True
  1920. }
  1921. with_pool = True
  1922. m_head = ppdet.modeling.MaskFeat(
  1923. in_channel=bb_head.out_shape[0].channels,
  1924. out_channel=256,
  1925. num_convs=0)
  1926. m_roi_extractor_cfg = {
  1927. 'resolution': 14,
  1928. 'spatial_scale':
  1929. [1. / i.stride for i in backbone.out_shape],
  1930. 'sampling_ratio': 0,
  1931. 'aligned': True
  1932. }
  1933. mask_assigner = MaskAssigner(
  1934. num_classes=num_classes, mask_resolution=14)
  1935. share_bbox_feat = True
  1936. rpn_target_assign_cfg = {
  1937. 'batch_size_per_im': rpn_batch_size_per_im,
  1938. 'fg_fraction': rpn_fg_fraction,
  1939. 'negative_overlap': .3,
  1940. 'positive_overlap': .7,
  1941. 'use_random': True
  1942. }
  1943. rpn_head = ppdet.modeling.RPNHead(
  1944. anchor_generator=anchor_generator_cfg,
  1945. rpn_target_assign=rpn_target_assign_cfg,
  1946. train_proposal=train_proposal_cfg,
  1947. test_proposal=test_proposal_cfg,
  1948. in_channel=rpn_in_channel)
  1949. bbox_assigner = BBoxAssigner(num_classes=num_classes)
  1950. bbox_head = ppdet.modeling.BBoxHead(
  1951. head=bb_head,
  1952. in_channel=bb_head.out_shape[0].channels,
  1953. roi_extractor=bb_roi_extractor_cfg,
  1954. with_pool=with_pool,
  1955. bbox_assigner=bbox_assigner,
  1956. num_classes=num_classes)
  1957. mask_head = ppdet.modeling.MaskHead(
  1958. head=m_head,
  1959. roi_extractor=m_roi_extractor_cfg,
  1960. mask_assigner=mask_assigner,
  1961. share_bbox_feat=share_bbox_feat,
  1962. num_classes=num_classes)
  1963. bbox_post_process = ppdet.modeling.BBoxPostProcess(
  1964. num_classes=num_classes,
  1965. decode=ppdet.modeling.RCNNBox(num_classes=num_classes),
  1966. nms=ppdet.modeling.MultiClassNMS(
  1967. score_threshold=score_threshold,
  1968. keep_top_k=keep_top_k,
  1969. nms_threshold=nms_threshold))
  1970. mask_post_process = ppdet.modeling.MaskPostProcess(
  1971. binary_thresh=.5)
  1972. params.update({
  1973. 'backbone': backbone,
  1974. 'neck': neck,
  1975. 'rpn_head': rpn_head,
  1976. 'bbox_head': bbox_head,
  1977. 'mask_head': mask_head,
  1978. 'bbox_post_process': bbox_post_process,
  1979. 'mask_post_process': mask_post_process
  1980. })
  1981. self.with_fpn = with_fpn
  1982. super(MaskRCNN, self).__init__(
  1983. model_name='MaskRCNN', num_classes=num_classes, **params)
  1984. def train(self,
  1985. num_epochs,
  1986. train_dataset,
  1987. train_batch_size=64,
  1988. eval_dataset=None,
  1989. optimizer=None,
  1990. save_interval_epochs=1,
  1991. log_interval_steps=10,
  1992. save_dir='output',
  1993. pretrain_weights='IMAGENET',
  1994. learning_rate=.001,
  1995. warmup_steps=0,
  1996. warmup_start_lr=0.0,
  1997. lr_decay_epochs=(216, 243),
  1998. lr_decay_gamma=0.1,
  1999. metric=None,
  2000. use_ema=False,
  2001. early_stop=False,
  2002. early_stop_patience=5,
  2003. use_vdl=True,
  2004. resume_checkpoint=None):
  2005. """
  2006. Train the model.
  2007. Args:
  2008. num_epochs(int): The number of epochs.
  2009. train_dataset(paddlex.dataset): Training dataset.
  2010. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  2011. eval_dataset(paddlex.dataset, optional):
  2012. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  2013. optimizer(paddle.optimizer.Optimizer or None, optional):
  2014. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  2015. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  2016. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  2017. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  2018. pretrain_weights(str or None, optional):
  2019. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  2020. learning_rate(float, optional): Learning rate for training. Defaults to .001.
  2021. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  2022. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  2023. lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
  2024. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
  2025. metric({'VOC', 'COCO', None}, optional):
  2026. Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
  2027. use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
  2028. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  2029. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  2030. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  2031. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  2032. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  2033. `pretrain_weights` can be set simultaneously. Defaults to None.
  2034. """
  2035. if train_dataset.pos_num < len(train_dataset.file_list):
  2036. train_dataset.num_workers = 0
  2037. super(MaskRCNN, self).train(
  2038. num_epochs, train_dataset, train_batch_size, eval_dataset,
  2039. optimizer, save_interval_epochs, log_interval_steps, save_dir,
  2040. pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
  2041. lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
  2042. early_stop_patience, use_vdl, resume_checkpoint)
  2043. def _compose_batch_transform(self, transforms, mode='train'):
  2044. if mode == 'train':
  2045. default_batch_transforms = [
  2046. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  2047. ]
  2048. else:
  2049. default_batch_transforms = [
  2050. _BatchPadding(pad_to_stride=32 if self.with_fpn else -1)
  2051. ]
  2052. custom_batch_transforms = []
  2053. for i, op in enumerate(transforms.transforms):
  2054. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  2055. if mode != 'train':
  2056. raise Exception(
  2057. "{} cannot be present in the {} transforms. ".format(
  2058. op.__class__.__name__, mode) +
  2059. "Please check the {} transforms.".format(mode))
  2060. custom_batch_transforms.insert(0, copy.deepcopy(op))
  2061. batch_transforms = BatchCompose(
  2062. custom_batch_transforms + default_batch_transforms,
  2063. collate_batch=False)
  2064. return batch_transforms
  2065. def _fix_transforms_shape(self, image_shape):
  2066. if getattr(self, 'test_transforms', None):
  2067. has_resize_op = False
  2068. resize_op_idx = -1
  2069. normalize_op_idx = len(self.test_transforms.transforms)
  2070. for idx, op in enumerate(self.test_transforms.transforms):
  2071. name = op.__class__.__name__
  2072. if name == 'ResizeByShort':
  2073. has_resize_op = True
  2074. resize_op_idx = idx
  2075. if name == 'Normalize':
  2076. normalize_op_idx = idx
  2077. if not has_resize_op:
  2078. self.test_transforms.transforms.insert(
  2079. normalize_op_idx,
  2080. Resize(
  2081. target_size=image_shape,
  2082. keep_ratio=True,
  2083. interp='CUBIC'))
  2084. else:
  2085. self.test_transforms.transforms[resize_op_idx] = Resize(
  2086. target_size=image_shape, keep_ratio=True, interp='CUBIC')
  2087. self.test_transforms.transforms.append(
  2088. Padding(im_padding_value=[0., 0., 0.]))
  2089. def _get_test_inputs(self, image_shape):
  2090. if image_shape is not None:
  2091. image_shape = self._check_image_shape(image_shape)
  2092. self._fix_transforms_shape(image_shape[-2:])
  2093. else:
  2094. image_shape = [None, 3, -1, -1]
  2095. if self.with_fpn:
  2096. self.test_transforms.transforms.append(
  2097. Padding(im_padding_value=[0., 0., 0.]))
  2098. self.fixed_input_shape = image_shape
  2099. return self._define_input_spec(image_shape)