classifier.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import math
  16. import os.path as osp
  17. from collections import OrderedDict
  18. import numpy as np
  19. import paddle
  20. from paddle import to_tensor
  21. import paddle.nn.functional as F
  22. from paddle.static import InputSpec
  23. from paddleslim import QAT
  24. from paddlex.utils import logging, TrainingStats, DisablePrint
  25. from paddlex.cv.models.base import BaseModel
  26. from paddlex.cv.transforms import arrange_transforms
  27. with DisablePrint():
  28. from PaddleClas.ppcls.modeling import architectures
  29. from PaddleClas.ppcls.modeling.loss import CELoss
  30. __all__ = [
  31. "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152",
  32. "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet50_vd_ssld",
  33. "ResNet101_vd", "ResNet101_vd_ssld", "ResNet152_vd", "ResNet200_vd",
  34. "AlexNet", "DarkNet53", "MobileNetV1", "MobileNetV2", "MobileNetV3_small",
  35. "MobileNetV3_small_ssld", "MobileNetV3_large", "MobileNetV3_large_ssld",
  36. "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201", "DenseNet264",
  37. "HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C", "HRNet_W44_C",
  38. "HRNet_W48_C", "HRNet_W64_C", "Xception41", "Xception65", "Xception71",
  39. "ShuffleNetV2", "ShuffleNetV2_swish"
  40. ]
  41. class BaseClassifier(BaseModel):
  42. """Parent class of all classification models.
  43. Args:
  44. model_name (str, optional): Name of classification model. Defaults to 'ResNet50'.
  45. num_classes (int, optional): The number of target classes. Defaults to 1000.
  46. """
  47. def __init__(self, model_name='ResNet50', num_classes=1000, **params):
  48. self.init_params = locals()
  49. self.init_params.update(params)
  50. del self.init_params['params']
  51. super(BaseClassifier, self).__init__('classifier')
  52. if not hasattr(architectures, model_name):
  53. raise Exception("ERROR: There's no model named {}.".format(
  54. model_name))
  55. self.model_name = model_name
  56. self.labels = None
  57. self.num_classes = num_classes
  58. for k, v in params.items():
  59. setattr(self, k, v)
  60. self.net = self.build_net(**params)
  61. def build_net(self, **params):
  62. with paddle.utils.unique_name.guard():
  63. net = architectures.__dict__[self.model_name](
  64. class_dim=self.num_classes, **params)
  65. return net
  66. def get_test_inputs(self, image_shape):
  67. input_spec = [
  68. InputSpec(
  69. shape=[None, 3] + image_shape, name='image', dtype='float32')
  70. ]
  71. return input_spec
  72. def run(self, net, inputs, mode):
  73. net_out = net(inputs[0])
  74. softmax_out = F.softmax(net_out)
  75. if mode == 'test':
  76. outputs = OrderedDict([('prediction', softmax_out)])
  77. elif mode == 'eval':
  78. labels = to_tensor(inputs[1].numpy().astype('int64').reshape(-1,
  79. 1))
  80. acc1 = paddle.metric.accuracy(softmax_out, label=labels)
  81. k = min(5, self.num_classes)
  82. acck = paddle.metric.accuracy(softmax_out, label=labels, k=k)
  83. # multi cards eval
  84. if paddle.distributed.get_world_size() > 1:
  85. acc1 = paddle.distributed.all_reduce(
  86. acc1, op=paddle.distributed.ReduceOp.
  87. SUM) / paddle.distributed.get_world_size()
  88. acck = paddle.distributed.all_reduce(
  89. acck, op=paddle.distributed.ReduceOp.
  90. SUM) / paddle.distributed.get_world_size()
  91. outputs = OrderedDict([('acc1', acc1), ('acc{}'.format(k), acck),
  92. ('prediction', softmax_out)])
  93. else:
  94. # mode == 'train'
  95. labels = to_tensor(inputs[1].numpy().astype('int64').reshape(-1,
  96. 1))
  97. loss = CELoss(class_dim=self.num_classes)
  98. loss = loss(net_out, inputs[1])
  99. acc1 = paddle.metric.accuracy(softmax_out, label=labels, k=1)
  100. k = min(5, self.num_classes)
  101. acck = paddle.metric.accuracy(softmax_out, label=labels, k=k)
  102. outputs = OrderedDict([('loss', loss), ('acc1', acc1),
  103. ('acc{}'.format(k), acck)])
  104. return outputs
  105. def default_optimizer(self, parameters, learning_rate, warmup_steps,
  106. warmup_start_lr, lr_decay_epochs, lr_decay_gamma,
  107. num_steps_each_epoch):
  108. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  109. values = [
  110. learning_rate * (lr_decay_gamma**i)
  111. for i in range(len(lr_decay_epochs) + 1)
  112. ]
  113. scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
  114. if warmup_steps > 0:
  115. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  116. logging.error(
  117. "In function train(), parameters should satisfy: "
  118. "warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  119. exit=False)
  120. logging.error(
  121. "See this doc for more information: "
  122. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  123. exit=False)
  124. logging.error(
  125. "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, "
  126. "please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
  127. format(lr_decay_epochs[0] * num_steps_each_epoch,
  128. warmup_steps // num_steps_each_epoch))
  129. scheduler = paddle.optimizer.lr.LinearWarmup(
  130. learning_rate=scheduler,
  131. warmup_steps=warmup_steps,
  132. start_lr=warmup_start_lr,
  133. end_lr=learning_rate)
  134. optimizer = paddle.optimizer.Momentum(
  135. scheduler,
  136. momentum=.9,
  137. weight_decay=paddle.regularizer.L2Decay(coeff=1e-04),
  138. parameters=parameters)
  139. return optimizer
  140. def train(self,
  141. num_epochs,
  142. train_dataset,
  143. train_batch_size=64,
  144. eval_dataset=None,
  145. optimizer=None,
  146. save_interval_epochs=1,
  147. log_interval_steps=10,
  148. save_dir='output',
  149. pretrain_weights='IMAGENET',
  150. learning_rate=.025,
  151. warmup_steps=0,
  152. warmup_start_lr=0.0,
  153. lr_decay_epochs=(30, 60, 90),
  154. lr_decay_gamma=0.1,
  155. early_stop=False,
  156. early_stop_patience=5,
  157. use_vdl=True):
  158. """
  159. Train the model.
  160. Args:
  161. num_epochs(int): The number of epochs.
  162. train_dataset(paddlex.dataset): Training dataset.
  163. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  164. eval_dataset(paddlex.dataset, optional):
  165. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  166. optimizer(paddle.optimizer.Optimizer or None, optional):
  167. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  168. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  169. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  170. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  171. pretrain_weights(str or None, optional):
  172. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  173. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  174. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  175. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  176. lr_decay_epochs(List[int] or Tuple[int], optional):
  177. Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
  178. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
  179. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  180. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  181. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  182. """
  183. self.labels = train_dataset.labels
  184. # build optimizer if not defined
  185. if optimizer is None:
  186. num_steps_each_epoch = len(train_dataset) // train_batch_size
  187. self.optimizer = self.default_optimizer(
  188. parameters=self.net.parameters(),
  189. learning_rate=learning_rate,
  190. warmup_steps=warmup_steps,
  191. warmup_start_lr=warmup_start_lr,
  192. lr_decay_epochs=lr_decay_epochs,
  193. lr_decay_gamma=lr_decay_gamma,
  194. num_steps_each_epoch=num_steps_each_epoch)
  195. else:
  196. self.optimizer = optimizer
  197. # initiate weights
  198. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  199. if pretrain_weights not in ['IMAGENET']:
  200. logging.warning(
  201. "Path of pretrain_weights('{}') does not exist!".format(
  202. pretrain_weights))
  203. logging.warning(
  204. "Pretrain_weights is forcibly set to 'IMAGENET'. "
  205. "If don't want to use pretrain weights, "
  206. "set pretrain_weights to be None.")
  207. pretrain_weights = 'IMAGENET'
  208. pretrained_dir = osp.join(save_dir, 'pretrain')
  209. self.net_initialize(
  210. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  211. # start train loop
  212. self.train_loop(
  213. num_epochs=num_epochs,
  214. train_dataset=train_dataset,
  215. train_batch_size=train_batch_size,
  216. eval_dataset=eval_dataset,
  217. save_interval_epochs=save_interval_epochs,
  218. log_interval_steps=log_interval_steps,
  219. save_dir=save_dir,
  220. early_stop=early_stop,
  221. early_stop_patience=early_stop_patience,
  222. use_vdl=use_vdl)
  223. def quant_aware_train(self,
  224. num_epochs,
  225. train_dataset,
  226. train_batch_size=64,
  227. eval_dataset=None,
  228. optimizer=None,
  229. save_interval_epochs=1,
  230. log_interval_steps=10,
  231. save_dir='output',
  232. pretrain_weights='IMAGENET',
  233. learning_rate=.025,
  234. warmup_steps=0,
  235. warmup_start_lr=0.0,
  236. lr_decay_epochs=(30, 60, 90),
  237. lr_decay_gamma=0.1,
  238. early_stop=False,
  239. early_stop_patience=5,
  240. use_vdl=True,
  241. infer_image_shape=[-1, -1],
  242. quant_config=None):
  243. """
  244. Quantization-aware training.
  245. Args:
  246. num_epochs(int): The number of epochs.
  247. train_dataset(paddlex.dataset): Training dataset.
  248. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
  249. eval_dataset(paddlex.dataset, optional):
  250. Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
  251. optimizer(paddle.optimizer.Optimizer or None, optional):
  252. Optimizer used for training. If None, a default optimizer is used. Defaults to None.
  253. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  254. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  255. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  256. pretrain_weights(str or None, optional):
  257. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  258. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  259. warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
  260. warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
  261. lr_decay_epochs(List[int] or Tuple[int], optional):
  262. Epoch milestones for learning rate decay. Defaults to (20, 60, 90).
  263. lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay, default .1.
  264. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  265. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  266. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  267. infer_image_shape(List[int], optional): The shape of input images during inference process, in [w, h] format.
  268. If the shape of images is variable, set `infer_image_shape` to [-1, -1]. Defaults to [-1, -1].
  269. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  270. configuration will be used. Defaults to None.
  271. """
  272. self._prepare_qat(quant_config, infer_image_shape)
  273. self.train(
  274. num_epochs=num_epochs,
  275. train_dataset=train_dataset,
  276. train_batch_size=train_batch_size,
  277. eval_dataset=eval_dataset,
  278. optimizer=optimizer,
  279. save_interval_epochs=save_interval_epochs,
  280. log_interval_steps=log_interval_steps,
  281. save_dir=save_dir,
  282. pretrain_weights=pretrain_weights,
  283. learning_rate=learning_rate,
  284. warmup_steps=warmup_steps,
  285. warmup_start_lr=warmup_start_lr,
  286. lr_decay_epochs=lr_decay_epochs,
  287. lr_decay_gamma=lr_decay_gamma,
  288. early_stop=early_stop,
  289. early_stop_patience=early_stop_patience,
  290. use_vdl=use_vdl)
  291. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  292. """
  293. Evaluate the model.
  294. Args:
  295. eval_dataset(paddlex.dataset): Evaluation dataset.
  296. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  297. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  298. Returns:
  299. collections.OrderedDict with key-value pairs: {"acc1": `top 1 accuracy`, "acc5": `top 5 accuracy`}.
  300. """
  301. # 给transform添加arrange操作
  302. arrange_transforms(
  303. model_type=self.model_type,
  304. transforms=eval_dataset.transforms,
  305. mode='eval')
  306. self.net.eval()
  307. nranks = paddle.distributed.get_world_size()
  308. local_rank = paddle.distributed.get_rank()
  309. if nranks > 1:
  310. # Initialize parallel environment if not done.
  311. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  312. ):
  313. paddle.distributed.init_parallel_env()
  314. self.eval_data_loader = self.build_data_loader(
  315. eval_dataset, batch_size=batch_size, mode='eval')
  316. eval_metrics = TrainingStats()
  317. eval_details = None
  318. if return_details:
  319. eval_details = list()
  320. logging.info(
  321. "Start to evaluate(total_samples={}, total_steps={})...".format(
  322. eval_dataset.num_samples,
  323. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  324. with paddle.no_grad():
  325. for step, data in enumerate(self.eval_data_loader()):
  326. outputs = self.run(self.net, data, mode='eval')
  327. if return_details:
  328. eval_details.append(outputs['prediction'].numpy())
  329. outputs.pop('prediction')
  330. eval_metrics.update(outputs)
  331. if return_details:
  332. return eval_metrics.get(), eval_details
  333. else:
  334. return eval_metrics.get()
  335. def predict(self, img_file, transforms=None, topk=1):
  336. """
  337. Do inference.
  338. Args:
  339. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  340. Image path or decoded image data in a BGR format, which also could constitute a list,
  341. meaning all images to be predicted as a mini-batch.
  342. transforms(paddlex.transforms.Compose or None, optional):
  343. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  344. topk(int, optional): Keep topk results in prediction. Defaults to 1.
  345. Returns:
  346. If img_file is a string or np.array, the result is a dict with key-value pairs:
  347. {"category_id": `category_id`, "category": `category`, "score": `score`}.
  348. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  349. category_id(int): the predicted category ID
  350. category(str): category name
  351. score(float): confidence
  352. """
  353. if transforms is None and not hasattr(self, 'test_transforms'):
  354. raise Exception("transforms need to be defined, now is None.")
  355. if transforms is None:
  356. transforms = self.test_transforms
  357. true_topk = min(self.num_classes, topk)
  358. if isinstance(img_file, (str, np.ndarray)):
  359. images = [img_file]
  360. else:
  361. images = img_file
  362. im = self._preprocess(images, transforms, self.model_type)
  363. self.net.eval()
  364. with paddle.no_grad():
  365. outputs = self.run(self.net, im, mode='test')
  366. prediction = outputs['prediction'].numpy()
  367. prediction = self._postprocess(prediction, true_topk, self.labels)
  368. if isinstance(img_file, (str, np.ndarray)):
  369. prediction = prediction[0]
  370. return prediction
  371. def _preprocess(self, images, transforms, model_type):
  372. arrange_transforms(
  373. model_type=model_type, transforms=transforms, mode='test')
  374. batch_im = list()
  375. for im in images:
  376. sample = {'image': im}
  377. batch_im.append(transforms(sample))
  378. batch_im = to_tensor(batch_im)
  379. return batch_im,
  380. def _postprocess(self, results, true_topk, labels):
  381. preds = list()
  382. for i, pred in enumerate(results):
  383. pred_label = np.argsort(pred)[::-1][:true_topk]
  384. preds.append([{
  385. 'category_id': l,
  386. 'category': labels[l],
  387. 'score': results[i][l]
  388. } for l in pred_label])
  389. return preds
  390. class ResNet18(BaseClassifier):
  391. def __init__(self, num_classes=1000):
  392. super(ResNet18, self).__init__(
  393. model_name='ResNet18', num_classes=num_classes)
  394. class ResNet34(BaseClassifier):
  395. def __init__(self, num_classes=1000):
  396. super(ResNet34, self).__init__(
  397. model_name='ResNet34', num_classes=num_classes)
  398. class ResNet50(BaseClassifier):
  399. def __init__(self, num_classes=1000):
  400. super(ResNet50, self).__init__(
  401. model_name='ResNet50', num_classes=num_classes)
  402. class ResNet101(BaseClassifier):
  403. def __init__(self, num_classes=1000):
  404. super(ResNet101, self).__init__(
  405. model_name='ResNet101', num_classes=num_classes)
  406. class ResNet152(BaseClassifier):
  407. def __init__(self, num_classes=1000):
  408. super(ResNet152, self).__init__(
  409. model_name='ResNet152', num_classes=num_classes)
  410. class ResNet18_vd(BaseClassifier):
  411. def __init__(self, num_classes=1000):
  412. super(ResNet18_vd, self).__init__(
  413. model_name='ResNet18_vd', num_classes=num_classes)
  414. class ResNet34_vd(BaseClassifier):
  415. def __init__(self, num_classes=1000):
  416. super(ResNet34_vd, self).__init__(
  417. model_name='ResNet34_vd', num_classes=num_classes)
  418. class ResNet50_vd(BaseClassifier):
  419. def __init__(self, num_classes=1000):
  420. super(ResNet50_vd, self).__init__(
  421. model_name='ResNet50_vd', num_classes=num_classes)
  422. class ResNet50_vd_ssld(BaseClassifier):
  423. def __init__(self, num_classes=1000):
  424. super(ResNet50_vd_ssld, self).__init__(
  425. model_name='ResNet50_vd',
  426. num_classes=num_classes,
  427. lr_mult_list=[.1, .1, .2, .2, .3])
  428. self.model_name = 'ResNet50_vd_ssld'
  429. class ResNet101_vd(BaseClassifier):
  430. def __init__(self, num_classes=1000):
  431. super(ResNet101_vd, self).__init__(
  432. model_name='ResNet101_vd', num_classes=num_classes)
  433. class ResNet101_vd_ssld(BaseClassifier):
  434. def __init__(self, num_classes=1000):
  435. super(ResNet101_vd_ssld, self).__init__(
  436. model_name='ResNet101_vd',
  437. num_classes=num_classes,
  438. lr_mult_list=[.1, .1, .2, .2, .3])
  439. self.model_name = 'ResNet101_vd_ssld'
  440. class ResNet152_vd(BaseClassifier):
  441. def __init__(self, num_classes=1000):
  442. super(ResNet152_vd, self).__init__(
  443. model_name='ResNet152_vd', num_classes=num_classes)
  444. class ResNet200_vd(BaseClassifier):
  445. def __init__(self, num_classes=1000):
  446. super(ResNet200_vd, self).__init__(
  447. model_name='ResNet200_vd', num_classes=num_classes)
  448. class AlexNet(BaseClassifier):
  449. def __init__(self, num_classes=1000):
  450. super(AlexNet, self).__init__(
  451. model_name='AlexNet', num_classes=num_classes)
  452. def get_test_inputs(self, image_shape):
  453. if image_shape == [-1, -1]:
  454. image_shape = [224, 224]
  455. logging.info('When exporting inference model for {},'.format(
  456. self.__class__.__name__
  457. ) + ' if image_shape is [-1, -1], it will be forcibly set to [224, 224]'
  458. )
  459. input_spec = [
  460. InputSpec(
  461. shape=[None, 3] + image_shape, name='image', dtype='float32')
  462. ]
  463. return input_spec
  464. class DarkNet53(BaseClassifier):
  465. def __init__(self, num_classes=1000):
  466. super(DarkNet53, self).__init__(
  467. model_name='DarkNet53', num_classes=num_classes)
  468. class MobileNetV1(BaseClassifier):
  469. def __init__(self, num_classes=1000, scale=1.0):
  470. supported_scale = [.25, .5, .75, 1.0]
  471. if scale not in supported_scale:
  472. logging.warning("scale={} is not supported by MobileNetV1, "
  473. "scale is forcibly set to 1.0".format(scale))
  474. scale = 1.0
  475. if scale == 1:
  476. model_name = 'MobileNetV1'
  477. else:
  478. model_name = 'MobileNetV1_x' + str(scale).replace('.', '_')
  479. self.scale = scale
  480. super(MobileNetV1, self).__init__(
  481. model_name=model_name, num_classes=num_classes)
  482. class MobileNetV2(BaseClassifier):
  483. def __init__(self, num_classes=1000, scale=1.0):
  484. supported_scale = [.25, .5, .75, 1.0, 1.5, 2.0]
  485. if scale not in supported_scale:
  486. logging.warning("scale={} is not supported by MobileNetV2, "
  487. "scale is forcibly set to 1.0".format(scale))
  488. scale = 1.0
  489. if scale == 1:
  490. model_name = 'MobileNetV2'
  491. else:
  492. model_name = 'MobileNetV2_x' + str(scale).replace('.', '_')
  493. super(MobileNetV2, self).__init__(
  494. model_name=model_name, num_classes=num_classes)
  495. class MobileNetV3_small(BaseClassifier):
  496. def __init__(self, num_classes=1000, scale=1.0):
  497. supported_scale = [.35, .5, .75, 1.0, 1.25]
  498. if scale not in supported_scale:
  499. logging.warning("scale={} is not supported by MobileNetV3_small, "
  500. "scale is forcibly set to 1.0".format(scale))
  501. scale = 1.0
  502. model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
  503. '_')
  504. super(MobileNetV3_small, self).__init__(
  505. model_name=model_name,
  506. num_classes=num_classes,
  507. lr_mult_list=[.1, .1, .2, .2, .3])
  508. class MobileNetV3_small_ssld(BaseClassifier):
  509. def __init__(self, num_classes=1000, scale=1.0):
  510. supported_scale = [.35, 1.0]
  511. if scale not in supported_scale:
  512. logging.warning(
  513. "scale={} is not supported by MobileNetV3_small_ssld, "
  514. "scale is forcibly set to 1.0".format(scale))
  515. scale = 1.0
  516. model_name = 'MobileNetV3_small_x' + str(float(scale)).replace('.',
  517. '_')
  518. super(MobileNetV3_small_ssld, self).__init__(
  519. model_name=model_name, num_classes=num_classes)
  520. self.model_name = model_name + '_ssld'
  521. class MobileNetV3_large(BaseClassifier):
  522. def __init__(self, num_classes=1000, scale=1.0):
  523. supported_scale = [.35, .5, .75, 1.0, 1.25]
  524. if scale not in supported_scale:
  525. logging.warning("scale={} is not supported by MobileNetV3_large, "
  526. "scale is forcibly set to 1.0".format(scale))
  527. scale = 1.0
  528. model_name = 'MobileNetV3_large_x' + str(float(scale)).replace('.',
  529. '_')
  530. super(MobileNetV3_large, self).__init__(
  531. model_name=model_name, num_classes=num_classes)
  532. class MobileNetV3_large_ssld(BaseClassifier):
  533. def __init__(self, num_classes=1000):
  534. super(MobileNetV3_large_ssld, self).__init__(
  535. model_name='MobileNetV3_large_x1_0', num_classes=num_classes)
  536. self.model_name = 'MobileNetV3_large_x1_0_ssld'
  537. class DenseNet121(BaseClassifier):
  538. def __init__(self, num_classes=1000):
  539. super(DenseNet121, self).__init__(
  540. model_name='DenseNet121', num_classes=num_classes)
  541. class DenseNet161(BaseClassifier):
  542. def __init__(self, num_classes=1000):
  543. super(DenseNet161, self).__init__(
  544. model_name='DenseNet161', num_classes=num_classes)
  545. class DenseNet169(BaseClassifier):
  546. def __init__(self, num_classes=1000):
  547. super(DenseNet169, self).__init__(
  548. model_name='DenseNet169', num_classes=num_classes)
  549. class DenseNet201(BaseClassifier):
  550. def __init__(self, num_classes=1000):
  551. super(DenseNet201, self).__init__(
  552. model_name='DenseNet201', num_classes=num_classes)
  553. class DenseNet264(BaseClassifier):
  554. def __init__(self, num_classes=1000):
  555. super(DenseNet264, self).__init__(
  556. model_name='DenseNet264', num_classes=num_classes)
  557. class HRNet_W18_C(BaseClassifier):
  558. def __init__(self, num_classes=1000):
  559. super(HRNet_W18_C, self).__init__(
  560. model_name='HRNet_W18_C', num_classes=num_classes)
  561. class HRNet_W30_C(BaseClassifier):
  562. def __init__(self, num_classes=1000):
  563. super(HRNet_W30_C, self).__init__(
  564. model_name='HRNet_W30_C', num_classes=num_classes)
  565. class HRNet_W32_C(BaseClassifier):
  566. def __init__(self, num_classes=1000):
  567. super(HRNet_W32_C, self).__init__(
  568. model_name='HRNet_W32_C', num_classes=num_classes)
  569. class HRNet_W40_C(BaseClassifier):
  570. def __init__(self, num_classes=1000):
  571. super(HRNet_W40_C, self).__init__(
  572. model_name='HRNet_W40_C', num_classes=num_classes)
  573. class HRNet_W44_C(BaseClassifier):
  574. def __init__(self, num_classes=1000):
  575. super(HRNet_W44_C, self).__init__(
  576. model_name='HRNet_W44_C', num_classes=num_classes)
  577. class HRNet_W48_C(BaseClassifier):
  578. def __init__(self, num_classes=1000):
  579. super(HRNet_W48_C, self).__init__(
  580. model_name='HRNet_W48_C', num_classes=num_classes)
  581. class HRNet_W64_C(BaseClassifier):
  582. def __init__(self, num_classes=1000):
  583. super(HRNet_W64_C, self).__init__(
  584. model_name='HRNet_W64_C', num_classes=num_classes)
  585. class Xception41(BaseClassifier):
  586. def __init__(self, num_classes=1000):
  587. super(Xception41, self).__init__(
  588. model_name='Xception41', num_classes=num_classes)
  589. class Xception65(BaseClassifier):
  590. def __init__(self, num_classes=1000):
  591. super(Xception65, self).__init__(
  592. model_name='Xception65', num_classes=num_classes)
  593. class Xception71(BaseClassifier):
  594. def __init__(self, num_classes=1000):
  595. super(Xception71, self).__init__(
  596. model_name='Xception71', num_classes=num_classes)
  597. class ShuffleNetV2(BaseClassifier):
  598. def __init__(self, num_classes=1000, scale=1.0):
  599. supported_scale = [.25, .33, .5, 1.0, 1.5, 2.0]
  600. if scale not in supported_scale:
  601. logging.warning("scale={} is not supported by ShuffleNetV2, "
  602. "scale is forcibly set to 1.0".format(scale))
  603. scale = 1.0
  604. model_name = 'ShuffleNetV2_x' + str(float(scale)).replace('.', '_')
  605. super(ShuffleNetV2, self).__init__(
  606. model_name=model_name, num_classes=num_classes)
  607. def get_test_inputs(self, image_shape):
  608. if image_shape == [-1, -1]:
  609. image_shape = [224, 224]
  610. logging.info('When exporting inference model for {},'.format(
  611. self.__class__.__name__
  612. ) + ' if image_shape is [-1, -1], it will be forcibly set to [224, 224]'
  613. )
  614. input_spec = [
  615. InputSpec(
  616. shape=[None, 3] + image_shape, name='image', dtype='float32')
  617. ]
  618. return input_spec
  619. class ShuffleNetV2_swish(BaseClassifier):
  620. def __init__(self, num_classes=1000):
  621. super(ShuffleNetV2_swish, self).__init__(
  622. model_name='ShuffleNetV2_x1_5', num_classes=num_classes)
  623. def get_test_inputs(self, image_shape):
  624. if image_shape == [-1, -1]:
  625. image_shape = [224, 224]
  626. logging.info('When exporting inference model for {},'.format(
  627. self.__class__.__name__
  628. ) + ' if image_shape is [-1, -1], it will be forcibly set to [224, 224]'
  629. )
  630. input_spec = [
  631. InputSpec(
  632. shape=[None, 3] + image_shape, name='image', dtype='float32')
  633. ]
  634. return input_spec