segmenter.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. from collections import OrderedDict
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddle.static import InputSpec
  21. import paddleseg
  22. import paddlex
  23. from paddlex.cv.transforms import arrange_transforms
  24. from paddlex.utils import get_single_card_bs, DisablePrint
  25. import paddlex.utils.logging as logging
  26. from .base import BaseModel
  27. from .utils import seg_metrics as metrics
  28. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  29. from paddlex.cv.transforms import Decode, Resize
  30. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
  31. class BaseSegmenter(BaseModel):
  32. def __init__(self,
  33. model_name,
  34. num_classes=2,
  35. use_mixed_loss=False,
  36. **params):
  37. self.init_params = locals()
  38. super(BaseSegmenter, self).__init__('segmenter')
  39. if not hasattr(paddleseg.models, model_name):
  40. raise Exception("ERROR: There's no model named {}.".format(
  41. model_name))
  42. self.model_name = model_name
  43. self.num_classes = num_classes
  44. self.use_mixed_loss = use_mixed_loss
  45. self.losses = None
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. self.find_unused_parameters = True
  49. def build_net(self, **params):
  50. # TODO: when using paddle.utils.unique_name.guard,
  51. # DeepLabv3p and HRNet will raise a error
  52. net = paddleseg.models.__dict__[self.model_name](
  53. num_classes=self.num_classes, **params)
  54. return net
  55. def _fix_transforms_shape(self, image_shape):
  56. if hasattr(self, 'test_transforms'):
  57. if self.test_transforms is not None:
  58. has_resize_op = False
  59. resize_op_idx = -1
  60. normalize_op_idx = len(self.test_transforms.transforms)
  61. for idx, op in enumerate(self.test_transforms.transforms):
  62. name = op.__class__.__name__
  63. if name == 'Normalize':
  64. normalize_op_idx = idx
  65. if 'Resize' in name:
  66. has_resize_op = True
  67. resize_op_idx = idx
  68. if not has_resize_op:
  69. self.test_transforms.transforms.insert(
  70. normalize_op_idx, Resize(target_size=image_shape))
  71. else:
  72. self.test_transforms.transforms[resize_op_idx] = Resize(
  73. target_size=image_shape)
  74. def _get_test_inputs(self, image_shape):
  75. if image_shape is not None:
  76. if len(image_shape) == 2:
  77. image_shape = [None, 3] + image_shape
  78. self._fix_transforms_shape(image_shape[-2:])
  79. else:
  80. image_shape = [None, 3, -1, -1]
  81. input_spec = [
  82. InputSpec(
  83. shape=image_shape, name='image', dtype='float32')
  84. ]
  85. return input_spec
  86. def run(self, net, inputs, mode):
  87. net_out = net(inputs[0])
  88. logit = net_out[0]
  89. outputs = OrderedDict()
  90. if mode == 'test':
  91. origin_shape = inputs[1]
  92. score_map = self._postprocess(
  93. logit, origin_shape, transforms=inputs[2])
  94. label_map = paddle.argmax(
  95. score_map, axis=1, keepdim=True, dtype='int32')
  96. score_map = paddle.max(score_map, axis=1, keepdim=True)
  97. score_map = paddle.squeeze(score_map)
  98. label_map = paddle.squeeze(label_map)
  99. outputs = {'label_map': label_map, 'score_map': score_map}
  100. if mode == 'eval':
  101. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  102. label = inputs[1]
  103. origin_shape = [label.shape[-2:]]
  104. # TODO: 替换cv2后postprocess移出run
  105. pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
  106. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  107. pred, label, self.num_classes)
  108. outputs['intersect_area'] = intersect_area
  109. outputs['pred_area'] = pred_area
  110. outputs['label_area'] = label_area
  111. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  112. self.num_classes)
  113. if mode == 'train':
  114. loss_list = metrics.loss_computation(
  115. logits_list=net_out, labels=inputs[1], losses=self.losses)
  116. loss = sum(loss_list)
  117. outputs['loss'] = loss
  118. return outputs
  119. def default_loss(self):
  120. if isinstance(self.use_mixed_loss, bool):
  121. if self.use_mixed_loss:
  122. losses = [
  123. paddleseg.models.CrossEntropyLoss(),
  124. paddleseg.models.LovaszSoftmaxLoss()
  125. ]
  126. coef = [.8, .2]
  127. loss_type = [
  128. paddleseg.models.MixedLoss(
  129. losses=losses, coef=coef),
  130. ]
  131. else:
  132. loss_type = [paddleseg.models.CrossEntropyLoss()]
  133. else:
  134. losses, coef = list(zip(*self.use_mixed_loss))
  135. if not set(losses).issubset(
  136. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  137. raise ValueError(
  138. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  139. )
  140. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  141. loss_type = [
  142. paddleseg.models.MixedLoss(
  143. losses=losses, coef=list(coef))
  144. ]
  145. if self.model_name == 'FastSCNN':
  146. loss_type *= 2
  147. loss_coef = [1.0, 0.4]
  148. elif self.model_name == 'BiSeNetV2':
  149. loss_type *= 5
  150. loss_coef = [1.0] * 5
  151. else:
  152. loss_coef = [1.0]
  153. losses = {'types': loss_type, 'coef': loss_coef}
  154. return losses
  155. def default_optimizer(self,
  156. parameters,
  157. learning_rate,
  158. num_epochs,
  159. num_steps_each_epoch,
  160. lr_decay_power=0.9):
  161. decay_step = num_epochs * num_steps_each_epoch
  162. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  163. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  164. optimizer = paddle.optimizer.Momentum(
  165. learning_rate=lr_scheduler,
  166. parameters=parameters,
  167. momentum=0.9,
  168. weight_decay=4e-5)
  169. return optimizer
  170. def train(self,
  171. num_epochs,
  172. train_dataset,
  173. train_batch_size=2,
  174. eval_dataset=None,
  175. optimizer=None,
  176. save_interval_epochs=1,
  177. log_interval_steps=2,
  178. save_dir='output',
  179. pretrain_weights='CITYSCAPES',
  180. learning_rate=0.01,
  181. lr_decay_power=0.9,
  182. early_stop=False,
  183. early_stop_patience=5,
  184. use_vdl=True):
  185. """
  186. Train the model.
  187. Args:
  188. num_epochs(int): The number of epochs.
  189. train_dataset(paddlex.dataset): Training dataset.
  190. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  191. eval_dataset(paddlex.dataset, optional):
  192. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  193. optimizer(paddle.optimizer.Optimizer or None, optional):
  194. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  195. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  196. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  197. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  198. pretrain_weights(str or None, optional):
  199. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
  200. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  201. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  202. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  203. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  204. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  205. """
  206. self.labels = train_dataset.labels
  207. if self.losses is None:
  208. self.losses = self.default_loss()
  209. if optimizer is None:
  210. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  211. self.optimizer = self.default_optimizer(
  212. self.net.parameters(), learning_rate, num_epochs,
  213. num_steps_each_epoch, lr_decay_power)
  214. else:
  215. self.optimizer = optimizer
  216. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  217. if pretrain_weights not in seg_pretrain_weights_dict[
  218. self.model_name]:
  219. logging.warning(
  220. "Path of pretrain_weights('{}') does not exist!".format(
  221. pretrain_weights))
  222. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  223. "If don't want to use pretrain weights, "
  224. "set pretrain_weights to be None.".format(
  225. seg_pretrain_weights_dict[self.model_name][
  226. 0]))
  227. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  228. 0]
  229. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  230. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  231. logging.error(
  232. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  233. exit=True)
  234. pretrained_dir = osp.join(save_dir, 'pretrain')
  235. self.net_initialize(
  236. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  237. self.train_loop(
  238. num_epochs=num_epochs,
  239. train_dataset=train_dataset,
  240. train_batch_size=train_batch_size,
  241. eval_dataset=eval_dataset,
  242. save_interval_epochs=save_interval_epochs,
  243. log_interval_steps=log_interval_steps,
  244. save_dir=save_dir,
  245. early_stop=early_stop,
  246. early_stop_patience=early_stop_patience,
  247. use_vdl=use_vdl)
  248. def quant_aware_train(self,
  249. num_epochs,
  250. train_dataset,
  251. train_batch_size=2,
  252. eval_dataset=None,
  253. optimizer=None,
  254. save_interval_epochs=1,
  255. log_interval_steps=2,
  256. save_dir='output',
  257. learning_rate=0.0001,
  258. lr_decay_power=0.9,
  259. early_stop=False,
  260. early_stop_patience=5,
  261. use_vdl=True,
  262. quant_config=None):
  263. """
  264. Quantization-aware training.
  265. Args:
  266. num_epochs(int): The number of epochs.
  267. train_dataset(paddlex.dataset): Training dataset.
  268. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  269. eval_dataset(paddlex.dataset, optional):
  270. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  271. optimizer(paddle.optimizer.Optimizer or None, optional):
  272. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  273. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  274. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  275. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  276. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  277. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  278. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  279. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  280. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  281. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  282. configuration will be used. Defaults to None.
  283. """
  284. self._prepare_qat(quant_config)
  285. self.train(
  286. num_epochs=num_epochs,
  287. train_dataset=train_dataset,
  288. train_batch_size=train_batch_size,
  289. eval_dataset=eval_dataset,
  290. optimizer=optimizer,
  291. save_interval_epochs=save_interval_epochs,
  292. log_interval_steps=log_interval_steps,
  293. save_dir=save_dir,
  294. pretrain_weights=None,
  295. learning_rate=learning_rate,
  296. lr_decay_power=lr_decay_power,
  297. early_stop=early_stop,
  298. early_stop_patience=early_stop_patience,
  299. use_vdl=use_vdl)
  300. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  301. """
  302. Evaluate the model.
  303. Args:
  304. eval_dataset(paddlex.dataset): Evaluation dataset.
  305. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  306. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  307. Returns:
  308. collections.OrderedDict with key-value pairs:
  309. {"miou": `mean intersection over union`,
  310. "category_iou": `category-wise mean intersection over union`,
  311. "oacc": `overall accuracy`,
  312. "category_acc": `category-wise accuracy`,
  313. "kappa": ` kappa coefficient`,
  314. "category_F1-score": `F1 score`}.
  315. """
  316. arrange_transforms(
  317. model_type=self.model_type,
  318. transforms=eval_dataset.transforms,
  319. mode='eval')
  320. self.net.eval()
  321. nranks = paddle.distributed.get_world_size()
  322. local_rank = paddle.distributed.get_rank()
  323. if nranks > 1:
  324. # Initialize parallel environment if not done.
  325. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  326. ):
  327. paddle.distributed.init_parallel_env()
  328. batch_size_each_card = get_single_card_bs(batch_size)
  329. if batch_size_each_card > 1:
  330. batch_size_each_card = 1
  331. batch_size = batch_size_each_card * paddlex.env_info['num']
  332. logging.warning(
  333. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  334. "during evaluation, so batch_size " \
  335. "is forcibly set to {}.".format(batch_size))
  336. self.eval_data_loader = self.build_data_loader(
  337. eval_dataset, batch_size=batch_size, mode='eval')
  338. intersect_area_all = 0
  339. pred_area_all = 0
  340. label_area_all = 0
  341. conf_mat_all = []
  342. logging.info(
  343. "Start to evaluate(total_samples={}, total_steps={})...".format(
  344. eval_dataset.num_samples,
  345. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  346. with paddle.no_grad():
  347. for step, data in enumerate(self.eval_data_loader):
  348. data.append(eval_dataset.transforms.transforms)
  349. outputs = self.run(self.net, data, 'eval')
  350. pred_area = outputs['pred_area']
  351. label_area = outputs['label_area']
  352. intersect_area = outputs['intersect_area']
  353. conf_mat = outputs['conf_mat']
  354. # Gather from all ranks
  355. if nranks > 1:
  356. intersect_area_list = []
  357. pred_area_list = []
  358. label_area_list = []
  359. conf_mat_list = []
  360. paddle.distributed.all_gather(intersect_area_list,
  361. intersect_area)
  362. paddle.distributed.all_gather(pred_area_list, pred_area)
  363. paddle.distributed.all_gather(label_area_list, label_area)
  364. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  365. # Some image has been evaluated and should be eliminated in last iter
  366. if (step + 1) * nranks > len(eval_dataset):
  367. valid = len(eval_dataset) - step * nranks
  368. intersect_area_list = intersect_area_list[:valid]
  369. pred_area_list = pred_area_list[:valid]
  370. label_area_list = label_area_list[:valid]
  371. conf_mat_list = conf_mat_list[:valid]
  372. intersect_area_all += sum(intersect_area_list)
  373. pred_area_all += sum(pred_area_list)
  374. label_area_all += sum(label_area_list)
  375. conf_mat_all.extend(conf_mat_list)
  376. else:
  377. intersect_area_all = intersect_area_all + intersect_area
  378. pred_area_all = pred_area_all + pred_area
  379. label_area_all = label_area_all + label_area
  380. conf_mat_all.append(conf_mat)
  381. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  382. intersect_area_all, pred_area_all, label_area_all)
  383. # TODO 确认是按oacc还是macc
  384. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  385. pred_area_all)
  386. kappa = paddleseg.utils.metrics.kappa(intersect_area_all,
  387. pred_area_all, label_area_all)
  388. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  389. label_area_all)
  390. eval_metrics = OrderedDict(
  391. zip([
  392. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  393. 'category_F1-score'
  394. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  395. if return_details:
  396. conf_mat = sum(conf_mat_all)
  397. eval_details = {'confusion_matrix': conf_mat.tolist()}
  398. return eval_metrics, eval_details
  399. return eval_metrics
  400. def predict(self, img_file, transforms=None):
  401. """
  402. Do inference.
  403. Args:
  404. Args:
  405. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  406. Image path or decoded image data in a BGR format, which also could constitute a list,
  407. meaning all images to be predicted as a mini-batch.
  408. transforms(paddlex.transforms.Compose or None, optional):
  409. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  410. Returns:
  411. If img_file is a string or np.array, the result is a dict with key-value pairs:
  412. {"label map": `label map`, "score_map": `score map`}.
  413. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  414. label_map(np.ndarray): the predicted label map
  415. score_map(np.ndarray): the prediction score map
  416. """
  417. if transforms is None and not hasattr(self, 'test_transforms'):
  418. raise Exception("transforms need to be defined, now is None.")
  419. if transforms is None:
  420. transforms = self.test_transforms
  421. if isinstance(img_file, (str, np.ndarray)):
  422. images = [img_file]
  423. else:
  424. images = img_file
  425. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  426. self.model_type)
  427. self.net.eval()
  428. data = (batch_im, batch_origin_shape, transforms.transforms)
  429. outputs = self.run(self.net, data, 'test')
  430. label_map = outputs['label_map']
  431. label_map = label_map.numpy().astype('uint8')
  432. score_map = outputs['score_map']
  433. score_map = score_map.numpy().astype('float32')
  434. prediction = [{
  435. 'label_map': l,
  436. 'score_map': s
  437. } for l, s in zip(label_map, score_map)]
  438. if isinstance(img_file, (str, np.ndarray)):
  439. prediction = prediction[0]
  440. return prediction
  441. def _preprocess(self, images, transforms, model_type):
  442. arrange_transforms(
  443. model_type=model_type, transforms=transforms, mode='test')
  444. batch_im = list()
  445. batch_ori_shape = list()
  446. for im in images:
  447. sample = {'image': im}
  448. if isinstance(sample['image'], str):
  449. sample = Decode(to_rgb=False)(sample)
  450. ori_shape = sample['image'].shape[:2]
  451. im = transforms(sample)[0]
  452. batch_im.append(im)
  453. batch_ori_shape.append(ori_shape)
  454. batch_im = paddle.to_tensor(batch_im)
  455. return batch_im, batch_ori_shape
  456. @staticmethod
  457. def get_transforms_shape_info(batch_ori_shape, transforms):
  458. batch_restore_list = list()
  459. for ori_shape in batch_ori_shape:
  460. restore_list = list()
  461. h, w = ori_shape[0], ori_shape[1]
  462. for op in transforms:
  463. if op.__class__.__name__ in ['Resize', 'ResizeByShort']:
  464. restore_list.append(('resize', (h, w)))
  465. h, w = op.target_size
  466. if op.__class__.__name__ in ['Padding']:
  467. restore_list.append(('padding', (h, w)))
  468. h, w = op.target_size
  469. batch_restore_list.append(restore_list)
  470. return batch_restore_list
  471. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  472. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  473. batch_origin_shape, transforms)
  474. results = list()
  475. for pred, restore_list in zip(batch_pred, batch_restore_list):
  476. pred = paddle.unsqueeze(pred, axis=0)
  477. for item in restore_list[::-1]:
  478. # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
  479. h, w = item[1][0], item[1][1]
  480. if item[0] == 'resize':
  481. pred = F.interpolate(pred, (h, w), mode='nearest')
  482. elif item[0] == 'padding':
  483. pred = pred[:, :, 0:h, 0:w]
  484. else:
  485. pass
  486. results.append(pred)
  487. batch_pred = paddle.concat(results, axis=0)
  488. return batch_pred
  489. class UNet(BaseSegmenter):
  490. def __init__(self,
  491. num_classes=2,
  492. use_mixed_loss=False,
  493. use_deconv=False,
  494. align_corners=False):
  495. params = {'use_deconv': use_deconv, 'align_corners': align_corners}
  496. super(UNet, self).__init__(
  497. model_name='UNet',
  498. num_classes=num_classes,
  499. use_mixed_loss=use_mixed_loss,
  500. **params)
  501. class DeepLabV3P(BaseSegmenter):
  502. def __init__(self,
  503. num_classes=2,
  504. backbone='ResNet50_vd',
  505. use_mixed_loss=False,
  506. output_stride=8,
  507. backbone_indices=(0, 3),
  508. aspp_ratios=(1, 12, 24, 36),
  509. aspp_out_channels=256,
  510. align_corners=False):
  511. self.backbone_name = backbone
  512. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  513. raise ValueError(
  514. "backbone: {} is not supported. Please choose one of "
  515. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  516. with DisablePrint():
  517. backbone = getattr(paddleseg.models, backbone)(
  518. output_stride=output_stride)
  519. params = {
  520. 'backbone': backbone,
  521. 'backbone_indices': backbone_indices,
  522. 'aspp_ratios': aspp_ratios,
  523. 'aspp_out_channels': aspp_out_channels,
  524. 'align_corners': align_corners
  525. }
  526. super(DeepLabV3P, self).__init__(
  527. model_name='DeepLabV3P',
  528. num_classes=num_classes,
  529. use_mixed_loss=use_mixed_loss,
  530. **params)
  531. class FastSCNN(BaseSegmenter):
  532. def __init__(self,
  533. num_classes=2,
  534. use_mixed_loss=False,
  535. align_corners=False):
  536. params = {'align_corners': align_corners}
  537. super(FastSCNN, self).__init__(
  538. model_name='FastSCNN',
  539. num_classes=num_classes,
  540. use_mixed_loss=use_mixed_loss,
  541. **params)
  542. class HRNet(BaseSegmenter):
  543. def __init__(self,
  544. num_classes=2,
  545. width=48,
  546. use_mixed_loss=False,
  547. align_corners=False):
  548. if width not in (18, 48):
  549. raise ValueError(
  550. "width={} is not supported, please choose from [18, 48]".
  551. format(width))
  552. self.backbone_name = 'HRNet_W{}'.format(width)
  553. with DisablePrint():
  554. backbone = getattr(paddleseg.models, self.backbone_name)(
  555. align_corners=align_corners)
  556. params = {'backbone': backbone, 'align_corners': align_corners}
  557. super(HRNet, self).__init__(
  558. model_name='FCN',
  559. num_classes=num_classes,
  560. use_mixed_loss=use_mixed_loss,
  561. **params)
  562. self.model_name = 'HRNet'
  563. class BiSeNetV2(BaseSegmenter):
  564. def __init__(self,
  565. num_classes=2,
  566. use_mixed_loss=False,
  567. align_corners=False):
  568. params = {'align_corners': align_corners}
  569. super(BiSeNetV2, self).__init__(
  570. model_name='BiSeNetV2',
  571. num_classes=num_classes,
  572. use_mixed_loss=use_mixed_loss,
  573. **params)