segmenter.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. from collections import OrderedDict
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddle.static import InputSpec
  21. import paddleseg
  22. import paddlex
  23. from paddlex.cv.transforms import arrange_transforms
  24. from paddlex.utils import get_single_card_bs, DisablePrint
  25. import paddlex.utils.logging as logging
  26. from .base import BaseModel
  27. from .utils import seg_metrics as metrics
  28. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  29. from paddlex.cv.transforms import Decode, Resize
  30. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
  31. class BaseSegmenter(BaseModel):
  32. def __init__(self,
  33. model_name,
  34. num_classes=2,
  35. use_mixed_loss=False,
  36. **params):
  37. self.init_params = locals()
  38. super(BaseSegmenter, self).__init__('segmenter')
  39. if not hasattr(paddleseg.models, model_name):
  40. raise Exception("ERROR: There's no model named {}.".format(
  41. model_name))
  42. self.model_name = model_name
  43. self.num_classes = num_classes
  44. self.use_mixed_loss = use_mixed_loss
  45. self.losses = None
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. self.find_unused_parameters = True
  49. def build_net(self, **params):
  50. # TODO: when using paddle.utils.unique_name.guard,
  51. # DeepLabv3p and HRNet will raise a error
  52. net = paddleseg.models.__dict__[self.model_name](
  53. num_classes=self.num_classes, **params)
  54. return net
  55. def _fix_transforms_shape(self, image_shape):
  56. if hasattr(self, 'test_transforms'):
  57. if self.test_transforms is not None:
  58. has_resize_op = False
  59. resize_op_idx = -1
  60. normalize_op_idx = len(self.test_transforms.transforms)
  61. for idx, op in enumerate(self.test_transforms.transforms):
  62. name = op.__class__.__name__
  63. if name == 'Normalize':
  64. normalize_op_idx = idx
  65. if 'Resize' in name:
  66. has_resize_op = True
  67. resize_op_idx = idx
  68. if not has_resize_op:
  69. self.test_transforms.transforms.insert(
  70. normalize_op_idx, Resize(target_size=image_shape))
  71. else:
  72. self.test_transforms.transforms[resize_op_idx] = Resize(
  73. target_size=image_shape)
  74. def _get_test_inputs(self, image_shape):
  75. if image_shape is not None:
  76. if len(image_shape) == 2:
  77. image_shape = [1, 3] + image_shape
  78. self._fix_transforms_shape(image_shape[-2:])
  79. else:
  80. image_shape = [None, 3, -1, -1]
  81. self.fixed_input_shape = image_shape
  82. input_spec = [
  83. InputSpec(
  84. shape=image_shape, name='image', dtype='float32')
  85. ]
  86. return input_spec
  87. def run(self, net, inputs, mode):
  88. net_out = net(inputs[0])
  89. logit = net_out[0]
  90. outputs = OrderedDict()
  91. if mode == 'test':
  92. origin_shape = inputs[1]
  93. score_map = self._postprocess(
  94. logit, origin_shape, transforms=inputs[2])
  95. label_map = paddle.argmax(
  96. score_map, axis=1, keepdim=True, dtype='int32')
  97. score_map = paddle.max(score_map, axis=1, keepdim=True)
  98. score_map = paddle.squeeze(score_map)
  99. label_map = paddle.squeeze(label_map)
  100. outputs = {'label_map': label_map, 'score_map': score_map}
  101. if mode == 'eval':
  102. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  103. label = inputs[1]
  104. origin_shape = [label.shape[-2:]]
  105. # TODO: 替换cv2后postprocess移出run
  106. pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
  107. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  108. pred, label, self.num_classes)
  109. outputs['intersect_area'] = intersect_area
  110. outputs['pred_area'] = pred_area
  111. outputs['label_area'] = label_area
  112. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  113. self.num_classes)
  114. if mode == 'train':
  115. loss_list = metrics.loss_computation(
  116. logits_list=net_out, labels=inputs[1], losses=self.losses)
  117. loss = sum(loss_list)
  118. outputs['loss'] = loss
  119. return outputs
  120. def default_loss(self):
  121. if isinstance(self.use_mixed_loss, bool):
  122. if self.use_mixed_loss:
  123. losses = [
  124. paddleseg.models.CrossEntropyLoss(),
  125. paddleseg.models.LovaszSoftmaxLoss()
  126. ]
  127. coef = [.8, .2]
  128. loss_type = [
  129. paddleseg.models.MixedLoss(
  130. losses=losses, coef=coef),
  131. ]
  132. else:
  133. loss_type = [paddleseg.models.CrossEntropyLoss()]
  134. else:
  135. losses, coef = list(zip(*self.use_mixed_loss))
  136. if not set(losses).issubset(
  137. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  138. raise ValueError(
  139. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  140. )
  141. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  142. loss_type = [
  143. paddleseg.models.MixedLoss(
  144. losses=losses, coef=list(coef))
  145. ]
  146. if self.model_name == 'FastSCNN':
  147. loss_type *= 2
  148. loss_coef = [1.0, 0.4]
  149. elif self.model_name == 'BiSeNetV2':
  150. loss_type *= 5
  151. loss_coef = [1.0] * 5
  152. else:
  153. loss_coef = [1.0]
  154. losses = {'types': loss_type, 'coef': loss_coef}
  155. return losses
  156. def default_optimizer(self,
  157. parameters,
  158. learning_rate,
  159. num_epochs,
  160. num_steps_each_epoch,
  161. lr_decay_power=0.9):
  162. decay_step = num_epochs * num_steps_each_epoch
  163. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  164. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  165. optimizer = paddle.optimizer.Momentum(
  166. learning_rate=lr_scheduler,
  167. parameters=parameters,
  168. momentum=0.9,
  169. weight_decay=4e-5)
  170. return optimizer
  171. def train(self,
  172. num_epochs,
  173. train_dataset,
  174. train_batch_size=2,
  175. eval_dataset=None,
  176. optimizer=None,
  177. save_interval_epochs=1,
  178. log_interval_steps=2,
  179. save_dir='output',
  180. pretrain_weights='CITYSCAPES',
  181. learning_rate=0.01,
  182. lr_decay_power=0.9,
  183. early_stop=False,
  184. early_stop_patience=5,
  185. use_vdl=True):
  186. """
  187. Train the model.
  188. Args:
  189. num_epochs(int): The number of epochs.
  190. train_dataset(paddlex.dataset): Training dataset.
  191. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  192. eval_dataset(paddlex.dataset, optional):
  193. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  194. optimizer(paddle.optimizer.Optimizer or None, optional):
  195. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  196. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  197. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  198. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  199. pretrain_weights(str or None, optional):
  200. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
  201. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  202. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  203. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  204. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  205. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  206. """
  207. self.labels = train_dataset.labels
  208. if self.losses is None:
  209. self.losses = self.default_loss()
  210. if optimizer is None:
  211. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  212. self.optimizer = self.default_optimizer(
  213. self.net.parameters(), learning_rate, num_epochs,
  214. num_steps_each_epoch, lr_decay_power)
  215. else:
  216. self.optimizer = optimizer
  217. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  218. if pretrain_weights not in seg_pretrain_weights_dict[
  219. self.model_name]:
  220. logging.warning(
  221. "Path of pretrain_weights('{}') does not exist!".format(
  222. pretrain_weights))
  223. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  224. "If don't want to use pretrain weights, "
  225. "set pretrain_weights to be None.".format(
  226. seg_pretrain_weights_dict[self.model_name][
  227. 0]))
  228. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  229. 0]
  230. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  231. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  232. logging.error(
  233. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  234. exit=True)
  235. pretrained_dir = osp.join(save_dir, 'pretrain')
  236. self.net_initialize(
  237. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  238. self.train_loop(
  239. num_epochs=num_epochs,
  240. train_dataset=train_dataset,
  241. train_batch_size=train_batch_size,
  242. eval_dataset=eval_dataset,
  243. save_interval_epochs=save_interval_epochs,
  244. log_interval_steps=log_interval_steps,
  245. save_dir=save_dir,
  246. early_stop=early_stop,
  247. early_stop_patience=early_stop_patience,
  248. use_vdl=use_vdl)
  249. def quant_aware_train(self,
  250. num_epochs,
  251. train_dataset,
  252. train_batch_size=2,
  253. eval_dataset=None,
  254. optimizer=None,
  255. save_interval_epochs=1,
  256. log_interval_steps=2,
  257. save_dir='output',
  258. learning_rate=0.0001,
  259. lr_decay_power=0.9,
  260. early_stop=False,
  261. early_stop_patience=5,
  262. use_vdl=True,
  263. quant_config=None):
  264. """
  265. Quantization-aware training.
  266. Args:
  267. num_epochs(int): The number of epochs.
  268. train_dataset(paddlex.dataset): Training dataset.
  269. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  270. eval_dataset(paddlex.dataset, optional):
  271. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  272. optimizer(paddle.optimizer.Optimizer or None, optional):
  273. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  274. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  275. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  276. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  277. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  278. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  279. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  280. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  281. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  282. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  283. configuration will be used. Defaults to None.
  284. """
  285. self._prepare_qat(quant_config)
  286. self.train(
  287. num_epochs=num_epochs,
  288. train_dataset=train_dataset,
  289. train_batch_size=train_batch_size,
  290. eval_dataset=eval_dataset,
  291. optimizer=optimizer,
  292. save_interval_epochs=save_interval_epochs,
  293. log_interval_steps=log_interval_steps,
  294. save_dir=save_dir,
  295. pretrain_weights=None,
  296. learning_rate=learning_rate,
  297. lr_decay_power=lr_decay_power,
  298. early_stop=early_stop,
  299. early_stop_patience=early_stop_patience,
  300. use_vdl=use_vdl)
  301. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  302. """
  303. Evaluate the model.
  304. Args:
  305. eval_dataset(paddlex.dataset): Evaluation dataset.
  306. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  307. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  308. Returns:
  309. collections.OrderedDict with key-value pairs:
  310. {"miou": `mean intersection over union`,
  311. "category_iou": `category-wise mean intersection over union`,
  312. "oacc": `overall accuracy`,
  313. "category_acc": `category-wise accuracy`,
  314. "kappa": ` kappa coefficient`,
  315. "category_F1-score": `F1 score`}.
  316. """
  317. arrange_transforms(
  318. model_type=self.model_type,
  319. transforms=eval_dataset.transforms,
  320. mode='eval')
  321. self.net.eval()
  322. nranks = paddle.distributed.get_world_size()
  323. local_rank = paddle.distributed.get_rank()
  324. if nranks > 1:
  325. # Initialize parallel environment if not done.
  326. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  327. ):
  328. paddle.distributed.init_parallel_env()
  329. batch_size_each_card = get_single_card_bs(batch_size)
  330. if batch_size_each_card > 1:
  331. batch_size_each_card = 1
  332. batch_size = batch_size_each_card * paddlex.env_info['num']
  333. logging.warning(
  334. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  335. "during evaluation, so batch_size " \
  336. "is forcibly set to {}.".format(batch_size))
  337. self.eval_data_loader = self.build_data_loader(
  338. eval_dataset, batch_size=batch_size, mode='eval')
  339. intersect_area_all = 0
  340. pred_area_all = 0
  341. label_area_all = 0
  342. conf_mat_all = []
  343. logging.info(
  344. "Start to evaluate(total_samples={}, total_steps={})...".format(
  345. eval_dataset.num_samples,
  346. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  347. with paddle.no_grad():
  348. for step, data in enumerate(self.eval_data_loader):
  349. data.append(eval_dataset.transforms.transforms)
  350. outputs = self.run(self.net, data, 'eval')
  351. pred_area = outputs['pred_area']
  352. label_area = outputs['label_area']
  353. intersect_area = outputs['intersect_area']
  354. conf_mat = outputs['conf_mat']
  355. # Gather from all ranks
  356. if nranks > 1:
  357. intersect_area_list = []
  358. pred_area_list = []
  359. label_area_list = []
  360. conf_mat_list = []
  361. paddle.distributed.all_gather(intersect_area_list,
  362. intersect_area)
  363. paddle.distributed.all_gather(pred_area_list, pred_area)
  364. paddle.distributed.all_gather(label_area_list, label_area)
  365. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  366. # Some image has been evaluated and should be eliminated in last iter
  367. if (step + 1) * nranks > len(eval_dataset):
  368. valid = len(eval_dataset) - step * nranks
  369. intersect_area_list = intersect_area_list[:valid]
  370. pred_area_list = pred_area_list[:valid]
  371. label_area_list = label_area_list[:valid]
  372. conf_mat_list = conf_mat_list[:valid]
  373. intersect_area_all += sum(intersect_area_list)
  374. pred_area_all += sum(pred_area_list)
  375. label_area_all += sum(label_area_list)
  376. conf_mat_all.extend(conf_mat_list)
  377. else:
  378. intersect_area_all = intersect_area_all + intersect_area
  379. pred_area_all = pred_area_all + pred_area
  380. label_area_all = label_area_all + label_area
  381. conf_mat_all.append(conf_mat)
  382. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  383. intersect_area_all, pred_area_all, label_area_all)
  384. # TODO 确认是按oacc还是macc
  385. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  386. pred_area_all)
  387. kappa = paddleseg.utils.metrics.kappa(intersect_area_all,
  388. pred_area_all, label_area_all)
  389. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  390. label_area_all)
  391. eval_metrics = OrderedDict(
  392. zip([
  393. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  394. 'category_F1-score'
  395. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  396. if return_details:
  397. conf_mat = sum(conf_mat_all)
  398. eval_details = {'confusion_matrix': conf_mat.tolist()}
  399. return eval_metrics, eval_details
  400. return eval_metrics
  401. def predict(self, img_file, transforms=None):
  402. """
  403. Do inference.
  404. Args:
  405. Args:
  406. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  407. Image path or decoded image data in a BGR format, which also could constitute a list,
  408. meaning all images to be predicted as a mini-batch.
  409. transforms(paddlex.transforms.Compose or None, optional):
  410. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  411. Returns:
  412. If img_file is a string or np.array, the result is a dict with key-value pairs:
  413. {"label map": `label map`, "score_map": `score map`}.
  414. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  415. label_map(np.ndarray): the predicted label map
  416. score_map(np.ndarray): the prediction score map
  417. """
  418. if transforms is None and not hasattr(self, 'test_transforms'):
  419. raise Exception("transforms need to be defined, now is None.")
  420. if transforms is None:
  421. transforms = self.test_transforms
  422. if isinstance(img_file, (str, np.ndarray)):
  423. images = [img_file]
  424. else:
  425. images = img_file
  426. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  427. self.model_type)
  428. self.net.eval()
  429. data = (batch_im, batch_origin_shape, transforms.transforms)
  430. outputs = self.run(self.net, data, 'test')
  431. label_map = outputs['label_map']
  432. label_map = label_map.numpy().astype('uint8')
  433. score_map = outputs['score_map']
  434. score_map = score_map.numpy().astype('float32')
  435. prediction = [{
  436. 'label_map': l,
  437. 'score_map': s
  438. } for l, s in zip(label_map, score_map)]
  439. if isinstance(img_file, (str, np.ndarray)):
  440. prediction = prediction[0]
  441. return prediction
  442. def _preprocess(self, images, transforms, model_type):
  443. arrange_transforms(
  444. model_type=model_type, transforms=transforms, mode='test')
  445. batch_im = list()
  446. batch_ori_shape = list()
  447. for im in images:
  448. sample = {'image': im}
  449. if isinstance(sample['image'], str):
  450. sample = Decode(to_rgb=False)(sample)
  451. ori_shape = sample['image'].shape[:2]
  452. im = transforms(sample)[0]
  453. batch_im.append(im)
  454. batch_ori_shape.append(ori_shape)
  455. batch_im = paddle.to_tensor(batch_im)
  456. return batch_im, batch_ori_shape
  457. @staticmethod
  458. def get_transforms_shape_info(batch_ori_shape, transforms):
  459. batch_restore_list = list()
  460. for ori_shape in batch_ori_shape:
  461. restore_list = list()
  462. h, w = ori_shape[0], ori_shape[1]
  463. for op in transforms:
  464. if op.__class__.__name__ in ['Resize', 'ResizeByShort']:
  465. restore_list.append(('resize', (h, w)))
  466. h, w = op.target_size
  467. if op.__class__.__name__ in ['Padding']:
  468. restore_list.append(('padding', (h, w)))
  469. h, w = op.target_size
  470. batch_restore_list.append(restore_list)
  471. return batch_restore_list
  472. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  473. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  474. batch_origin_shape, transforms)
  475. results = list()
  476. for pred, restore_list in zip(batch_pred, batch_restore_list):
  477. pred = paddle.unsqueeze(pred, axis=0)
  478. for item in restore_list[::-1]:
  479. # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
  480. h, w = item[1][0], item[1][1]
  481. if item[0] == 'resize':
  482. pred = F.interpolate(pred, (h, w), mode='nearest')
  483. elif item[0] == 'padding':
  484. pred = pred[:, :, 0:h, 0:w]
  485. else:
  486. pass
  487. results.append(pred)
  488. batch_pred = paddle.concat(results, axis=0)
  489. return batch_pred
  490. class UNet(BaseSegmenter):
  491. def __init__(self,
  492. num_classes=2,
  493. use_mixed_loss=False,
  494. use_deconv=False,
  495. align_corners=False):
  496. params = {'use_deconv': use_deconv, 'align_corners': align_corners}
  497. super(UNet, self).__init__(
  498. model_name='UNet',
  499. num_classes=num_classes,
  500. use_mixed_loss=use_mixed_loss,
  501. **params)
  502. class DeepLabV3P(BaseSegmenter):
  503. def __init__(self,
  504. num_classes=2,
  505. backbone='ResNet50_vd',
  506. use_mixed_loss=False,
  507. output_stride=8,
  508. backbone_indices=(0, 3),
  509. aspp_ratios=(1, 12, 24, 36),
  510. aspp_out_channels=256,
  511. align_corners=False):
  512. self.backbone_name = backbone
  513. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  514. raise ValueError(
  515. "backbone: {} is not supported. Please choose one of "
  516. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  517. with DisablePrint():
  518. backbone = getattr(paddleseg.models, backbone)(
  519. output_stride=output_stride)
  520. params = {
  521. 'backbone': backbone,
  522. 'backbone_indices': backbone_indices,
  523. 'aspp_ratios': aspp_ratios,
  524. 'aspp_out_channels': aspp_out_channels,
  525. 'align_corners': align_corners
  526. }
  527. super(DeepLabV3P, self).__init__(
  528. model_name='DeepLabV3P',
  529. num_classes=num_classes,
  530. use_mixed_loss=use_mixed_loss,
  531. **params)
  532. class FastSCNN(BaseSegmenter):
  533. def __init__(self,
  534. num_classes=2,
  535. use_mixed_loss=False,
  536. align_corners=False):
  537. params = {'align_corners': align_corners}
  538. super(FastSCNN, self).__init__(
  539. model_name='FastSCNN',
  540. num_classes=num_classes,
  541. use_mixed_loss=use_mixed_loss,
  542. **params)
  543. class HRNet(BaseSegmenter):
  544. def __init__(self,
  545. num_classes=2,
  546. width=48,
  547. use_mixed_loss=False,
  548. align_corners=False):
  549. if width not in (18, 48):
  550. raise ValueError(
  551. "width={} is not supported, please choose from [18, 48]".
  552. format(width))
  553. self.backbone_name = 'HRNet_W{}'.format(width)
  554. with DisablePrint():
  555. backbone = getattr(paddleseg.models, self.backbone_name)(
  556. align_corners=align_corners)
  557. params = {'backbone': backbone, 'align_corners': align_corners}
  558. super(HRNet, self).__init__(
  559. model_name='FCN',
  560. num_classes=num_classes,
  561. use_mixed_loss=use_mixed_loss,
  562. **params)
  563. self.model_name = 'HRNet'
  564. class BiSeNetV2(BaseSegmenter):
  565. def __init__(self,
  566. num_classes=2,
  567. use_mixed_loss=False,
  568. align_corners=False):
  569. params = {'align_corners': align_corners}
  570. super(BiSeNetV2, self).__init__(
  571. model_name='BiSeNetV2',
  572. num_classes=num_classes,
  573. use_mixed_loss=use_mixed_loss,
  574. **params)