classifier.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import math
  16. import os.path as osp
  17. from collections import OrderedDict
  18. import numpy as np
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle.static import InputSpec
  22. from paddlex.utils import logging, TrainingStats
  23. from paddlex.cv.models.base import BaseModel
  24. from paddlex.cv.transforms import arrange_transforms
  25. from paddlex.cv.transforms.operators import Resize
  26. from paddlex.ppcls import arch
  27. from paddlex.ppcls.loss import CELoss
  28. __all__ = [
  29. "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152",
  30. "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet50_vd_ssld",
  31. "ResNet101_vd", "ResNet101_vd_ssld", "ResNet152_vd", "ResNet200_vd",
  32. "AlexNet", "DarkNet53", "MobileNetV1", "MobileNetV2", "MobileNetV3_small",
  33. "MobileNetV3_small_ssld", "MobileNetV3_large", "MobileNetV3_large_ssld",
  34. "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201", "DenseNet264",
  35. "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C", "HRNet_W44_C",
  36. "HRNet_W48_C", "HRNet_W64_C", "Xception41", "Xception65", "Xception71",
  37. "ShuffleNetV2", "ShuffleNetV2_swish", "PPLCNet", "PPLCNet_ssld"
  38. ]
  39. class BaseClassifier(BaseModel):
  40. """Parent class of all classification models.
  41. Args:
  42. model_name (str, optional): Name of classification model. Defaults to 'ResNet50'.
  43. num_classes (int, optional): The number of target classes. Defaults to 1000.
  44. """
  45. def __init__(self, model_name='ResNet50', num_classes=1000, **params):
  46. self.init_params = locals()
  47. self.init_params.update(params)
  48. if 'lr_mult_list' in self.init_params:
  49. del self.init_params['lr_mult_list']
  50. if 'with_net' in self.init_params:
  51. del self.init_params['with_net']
  52. super(BaseClassifier, self).__init__('classifier')
  53. if not hasattr(arch, model_name):
  54. raise Exception("ERROR: There's no model named {}.".format(
  55. model_name))
  56. self.model_name = model_name
  57. self.labels = None
  58. self.num_classes = num_classes
  59. for k, v in params.items():
  60. setattr(self, k, v)
  61. if params.get('with_net', True):
  62. params.pop('with_net', None)
  63. self.net = self.build_net(**params)
  64. def build_net(self, **params):
  65. with paddle.utils.unique_name.guard():
  66. net = arch.__dict__[self.model_name](class_num=self.num_classes,
  67. **params)
  68. return net
  69. def build_loss(self, label_smoothing=None):
  70. if isinstance(label_smoothing, bool):
  71. label_smoothing = .1 if label_smoothing else None
  72. self.loss_func = CELoss(epsilon=label_smoothing)
  73. def _fix_transforms_shape(self, image_shape):
  74. if hasattr(self, 'test_transforms'):
  75. if self.test_transforms is not None:
  76. self.test_transforms.transforms.append(
  77. Resize(target_size=image_shape))
  78. def _get_test_inputs(self, image_shape):
  79. if image_shape is not None:
  80. if len(image_shape) == 2:
  81. image_shape = [1, 3] + image_shape
  82. self._fix_transforms_shape(image_shape[-2:])
  83. else:
  84. image_shape = [None, 3, -1, -1]
  85. self.fixed_input_shape = image_shape
  86. input_spec = [
  87. InputSpec(
  88. shape=image_shape, name='image', dtype='float32')
  89. ]
  90. return input_spec
  91. def run(self, net, inputs, mode):
  92. net_out = net(inputs[0])
  93. softmax_out = net_out if self.status == 'Infer' else F.softmax(net_out)
  94. if mode == 'test':
  95. outputs = OrderedDict([('prediction', softmax_out)])
  96. elif mode == 'eval':
  97. pred = softmax_out
  98. gt = inputs[1]
  99. labels = inputs[1].reshape([-1, 1])
  100. acc1 = paddle.metric.accuracy(softmax_out, label=labels)
  101. k = min(5, self.num_classes)
  102. acck = paddle.metric.accuracy(softmax_out, label=labels, k=k)
  103. # multi cards eval
  104. if paddle.distributed.get_world_size() > 1:
  105. acc1 = paddle.distributed.all_reduce(
  106. acc1, op=paddle.distributed.ReduceOp.
  107. SUM) / paddle.distributed.get_world_size()
  108. acck = paddle.distributed.all_reduce(
  109. acck, op=paddle.distributed.ReduceOp.
  110. SUM) / paddle.distributed.get_world_size()
  111. pred = list()
  112. gt = list()
  113. paddle.distributed.all_gather(pred, softmax_out)
  114. paddle.distributed.all_gather(gt, inputs[1])
  115. pred = paddle.concat(pred, axis=0)
  116. gt = paddle.concat(gt, axis=0)
  117. outputs = OrderedDict([('acc1', acc1), ('acc{}'.format(k), acck),
  118. ('prediction', pred), ('labels', gt)])
  119. else:
  120. # mode == 'train'
  121. labels = inputs[1].reshape([-1, 1])
  122. loss = self.loss_func(net_out, inputs[1])['CELoss']
  123. acc1 = paddle.metric.accuracy(softmax_out, label=labels, k=1)
  124. k = min(5, self.num_classes)
  125. acck = paddle.metric.accuracy(softmax_out, label=labels, k=k)
  126. outputs = OrderedDict([('loss', loss), ('acc1', acc1),
  127. ('acc{}'.format(k), acck)])
  128. return outputs
  129. def default_optimizer(self,
  130. parameters,
  131. learning_rate,
  132. warmup_steps,
  133. warmup_start_lr,
  134. lr_decay_epochs,
  135. lr_decay_gamma,
  136. num_steps_each_epoch,
  137. reg_coeff=1e-04,
  138. scheduler='Piecewise',
  139. num_epochs=None):
  140. if scheduler.lower() == 'piecewise':
  141. if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
  142. 0] * num_steps_each_epoch:
  143. logging.error(
  144. "In function train(), parameters must satisfy: "
  145. "warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
  146. "See this doc for more information: "
  147. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
  148. exit=False)
  149. logging.error(
  150. "Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
  151. "must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
  152. format(lr_decay_epochs[0] * num_steps_each_epoch,
  153. warmup_steps // num_steps_each_epoch),
  154. exit=True)
  155. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  156. values = [
  157. learning_rate * (lr_decay_gamma**i)
  158. for i in range(len(lr_decay_epochs) + 1)
  159. ]
  160. scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
  161. elif scheduler.lower() == 'cosine':
  162. if num_epochs is None:
  163. logging.error(
  164. "`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
  165. format(num_epochs),
  166. exit=False)
  167. if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
  168. logging.error(
  169. "In function train(), parameters must satisfy: "
  170. "warmup_steps <= num_epochs * num_samples_in_train_dataset. "
  171. "See this doc for more information: "
  172. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/parameters.md",
  173. exit=False)
  174. logging.error(
  175. "`warmup_steps` must be less than the total number of steps({}), "
  176. "please modify 'num_epochs' or 'warmup_steps' in train function".
  177. format(num_epochs * num_steps_each_epoch),
  178. exit=True)
  179. T_max = num_epochs * num_steps_each_epoch - warmup_steps
  180. scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
  181. learning_rate=learning_rate,
  182. T_max=T_max,
  183. eta_min=0.0,
  184. last_epoch=-1)
  185. else:
  186. logging.error(
  187. "Invalid learning rate scheduler: {}!".format(scheduler),
  188. exit=True)
  189. if warmup_steps > 0:
  190. scheduler = paddle.optimizer.lr.LinearWarmup(
  191. learning_rate=scheduler,
  192. warmup_steps=warmup_steps,
  193. start_lr=warmup_start_lr,
  194. end_lr=learning_rate)
  195. optimizer = paddle.optimizer.Momentum(
  196. scheduler,
  197. momentum=.9,
  198. weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
  199. parameters=parameters)
  200. return optimizer
  201. def train(self,
  202. num_epochs,
  203. train_dataset,
  204. train_batch_size=64,
  205. eval_dataset=None,
  206. optimizer=None,
  207. save_interval_epochs=1,
  208. log_interval_steps=10,
  209. save_dir='output',
  210. pretrain_weights='IMAGENET',
  211. learning_rate=.025,
  212. warmup_steps=0,
  213. warmup_start_lr=0.0,
  214. lr_decay_epochs=(30, 60, 90),
  215. lr_decay_gamma=0.1,
  216. label_smoothing=None,
  217. early_stop=False,
  218. early_stop_patience=5,
  219. use_vdl=True,
  220. resume_checkpoint=None):
  221. """
  222. Train the model.
  223. Args:
  224. num_epochs(int): The number of epochs.
  225. train_dataset(paddlex.dataset): Training dataset.
  226. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  227. eval_dataset(paddlex.dataset, optional):
  228. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  229. optimizer(paddle.optimizer.Optimizer or None, optional):
  230. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  231. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  232. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  233. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  234. pretrain_weights(str or None, optional):
  235. None or name/path of pretrained weights. If None, no pretrained weights will be loaded.
  236. At most one of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  237. Defaults to 'IMAGENET'.
  238. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  239. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  240. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  241. lr_decay_epochs(List[int] or Tuple[int], optional):
  242. Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
  243. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
  244. label_smoothing(float, bool or None, optional): Whether to adopt label smoothing or not.
  245. If float, the value refer to epsilon coefficient of label smoothing. If False or None, label smoothing
  246. will not be adopted. Otherwise, adopt label smoothing with epsilon equals to 0.1. Defaults to None.
  247. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  248. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  249. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  250. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  251. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  252. `pretrain_weights` can be set simultaneously. Defaults to None.
  253. """
  254. if self.status == 'Infer':
  255. logging.error(
  256. "Exported inference model does not support training.",
  257. exit=True)
  258. if pretrain_weights is not None and resume_checkpoint is not None:
  259. logging.error(
  260. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  261. exit=True)
  262. self.labels = train_dataset.labels
  263. # build optimizer if not defined
  264. if optimizer is None:
  265. num_steps_each_epoch = len(train_dataset) // train_batch_size
  266. self.optimizer = self.default_optimizer(
  267. parameters=self.net.parameters(),
  268. learning_rate=learning_rate,
  269. warmup_steps=warmup_steps,
  270. warmup_start_lr=warmup_start_lr,
  271. lr_decay_epochs=lr_decay_epochs,
  272. lr_decay_gamma=lr_decay_gamma,
  273. num_steps_each_epoch=num_steps_each_epoch)
  274. else:
  275. self.optimizer = optimizer
  276. # build loss
  277. self.build_loss(label_smoothing)
  278. # initiate weights
  279. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  280. if pretrain_weights not in ['IMAGENET']:
  281. logging.warning(
  282. "Path of pretrain_weights('{}') does not exist!".format(
  283. pretrain_weights))
  284. logging.warning(
  285. "Pretrain_weights is forcibly set to 'IMAGENET'. "
  286. "If don't want to use pretrain weights, "
  287. "set pretrain_weights to be None.")
  288. pretrain_weights = 'IMAGENET'
  289. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  290. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  291. logging.error(
  292. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  293. exit=True)
  294. pretrained_dir = osp.join(save_dir, 'pretrain')
  295. self.net_initialize(
  296. pretrain_weights=pretrain_weights,
  297. save_dir=pretrained_dir,
  298. resume_checkpoint=resume_checkpoint)
  299. # start train loop
  300. self.train_loop(
  301. num_epochs=num_epochs,
  302. train_dataset=train_dataset,
  303. train_batch_size=train_batch_size,
  304. eval_dataset=eval_dataset,
  305. save_interval_epochs=save_interval_epochs,
  306. log_interval_steps=log_interval_steps,
  307. save_dir=save_dir,
  308. early_stop=early_stop,
  309. early_stop_patience=early_stop_patience,
  310. use_vdl=use_vdl)
  311. def quant_aware_train(self,
  312. num_epochs,
  313. train_dataset,
  314. train_batch_size=64,
  315. eval_dataset=None,
  316. optimizer=None,
  317. save_interval_epochs=1,
  318. log_interval_steps=10,
  319. save_dir='output',
  320. learning_rate=.000025,
  321. warmup_steps=0,
  322. warmup_start_lr=0.0,
  323. lr_decay_epochs=(30, 60, 90),
  324. lr_decay_gamma=0.1,
  325. early_stop=False,
  326. early_stop_patience=5,
  327. use_vdl=True,
  328. resume_checkpoint=None,
  329. quant_config=None):
  330. """
  331. Quantization-aware training.
  332. Args:
  333. num_epochs(int): The number of epochs.
  334. train_dataset(paddlex.dataset): Training dataset.
  335. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  336. eval_dataset(paddlex.dataset, optional):
  337. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  338. optimizer(paddle.optimizer.Optimizer or None, optional):
  339. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  340. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  341. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  342. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  343. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  344. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  345. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  346. lr_decay_epochs(List[int] or Tuple[int], optional):
  347. Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
  348. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
  349. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  350. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  351. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  352. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  353. configuration will be used. Defaults to None.
  354. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  355. from. If None, no training checkpoint will be resumed. Defaults to None.
  356. """
  357. self._prepare_qat(quant_config)
  358. self.train(
  359. num_epochs=num_epochs,
  360. train_dataset=train_dataset,
  361. train_batch_size=train_batch_size,
  362. eval_dataset=eval_dataset,
  363. optimizer=optimizer,
  364. save_interval_epochs=save_interval_epochs,
  365. log_interval_steps=log_interval_steps,
  366. save_dir=save_dir,
  367. pretrain_weights=None,
  368. learning_rate=learning_rate,
  369. warmup_steps=warmup_steps,
  370. warmup_start_lr=warmup_start_lr,
  371. lr_decay_epochs=lr_decay_epochs,
  372. lr_decay_gamma=lr_decay_gamma,
  373. early_stop=early_stop,
  374. early_stop_patience=early_stop_patience,
  375. use_vdl=use_vdl,
  376. resume_checkpoint=resume_checkpoint)
  377. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  378. """
  379. Evaluate the model.
  380. Args:
  381. eval_dataset(paddlex.dataset): Evaluation dataset.
  382. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  383. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  384. Returns:
  385. collections.OrderedDict with key-value pairs: {"acc1": `top 1 accuracy`, "acc5": `top 5 accuracy`}.
  386. """
  387. # 给transform添加arrange操作
  388. arrange_transforms(
  389. model_type=self.model_type,
  390. transforms=eval_dataset.transforms,
  391. mode='eval')
  392. self.net.eval()
  393. nranks = paddle.distributed.get_world_size()
  394. local_rank = paddle.distributed.get_rank()
  395. if nranks > 1:
  396. # Initialize parallel environment if not done.
  397. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  398. ):
  399. paddle.distributed.init_parallel_env()
  400. self.eval_data_loader = self.build_data_loader(
  401. eval_dataset, batch_size=batch_size, mode='eval')
  402. eval_metrics = TrainingStats()
  403. if return_details:
  404. true_labels = list()
  405. pred_scores = list()
  406. logging.info(
  407. "Start to evaluate(total_samples={}, total_steps={})...".format(
  408. eval_dataset.num_samples,
  409. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  410. with paddle.no_grad():
  411. for step, data in enumerate(self.eval_data_loader()):
  412. outputs = self.run(self.net, data, mode='eval')
  413. if return_details:
  414. true_labels.extend(outputs['labels'].tolist())
  415. pred_scores.extend(outputs['prediction'].tolist())
  416. outputs.pop('prediction')
  417. outputs.pop('labels')
  418. eval_metrics.update(outputs)
  419. if return_details:
  420. eval_details = {
  421. 'true_labels': true_labels,
  422. 'pred_scores': pred_scores
  423. }
  424. return eval_metrics.get(), eval_details
  425. else:
  426. return eval_metrics.get()
  427. def predict(self, img_file, transforms=None, topk=1):
  428. """
  429. Do inference.
  430. Args:
  431. img_file(List[np.ndarray or str], str or np.ndarray):
  432. Image path or decoded image data in a BGR format, which also could constitute a list,
  433. meaning all images to be predicted as a mini-batch.
  434. transforms(paddlex.transforms.Compose or None, optional):
  435. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  436. topk(int, optional): Keep topk results in prediction. Defaults to 1.
  437. Returns:
  438. If img_file is a string or np.array, the result is a dict with key-value pairs:
  439. {"category_id": `category_id`, "category": `category`, "score": `score`}.
  440. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  441. category_id(int): the predicted category ID
  442. category(str): category name
  443. score(float): confidence
  444. """
  445. if transforms is None and not hasattr(self, 'test_transforms'):
  446. raise Exception("transforms need to be defined, now is None.")
  447. if transforms is None:
  448. transforms = self.test_transforms
  449. true_topk = min(self.num_classes, topk)
  450. if isinstance(img_file, (str, np.ndarray)):
  451. images = [img_file]
  452. else:
  453. images = img_file
  454. im = self._preprocess(images, transforms)
  455. self.net.eval()
  456. with paddle.no_grad():
  457. outputs = self.run(self.net, im, mode='test')
  458. prediction = outputs['prediction'].numpy()
  459. prediction = self._postprocess(prediction, true_topk)
  460. if isinstance(img_file, (str, np.ndarray)):
  461. prediction = prediction[0]
  462. return prediction
  463. def _preprocess(self, images, transforms, to_tensor=True):
  464. arrange_transforms(
  465. model_type=self.model_type, transforms=transforms, mode='test')
  466. batch_im = list()
  467. for im in images:
  468. sample = {'image': im}
  469. batch_im.append(transforms(sample))
  470. if to_tensor:
  471. batch_im = paddle.to_tensor(batch_im)
  472. else:
  473. batch_im = np.asarray(batch_im)
  474. return batch_im,
  475. def _postprocess(self, results, true_topk):
  476. preds = list()
  477. for i, pred in enumerate(results):
  478. pred_label = np.argsort(pred)[::-1][:true_topk]
  479. preds.append([{
  480. 'category_id': l,
  481. 'category': self.labels[l],
  482. 'score': results[i][l]
  483. } for l in pred_label])
  484. return preds
  485. class ResNet18(BaseClassifier):
  486. def __init__(self, num_classes=1000, **params):
  487. super(ResNet18, self).__init__(
  488. model_name='ResNet18', num_classes=num_classes, **params)
  489. class ResNet34(BaseClassifier):
  490. def __init__(self, num_classes=1000, **params):
  491. super(ResNet34, self).__init__(
  492. model_name='ResNet34', num_classes=num_classes, **params)
  493. class ResNet50(BaseClassifier):
  494. def __init__(self, num_classes=1000, **params):
  495. super(ResNet50, self).__init__(
  496. model_name='ResNet50', num_classes=num_classes, **params)
  497. class ResNet101(BaseClassifier):
  498. def __init__(self, num_classes=1000, **params):
  499. super(ResNet101, self).__init__(
  500. model_name='ResNet101', num_classes=num_classes, **params)
  501. class ResNet152(BaseClassifier):
  502. def __init__(self, num_classes=1000, **params):
  503. super(ResNet152, self).__init__(
  504. model_name='ResNet152', num_classes=num_classes, **params)
  505. class ResNet18_vd(BaseClassifier):
  506. def __init__(self, num_classes=1000, **params):
  507. super(ResNet18_vd, self).__init__(
  508. model_name='ResNet18_vd', num_classes=num_classes, **params)
  509. class ResNet34_vd(BaseClassifier):
  510. def __init__(self, num_classes=1000, **params):
  511. super(ResNet34_vd, self).__init__(
  512. model_name='ResNet34_vd', num_classes=num_classes, **params)
  513. class ResNet50_vd(BaseClassifier):
  514. def __init__(self, num_classes=1000, **params):
  515. super(ResNet50_vd, self).__init__(
  516. model_name='ResNet50_vd', num_classes=num_classes, **params)
  517. class ResNet50_vd_ssld(BaseClassifier):
  518. def __init__(self, num_classes=1000, **params):
  519. super(ResNet50_vd_ssld, self).__init__(
  520. model_name='ResNet50_vd',
  521. num_classes=num_classes,
  522. lr_mult_list=[.1, .1, .2, .2, .3],
  523. **params)
  524. self.model_name = 'ResNet50_vd_ssld'
  525. class ResNet101_vd(BaseClassifier):
  526. def __init__(self, num_classes=1000, **params):
  527. super(ResNet101_vd, self).__init__(
  528. model_name='ResNet101_vd', num_classes=num_classes, **params)
  529. class ResNet101_vd_ssld(BaseClassifier):
  530. def __init__(self, num_classes=1000, **params):
  531. super(ResNet101_vd_ssld, self).__init__(
  532. model_name='ResNet101_vd',
  533. num_classes=num_classes,
  534. lr_mult_list=[.1, .1, .2, .2, .3],
  535. **params)
  536. self.model_name = 'ResNet101_vd_ssld'
  537. class ResNet152_vd(BaseClassifier):
  538. def __init__(self, num_classes=1000, **params):
  539. super(ResNet152_vd, self).__init__(
  540. model_name='ResNet152_vd', num_classes=num_classes, **params)
  541. class ResNet200_vd(BaseClassifier):
  542. def __init__(self, num_classes=1000, **params):
  543. super(ResNet200_vd, self).__init__(
  544. model_name='ResNet200_vd', num_classes=num_classes, **params)
  545. class AlexNet(BaseClassifier):
  546. def __init__(self, num_classes=1000, **params):
  547. super(AlexNet, self).__init__(
  548. model_name='AlexNet', num_classes=num_classes, **params)
  549. def _get_test_inputs(self, image_shape):
  550. if image_shape is not None:
  551. if len(image_shape) == 2:
  552. image_shape = [None, 3] + image_shape
  553. else:
  554. image_shape = [None, 3, 224, 224]
  555. logging.warning(
  556. '[Important!!!] When exporting inference model for {},'.format(
  557. self.__class__.__name__) +
  558. ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
  559. +
  560. 'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
  561. + 'should be specified manually.')
  562. self._fix_transforms_shape(image_shape[-2:])
  563. self.fixed_input_shape = image_shape
  564. input_spec = [
  565. InputSpec(
  566. shape=image_shape, name='image', dtype='float32')
  567. ]
  568. return input_spec
  569. class DarkNet53(BaseClassifier):
  570. def __init__(self, num_classes=1000, **params):
  571. super(DarkNet53, self).__init__(
  572. model_name='DarkNet53', num_classes=num_classes, **params)
  573. class MobileNetV1(BaseClassifier):
  574. def __init__(self, num_classes=1000, scale=1.0, **params):
  575. supported_scale = [.25, .5, .75, 1.0]
  576. if scale not in supported_scale:
  577. logging.warning("scale={} is not supported by MobileNetV1, "
  578. "scale is forcibly set to 1.0".format(scale))
  579. scale = 1.0
  580. if scale == 1:
  581. model_name = 'MobileNetV1'
  582. else:
  583. model_name = 'MobileNetV1_x' + str(scale).replace('.', '_')
  584. self.scale = scale
  585. super(MobileNetV1, self).__init__(
  586. model_name=model_name, num_classes=num_classes, **params)
  587. class MobileNetV2(BaseClassifier):
  588. def __init__(self, num_classes=1000, scale=1.0, **params):
  589. supported_scale = [.25, .5, .75, 1.0, 1.5, 2.0]
  590. if scale not in supported_scale:
  591. logging.warning("scale={} is not supported by MobileNetV2, "
  592. "scale is forcibly set to 1.0".format(scale))
  593. scale = 1.0
  594. if scale == 1:
  595. model_name = 'MobileNetV2'
  596. else:
  597. model_name = 'MobileNetV2_x' + str(scale).replace('.', '_')
  598. super(MobileNetV2, self).__init__(
  599. model_name=model_name, num_classes=num_classes, **params)
  600. class MobileNetV3_small(BaseClassifier):
  601. def __init__(self, num_classes=1000, scale=1.0, **params):
  602. supported_scale = [.35, .5, .75, 1.0, 1.25]
  603. if scale not in supported_scale:
  604. logging.warning("scale={} is not supported by MobileNetV3_small, "
  605. "scale is forcibly set to 1.0".format(scale))
  606. scale = 1.0
  607. model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
  608. '_')
  609. super(MobileNetV3_small, self).__init__(
  610. model_name=model_name, num_classes=num_classes, **params)
  611. class MobileNetV3_small_ssld(BaseClassifier):
  612. def __init__(self, num_classes=1000, scale=1.0, **params):
  613. supported_scale = [.35, 1.0]
  614. if scale not in supported_scale:
  615. logging.warning(
  616. "scale={} is not supported by MobileNetV3_small_ssld, "
  617. "scale is forcibly set to 1.0".format(scale))
  618. scale = 1.0
  619. model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
  620. '_')
  621. super(MobileNetV3_small_ssld, self).__init__(
  622. model_name=model_name, num_classes=num_classes, **params)
  623. self.model_name = model_name + '_ssld'
  624. class MobileNetV3_large(BaseClassifier):
  625. def __init__(self, num_classes=1000, scale=1.0, **params):
  626. supported_scale = [.35, .5, .75, 1.0, 1.25]
  627. if scale not in supported_scale:
  628. logging.warning("scale={} is not supported by MobileNetV3_large, "
  629. "scale is forcibly set to 1.0".format(scale))
  630. scale = 1.0
  631. model_name = 'MobileNetV3_large_x' + str(float(scale)).replace('.',
  632. '_')
  633. super(MobileNetV3_large, self).__init__(
  634. model_name=model_name, num_classes=num_classes, **params)
  635. class MobileNetV3_large_ssld(BaseClassifier):
  636. def __init__(self, num_classes=1000, **params):
  637. super(MobileNetV3_large_ssld, self).__init__(
  638. model_name='MobileNetV3_large_x1_0',
  639. num_classes=num_classes,
  640. **params)
  641. self.model_name = 'MobileNetV3_large_x1_0_ssld'
  642. class DenseNet121(BaseClassifier):
  643. def __init__(self, num_classes=1000, **params):
  644. super(DenseNet121, self).__init__(
  645. model_name='DenseNet121', num_classes=num_classes, **params)
  646. class DenseNet161(BaseClassifier):
  647. def __init__(self, num_classes=1000, **params):
  648. super(DenseNet161, self).__init__(
  649. model_name='DenseNet161', num_classes=num_classes, **params)
  650. class DenseNet169(BaseClassifier):
  651. def __init__(self, num_classes=1000, **params):
  652. super(DenseNet169, self).__init__(
  653. model_name='DenseNet169', num_classes=num_classes, **params)
  654. class DenseNet201(BaseClassifier):
  655. def __init__(self, num_classes=1000, **params):
  656. super(DenseNet201, self).__init__(
  657. model_name='DenseNet201', num_classes=num_classes, **params)
  658. class DenseNet264(BaseClassifier):
  659. def __init__(self, num_classes=1000, **params):
  660. super(DenseNet264, self).__init__(
  661. model_name='DenseNet264', num_classes=num_classes, **params)
  662. class HRNet_W18_C(BaseClassifier):
  663. def __init__(self, num_classes=1000, **params):
  664. super(HRNet_W18_C, self).__init__(
  665. model_name='HRNet_W18_C', num_classes=num_classes, **params)
  666. class HRNet_W30_C(BaseClassifier):
  667. def __init__(self, num_classes=1000, **params):
  668. super(HRNet_W30_C, self).__init__(
  669. model_name='HRNet_W30_C', num_classes=num_classes, **params)
  670. class HRNet_W32_C(BaseClassifier):
  671. def __init__(self, num_classes=1000, **params):
  672. super(HRNet_W32_C, self).__init__(
  673. model_name='HRNet_W32_C', num_classes=num_classes, **params)
  674. class HRNet_W40_C(BaseClassifier):
  675. def __init__(self, num_classes=1000, **params):
  676. super(HRNet_W40_C, self).__init__(
  677. model_name='HRNet_W40_C', num_classes=num_classes, **params)
  678. class HRNet_W44_C(BaseClassifier):
  679. def __init__(self, num_classes=1000, **params):
  680. super(HRNet_W44_C, self).__init__(
  681. model_name='HRNet_W44_C', num_classes=num_classes, **params)
  682. class HRNet_W48_C(BaseClassifier):
  683. def __init__(self, num_classes=1000, **params):
  684. super(HRNet_W48_C, self).__init__(
  685. model_name='HRNet_W48_C', num_classes=num_classes, **params)
  686. class HRNet_W64_C(BaseClassifier):
  687. def __init__(self, num_classes=1000, **params):
  688. super(HRNet_W64_C, self).__init__(
  689. model_name='HRNet_W64_C', num_classes=num_classes, **params)
  690. class Xception41(BaseClassifier):
  691. def __init__(self, num_classes=1000, **params):
  692. super(Xception41, self).__init__(
  693. model_name='Xception41', num_classes=num_classes, **params)
  694. class Xception65(BaseClassifier):
  695. def __init__(self, num_classes=1000, **params):
  696. super(Xception65, self).__init__(
  697. model_name='Xception65', num_classes=num_classes, **params)
  698. class Xception71(BaseClassifier):
  699. def __init__(self, num_classes=1000, **params):
  700. super(Xception71, self).__init__(
  701. model_name='Xception71', num_classes=num_classes, **params)
  702. class ShuffleNetV2(BaseClassifier):
  703. def __init__(self, num_classes=1000, scale=1.0, **params):
  704. supported_scale = [.25, .33, .5, 1.0, 1.5, 2.0]
  705. if scale not in supported_scale:
  706. logging.warning("scale={} is not supported by ShuffleNetV2, "
  707. "scale is forcibly set to 1.0".format(scale))
  708. scale = 1.0
  709. model_name = 'ShuffleNetV2_x' + str(float(scale)).replace('.', '_')
  710. super(ShuffleNetV2, self).__init__(
  711. model_name=model_name, num_classes=num_classes, **params)
  712. def _get_test_inputs(self, image_shape):
  713. if image_shape is not None:
  714. if len(image_shape) == 2:
  715. image_shape = [None, 3] + image_shape
  716. else:
  717. image_shape = [None, 3, 224, 224]
  718. logging.warning(
  719. '[Important!!!] When exporting inference model for {},'.format(
  720. self.__class__.__name__) +
  721. ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
  722. +
  723. 'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
  724. + 'should be specified manually.')
  725. self._fix_transforms_shape(image_shape[-2:])
  726. self.fixed_input_shape = image_shape
  727. input_spec = [
  728. InputSpec(
  729. shape=image_shape, name='image', dtype='float32')
  730. ]
  731. return input_spec
  732. class ShuffleNetV2_swish(BaseClassifier):
  733. def __init__(self, num_classes=1000, **params):
  734. super(ShuffleNetV2_swish, self).__init__(
  735. model_name='ShuffleNetV2_x1_5', num_classes=num_classes, **params)
  736. def _get_test_inputs(self, image_shape):
  737. if image_shape is not None:
  738. if len(image_shape) == 2:
  739. image_shape = [None, 3] + image_shape
  740. else:
  741. image_shape = [None, 3, 224, 224]
  742. logging.warning(
  743. '[Important!!!] When exporting inference model for {},'.format(
  744. self.__class__.__name__) +
  745. ' if fixed_input_shape is not set, it will be forcibly set to [None, 3, 224, 224]'
  746. +
  747. 'Please check image shape after transforms is [3, 224, 224], if not, fixed_input_shape '
  748. + 'should be specified manually.')
  749. self._fix_transforms_shape(image_shape[-2:])
  750. self.fixed_input_shape = image_shape
  751. input_spec = [
  752. InputSpec(
  753. shape=image_shape, name='image', dtype='float32')
  754. ]
  755. return input_spec
  756. class PPLCNet(BaseClassifier):
  757. def __init__(self, num_classes=1000, scale=1., **params):
  758. supported_scale = [.25, .35, .5, .75, 1., 1.5, 2., 2.5]
  759. if scale not in supported_scale:
  760. logging.warning("scale={} is not supported by PPLCNet, "
  761. "scale is forcibly set to 1.0".format(scale))
  762. scale = 1.0
  763. model_name = 'PPLCNet_x' + str(float(scale)).replace('.', '_')
  764. super(PPLCNet, self).__init__(
  765. model_name=model_name, num_classes=num_classes, **params)
  766. def train(self,
  767. num_epochs,
  768. train_dataset,
  769. train_batch_size=64,
  770. eval_dataset=None,
  771. optimizer=None,
  772. save_interval_epochs=1,
  773. log_interval_steps=10,
  774. save_dir='output',
  775. pretrain_weights='IMAGENET',
  776. learning_rate=.1,
  777. warmup_steps=0,
  778. warmup_start_lr=0.0,
  779. lr_decay_epochs=(30, 60, 90),
  780. lr_decay_gamma=0.1,
  781. label_smoothing=None,
  782. early_stop=False,
  783. early_stop_patience=5,
  784. use_vdl=True,
  785. resume_checkpoint=None):
  786. """
  787. Train the model.
  788. Args:
  789. num_epochs(int): The number of epochs.
  790. train_dataset(paddlex.dataset): Training dataset.
  791. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  792. eval_dataset(paddlex.dataset, optional):
  793. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  794. optimizer(paddle.optimizer.Optimizer or None, optional):
  795. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  796. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  797. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  798. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  799. pretrain_weights(str or None, optional):
  800. None or name/path of pretrained weights. If None, no pretrained weights will be loaded.
  801. At most one of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  802. Defaults to 'IMAGENET'.
  803. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  804. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  805. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  806. lr_decay_epochs(List[int] or Tuple[int], optional):
  807. Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
  808. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
  809. label_smoothing(float, bool or None, optional): Whether to adopt label smoothing or not.
  810. If float, the value refer to epsilon coefficient of label smoothing. If False or None, label smoothing
  811. will not be adopted. Otherwise, adopt label smoothing with epsilon equals to 0.1. Defaults to None.
  812. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  813. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  814. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  815. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  816. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  817. `pretrain_weights` can be set simultaneously. Defaults to None.
  818. """
  819. if optimizer is None:
  820. num_steps_each_epoch = len(train_dataset) // train_batch_size
  821. optimizer = self.default_optimizer(
  822. parameters=self.net.parameters(),
  823. learning_rate=learning_rate,
  824. warmup_steps=warmup_steps,
  825. warmup_start_lr=warmup_start_lr,
  826. lr_decay_epochs=lr_decay_epochs,
  827. lr_decay_gamma=lr_decay_gamma,
  828. num_steps_each_epoch=num_steps_each_epoch,
  829. reg_coeff=3e-5,
  830. scheduler='Cosine',
  831. num_epochs=num_epochs)
  832. super(PPLCNet, self).train(
  833. num_epochs=num_epochs,
  834. train_dataset=train_dataset,
  835. train_batch_size=train_batch_size,
  836. eval_dataset=eval_dataset,
  837. optimizer=optimizer,
  838. save_interval_epochs=save_interval_epochs,
  839. log_interval_steps=log_interval_steps,
  840. save_dir=save_dir,
  841. pretrain_weights=pretrain_weights,
  842. learning_rate=learning_rate,
  843. warmup_steps=warmup_steps,
  844. warmup_start_lr=warmup_start_lr,
  845. lr_decay_epochs=lr_decay_epochs,
  846. lr_decay_gamma=lr_decay_gamma,
  847. label_smoothing=label_smoothing,
  848. early_stop=early_stop,
  849. early_stop_patience=early_stop_patience,
  850. use_vdl=use_vdl,
  851. resume_checkpoint=resume_checkpoint)
  852. class PPLCNet_ssld(PPLCNet):
  853. def __init__(self, num_classes=1000, scale=1., **params):
  854. supported_scale = [.5, 1., 2.5]
  855. if scale not in supported_scale:
  856. logging.warning("scale={} is not supported by PPLCNet, "
  857. "scale is forcibly set to 1.0".format(scale))
  858. scale = 1.0
  859. model_name = 'PPLCNet_x' + str(float(scale)).replace('.', '_')
  860. super(PPLCNet, self).__init__(
  861. model_name=model_name, num_classes=num_classes, **params)
  862. self.model_name = model_name + '_ssld'