segmenter.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. from collections import OrderedDict
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddle.static import InputSpec
  21. import paddleseg
  22. import paddlex
  23. from paddlex.cv.transforms import arrange_transforms
  24. from paddlex.utils import get_single_card_bs, DisablePrint
  25. import paddlex.utils.logging as logging
  26. from .base import BaseModel
  27. from .utils import seg_metrics as metrics
  28. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  29. from paddlex.cv.transforms import Decode
  30. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
  31. class BaseSegmenter(BaseModel):
  32. def __init__(self,
  33. model_name,
  34. num_classes=2,
  35. use_mixed_loss=False,
  36. **params):
  37. self.init_params = locals()
  38. super(BaseSegmenter, self).__init__('segmenter')
  39. if not hasattr(paddleseg.models, model_name):
  40. raise Exception("ERROR: There's no model named {}.".format(
  41. model_name))
  42. self.model_name = model_name
  43. self.num_classes = num_classes
  44. self.use_mixed_loss = use_mixed_loss
  45. self.losses = None
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. self.find_unused_parameters = True
  49. def build_net(self, **params):
  50. # TODO: when using paddle.utils.unique_name.guard,
  51. # DeepLabv3p and HRNet will raise a error
  52. net = paddleseg.models.__dict__[self.model_name](
  53. num_classes=self.num_classes, **params)
  54. return net
  55. def get_test_inputs(self, image_shape):
  56. input_spec = [
  57. InputSpec(
  58. shape=[None, 3] + image_shape, name='image', dtype='float32')
  59. ]
  60. return input_spec
  61. def run(self, net, inputs, mode):
  62. net_out = net(inputs[0])
  63. logit = net_out[0]
  64. outputs = OrderedDict()
  65. if mode == 'test':
  66. origin_shape = inputs[1]
  67. score_map = self._postprocess(
  68. logit, origin_shape, transforms=inputs[2])
  69. label_map = paddle.argmax(
  70. score_map, axis=1, keepdim=True, dtype='int32')
  71. score_map = paddle.max(score_map, axis=1, keepdim=True)
  72. score_map = paddle.squeeze(score_map)
  73. label_map = paddle.squeeze(label_map)
  74. outputs = {'label_map': label_map, 'score_map': score_map}
  75. if mode == 'eval':
  76. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  77. label = inputs[1]
  78. origin_shape = [label.shape[-2:]]
  79. # TODO: 替换cv2后postprocess移出run
  80. pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
  81. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  82. pred, label, self.num_classes)
  83. outputs['intersect_area'] = intersect_area
  84. outputs['pred_area'] = pred_area
  85. outputs['label_area'] = label_area
  86. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  87. self.num_classes)
  88. if mode == 'train':
  89. loss_list = metrics.loss_computation(
  90. logits_list=net_out, labels=inputs[1], losses=self.losses)
  91. loss = sum(loss_list)
  92. outputs['loss'] = loss
  93. return outputs
  94. def default_loss(self):
  95. if isinstance(self.use_mixed_loss, bool):
  96. if self.use_mixed_loss:
  97. losses = [
  98. paddleseg.models.CrossEntropyLoss(),
  99. paddleseg.models.LovaszSoftmaxLoss()
  100. ]
  101. coef = [.8, .2]
  102. loss_type = [
  103. paddleseg.models.MixedLoss(
  104. losses=losses, coef=coef),
  105. ]
  106. else:
  107. loss_type = [paddleseg.models.CrossEntropyLoss()]
  108. else:
  109. losses, coef = list(zip(*self.use_mixed_loss))
  110. if not set(losses).issubset(
  111. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  112. raise ValueError(
  113. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  114. )
  115. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  116. loss_type = [
  117. paddleseg.models.MixedLoss(
  118. losses=losses, coef=list(coef))
  119. ]
  120. if self.model_name == 'FastSCNN':
  121. loss_type *= 2
  122. loss_coef = [1.0, 0.4]
  123. elif self.model_name == 'BiSeNetV2':
  124. loss_type *= 5
  125. loss_coef = [1.0] * 5
  126. else:
  127. loss_coef = [1.0]
  128. losses = {'types': loss_type, 'coef': loss_coef}
  129. return losses
  130. def default_optimizer(self,
  131. parameters,
  132. learning_rate,
  133. num_epochs,
  134. num_steps_each_epoch,
  135. lr_decay_power=0.9):
  136. decay_step = num_epochs * num_steps_each_epoch
  137. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  138. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  139. optimizer = paddle.optimizer.Momentum(
  140. learning_rate=lr_scheduler,
  141. parameters=parameters,
  142. momentum=0.9,
  143. weight_decay=4e-5)
  144. return optimizer
  145. def train(self,
  146. num_epochs,
  147. train_dataset,
  148. train_batch_size=2,
  149. eval_dataset=None,
  150. optimizer=None,
  151. save_interval_epochs=1,
  152. log_interval_steps=2,
  153. save_dir='output',
  154. pretrain_weights='CITYSCAPES',
  155. learning_rate=0.01,
  156. lr_decay_power=0.9,
  157. early_stop=False,
  158. early_stop_patience=5,
  159. use_vdl=True):
  160. """
  161. Train the model.
  162. Args:
  163. num_epochs(int): The number of epochs.
  164. train_dataset(paddlex.dataset): Training dataset.
  165. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  166. eval_dataset(paddlex.dataset, optional):
  167. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  168. optimizer(paddle.optimizer.Optimizer or None, optional):
  169. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  170. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  171. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  172. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  173. pretrain_weights(str or None, optional):
  174. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  175. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  176. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  177. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  178. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  179. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  180. """
  181. self.labels = train_dataset.labels
  182. if self.losses is None:
  183. self.losses = self.default_loss()
  184. if optimizer is None:
  185. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  186. self.optimizer = self.default_optimizer(
  187. self.net.parameters(), learning_rate, num_epochs,
  188. num_steps_each_epoch, lr_decay_power)
  189. else:
  190. self.optimizer = optimizer
  191. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  192. if pretrain_weights not in seg_pretrain_weights_dict[
  193. self.model_name]:
  194. logging.warning(
  195. "Path of pretrain_weights('{}') does not exist!".format(
  196. pretrain_weights))
  197. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  198. "If don't want to use pretrain weights, "
  199. "set pretrain_weights to be None.".format(
  200. seg_pretrain_weights_dict[self.model_name][
  201. 0]))
  202. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  203. 0]
  204. pretrained_dir = osp.join(save_dir, 'pretrain')
  205. self.net_initialize(
  206. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  207. self.train_loop(
  208. num_epochs=num_epochs,
  209. train_dataset=train_dataset,
  210. train_batch_size=train_batch_size,
  211. eval_dataset=eval_dataset,
  212. save_interval_epochs=save_interval_epochs,
  213. log_interval_steps=log_interval_steps,
  214. save_dir=save_dir,
  215. early_stop=early_stop,
  216. early_stop_patience=early_stop_patience,
  217. use_vdl=use_vdl)
  218. def quant_aware_train(self,
  219. num_epochs,
  220. train_dataset,
  221. train_batch_size=2,
  222. eval_dataset=None,
  223. optimizer=None,
  224. save_interval_epochs=1,
  225. log_interval_steps=2,
  226. save_dir='output',
  227. learning_rate=0.0001,
  228. lr_decay_power=0.9,
  229. early_stop=False,
  230. early_stop_patience=5,
  231. use_vdl=True,
  232. quant_config=None):
  233. """
  234. Quantization-aware training.
  235. Args:
  236. num_epochs(int): The number of epochs.
  237. train_dataset(paddlex.dataset): Training dataset.
  238. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  239. eval_dataset(paddlex.dataset, optional):
  240. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  241. optimizer(paddle.optimizer.Optimizer or None, optional):
  242. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  243. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  244. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  245. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  246. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  247. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  248. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  249. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  250. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  251. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  252. configuration will be used. Defaults to None.
  253. """
  254. self._prepare_qat(quant_config)
  255. self.train(
  256. num_epochs=num_epochs,
  257. train_dataset=train_dataset,
  258. train_batch_size=train_batch_size,
  259. eval_dataset=eval_dataset,
  260. optimizer=optimizer,
  261. save_interval_epochs=save_interval_epochs,
  262. log_interval_steps=log_interval_steps,
  263. save_dir=save_dir,
  264. pretrain_weights=None,
  265. learning_rate=learning_rate,
  266. lr_decay_power=lr_decay_power,
  267. early_stop=early_stop,
  268. early_stop_patience=early_stop_patience,
  269. use_vdl=use_vdl)
  270. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  271. """
  272. Evaluate the model.
  273. Args:
  274. eval_dataset(paddlex.dataset): Evaluation dataset.
  275. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  276. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  277. Returns:
  278. collections.OrderedDict with key-value pairs:
  279. {"miou": `mean intersection over union`,
  280. "category_iou": `category-wise mean intersection over union`,
  281. "oacc": `overall accuracy`,
  282. "category_acc": `category-wise accuracy`,
  283. "kappa": ` kappa coefficient`,
  284. "category_F1-score": `F1 score`}.
  285. """
  286. arrange_transforms(
  287. model_type=self.model_type,
  288. transforms=eval_dataset.transforms,
  289. mode='eval')
  290. self.net.eval()
  291. nranks = paddle.distributed.get_world_size()
  292. local_rank = paddle.distributed.get_rank()
  293. if nranks > 1:
  294. # Initialize parallel environment if not done.
  295. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  296. ):
  297. paddle.distributed.init_parallel_env()
  298. batch_size_each_card = get_single_card_bs(batch_size)
  299. if batch_size_each_card > 1:
  300. batch_size_each_card = 1
  301. batch_size = batch_size_each_card * paddlex.env_info['num']
  302. logging.warning(
  303. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  304. "during evaluation, so batch_size " \
  305. "is forcibly set to {}.".format(batch_size))
  306. self.eval_data_loader = self.build_data_loader(
  307. eval_dataset, batch_size=batch_size, mode='eval')
  308. intersect_area_all = 0
  309. pred_area_all = 0
  310. label_area_all = 0
  311. conf_mat_all = []
  312. logging.info(
  313. "Start to evaluate(total_samples={}, total_steps={})...".format(
  314. eval_dataset.num_samples,
  315. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  316. with paddle.no_grad():
  317. for step, data in enumerate(self.eval_data_loader):
  318. data.append(eval_dataset.transforms.transforms)
  319. outputs = self.run(self.net, data, 'eval')
  320. pred_area = outputs['pred_area']
  321. label_area = outputs['label_area']
  322. intersect_area = outputs['intersect_area']
  323. conf_mat = outputs['conf_mat']
  324. # Gather from all ranks
  325. if nranks > 1:
  326. intersect_area_list = []
  327. pred_area_list = []
  328. label_area_list = []
  329. conf_mat_list = []
  330. paddle.distributed.all_gather(intersect_area_list,
  331. intersect_area)
  332. paddle.distributed.all_gather(pred_area_list, pred_area)
  333. paddle.distributed.all_gather(label_area_list, label_area)
  334. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  335. # Some image has been evaluated and should be eliminated in last iter
  336. if (step + 1) * nranks > len(eval_dataset):
  337. valid = len(eval_dataset) - step * nranks
  338. intersect_area_list = intersect_area_list[:valid]
  339. pred_area_list = pred_area_list[:valid]
  340. label_area_list = label_area_list[:valid]
  341. conf_mat_list = conf_mat_list[:valid]
  342. intersect_area_all += sum(intersect_area_list)
  343. pred_area_all += sum(pred_area_list)
  344. label_area_all += sum(label_area_list)
  345. conf_mat_all.extend(conf_mat_list)
  346. else:
  347. intersect_area_all = intersect_area_all + intersect_area
  348. pred_area_all = pred_area_all + pred_area
  349. label_area_all = label_area_all + label_area
  350. conf_mat_all.append(conf_mat)
  351. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  352. intersect_area_all, pred_area_all, label_area_all)
  353. # TODO 确认是按oacc还是macc
  354. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  355. pred_area_all)
  356. kappa = paddleseg.utils.metrics.kappa(intersect_area_all,
  357. pred_area_all, label_area_all)
  358. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  359. label_area_all)
  360. eval_metrics = OrderedDict(
  361. zip([
  362. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  363. 'category_F1-score'
  364. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  365. if return_details:
  366. conf_mat = sum(conf_mat_all).numpy()
  367. eval_details = {'confusion_matrix': conf_mat.tolist()}
  368. return eval_metrics, eval_details
  369. return eval_metrics
  370. def predict(self, img_file, transforms=None):
  371. """
  372. Do inference.
  373. Args:
  374. Args:
  375. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  376. Image path or decoded image data in a BGR format, which also could constitute a list,
  377. meaning all images to be predicted as a mini-batch.
  378. transforms(paddlex.transforms.Compose or None, optional):
  379. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  380. Returns:
  381. If img_file is a string or np.array, the result is a dict with key-value pairs:
  382. {"label map": `label map`, "score_map": `score map`}.
  383. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  384. label_map(np.ndarray): the predicted label map
  385. score_map(np.ndarray): the prediction score map
  386. """
  387. if transforms is None and not hasattr(self, 'test_transforms'):
  388. raise Exception("transforms need to be defined, now is None.")
  389. if transforms is None:
  390. transforms = self.test_transforms
  391. if isinstance(img_file, (str, np.ndarray)):
  392. images = [img_file]
  393. else:
  394. images = img_file
  395. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  396. self.model_type)
  397. self.net.eval()
  398. data = (batch_im, batch_origin_shape, transforms.transforms)
  399. outputs = self.run(self.net, data, 'test')
  400. label_map = outputs['label_map']
  401. label_map = label_map.numpy().astype('uint8')
  402. score_map = outputs['score_map']
  403. score_map = score_map.numpy().astype('float32')
  404. return {'label_map': label_map, 'score_map': score_map}
  405. def _preprocess(self, images, transforms, model_type):
  406. arrange_transforms(
  407. model_type=model_type, transforms=transforms, mode='test')
  408. batch_im = list()
  409. batch_ori_shape = list()
  410. for im in images:
  411. sample = {'image': im}
  412. if isinstance(sample['image'], str):
  413. sample = Decode(to_rgb=False)(sample)
  414. ori_shape = sample['image'].shape[:2]
  415. im = transforms(sample)[0]
  416. batch_im.append(im)
  417. batch_ori_shape.append(ori_shape)
  418. batch_im = paddle.to_tensor(batch_im)
  419. return batch_im, batch_ori_shape
  420. @staticmethod
  421. def get_transforms_shape_info(batch_ori_shape, transforms):
  422. batch_restore_list = list()
  423. for ori_shape in batch_ori_shape:
  424. restore_list = list()
  425. h, w = ori_shape[0], ori_shape[1]
  426. for op in transforms:
  427. if op.__class__.__name__ in ['Resize', 'ResizeByShort']:
  428. restore_list.append(('resize', (h, w)))
  429. h, w = op.target_size
  430. if op.__class__.__name__ in ['Padding']:
  431. restore_list.append(('padding', (h, w)))
  432. h, w = op.target_size
  433. batch_restore_list.append(restore_list)
  434. return batch_restore_list
  435. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  436. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  437. batch_origin_shape, transforms)
  438. results = list()
  439. for pred, restore_list in zip(batch_pred, batch_restore_list):
  440. pred = paddle.unsqueeze(pred, axis=0)
  441. for item in restore_list[::-1]:
  442. # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
  443. h, w = item[1][0], item[1][1]
  444. if item[0] == 'resize':
  445. pred = F.interpolate(pred, (h, w), mode='nearest')
  446. elif item[0] == 'padding':
  447. pred = pred[:, :, 0:h, 0:w]
  448. else:
  449. pass
  450. results.append(pred)
  451. batch_pred = paddle.concat(results, axis=0)
  452. return batch_pred
  453. class UNet(BaseSegmenter):
  454. def __init__(self,
  455. num_classes=2,
  456. use_mixed_loss=False,
  457. use_deconv=False,
  458. align_corners=False):
  459. params = {'use_deconv': use_deconv, 'align_corners': align_corners}
  460. super(UNet, self).__init__(
  461. model_name='UNet',
  462. num_classes=num_classes,
  463. use_mixed_loss=use_mixed_loss,
  464. **params)
  465. class DeepLabV3P(BaseSegmenter):
  466. def __init__(self,
  467. num_classes=2,
  468. backbone='ResNet50_vd',
  469. use_mixed_loss=False,
  470. output_stride=8,
  471. backbone_indices=(0, 3),
  472. aspp_ratios=(1, 12, 24, 36),
  473. aspp_out_channels=256,
  474. align_corners=False):
  475. self.backbone_name = backbone
  476. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  477. raise ValueError(
  478. "backbone: {} is not supported. Please choose one of "
  479. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  480. with DisablePrint():
  481. backbone = getattr(paddleseg.models, backbone)(
  482. output_stride=output_stride)
  483. params = {
  484. 'backbone': backbone,
  485. 'backbone_indices': backbone_indices,
  486. 'aspp_ratios': aspp_ratios,
  487. 'aspp_out_channels': aspp_out_channels,
  488. 'align_corners': align_corners
  489. }
  490. super(DeepLabV3P, self).__init__(
  491. model_name='DeepLabV3P',
  492. num_classes=num_classes,
  493. use_mixed_loss=use_mixed_loss,
  494. **params)
  495. class FastSCNN(BaseSegmenter):
  496. def __init__(self,
  497. num_classes=2,
  498. use_mixed_loss=False,
  499. align_corners=False):
  500. params = {'align_corners': align_corners}
  501. super(FastSCNN, self).__init__(
  502. model_name='FastSCNN',
  503. num_classes=num_classes,
  504. use_mixed_loss=use_mixed_loss,
  505. **params)
  506. class HRNet(BaseSegmenter):
  507. def __init__(self,
  508. num_classes=2,
  509. width=48,
  510. use_mixed_loss=False,
  511. align_corners=False):
  512. if width not in (18, 48):
  513. raise ValueError(
  514. "width={} is not supported, please choose from [18, 48]".
  515. format(width))
  516. self.backbone_name = 'HRNet_W{}'.format(width)
  517. with DisablePrint():
  518. backbone = getattr(paddleseg.models, self.backbone_name)(
  519. align_corners=align_corners)
  520. params = {'backbone': backbone, 'align_corners': align_corners}
  521. super(HRNet, self).__init__(
  522. model_name='FCN',
  523. num_classes=num_classes,
  524. use_mixed_loss=use_mixed_loss,
  525. **params)
  526. self.model_name = 'HRNet'
  527. class BiSeNetV2(BaseSegmenter):
  528. def __init__(self,
  529. num_classes=2,
  530. use_mixed_loss=False,
  531. align_corners=False):
  532. params = {'align_corners': align_corners}
  533. super(BiSeNetV2, self).__init__(
  534. model_name='BiSeNetV2',
  535. num_classes=num_classes,
  536. use_mixed_loss=use_mixed_loss,
  537. **params)