segmenter.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. import cv2
  18. from collections import OrderedDict
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle.static import InputSpec
  22. import paddlex.paddleseg as paddleseg
  23. import paddlex
  24. from paddlex.cv.transforms import arrange_transforms
  25. from paddlex.utils import get_single_card_bs, DisablePrint
  26. import paddlex.utils.logging as logging
  27. from .base import BaseModel
  28. from .utils import seg_metrics as metrics
  29. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  30. from paddlex.cv.transforms import Decode, Resize
  31. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
  32. class BaseSegmenter(BaseModel):
  33. def __init__(self,
  34. model_name,
  35. num_classes=2,
  36. use_mixed_loss=False,
  37. **params):
  38. self.init_params = locals()
  39. if 'with_net' in self.init_params:
  40. del self.init_params['with_net']
  41. super(BaseSegmenter, self).__init__('segmenter')
  42. if not hasattr(paddleseg.models, model_name):
  43. raise Exception("ERROR: There's no model named {}.".format(
  44. model_name))
  45. self.model_name = model_name
  46. self.num_classes = num_classes
  47. self.use_mixed_loss = use_mixed_loss
  48. self.losses = None
  49. self.labels = None
  50. if params.get('with_net', True):
  51. params.pop('with_net', None)
  52. self.net = self.build_net(**params)
  53. self.find_unused_parameters = True
  54. def build_net(self, **params):
  55. # TODO: when using paddle.utils.unique_name.guard,
  56. # DeepLabv3p and HRNet will raise a error
  57. net = paddleseg.models.__dict__[self.model_name](
  58. num_classes=self.num_classes, **params)
  59. return net
  60. def _fix_transforms_shape(self, image_shape):
  61. if hasattr(self, 'test_transforms'):
  62. if self.test_transforms is not None:
  63. has_resize_op = False
  64. resize_op_idx = -1
  65. normalize_op_idx = len(self.test_transforms.transforms)
  66. for idx, op in enumerate(self.test_transforms.transforms):
  67. name = op.__class__.__name__
  68. if name == 'Normalize':
  69. normalize_op_idx = idx
  70. if 'Resize' in name:
  71. has_resize_op = True
  72. resize_op_idx = idx
  73. if not has_resize_op:
  74. self.test_transforms.transforms.insert(
  75. normalize_op_idx, Resize(target_size=image_shape))
  76. else:
  77. self.test_transforms.transforms[resize_op_idx] = Resize(
  78. target_size=image_shape)
  79. def _get_test_inputs(self, image_shape):
  80. if image_shape is not None:
  81. if len(image_shape) == 2:
  82. image_shape = [1, 3] + image_shape
  83. self._fix_transforms_shape(image_shape[-2:])
  84. else:
  85. image_shape = [None, 3, -1, -1]
  86. self.fixed_input_shape = image_shape
  87. input_spec = [
  88. InputSpec(
  89. shape=image_shape, name='image', dtype='float32')
  90. ]
  91. return input_spec
  92. def run(self, net, inputs, mode):
  93. net_out = net(inputs[0])
  94. logit = net_out[0]
  95. outputs = OrderedDict()
  96. if mode == 'test':
  97. origin_shape = inputs[1]
  98. if self.status == 'Infer':
  99. label_map_list, score_map_list = self._postprocess(
  100. net_out, origin_shape, transforms=inputs[2])
  101. else:
  102. logit_list = self._postprocess(
  103. logit, origin_shape, transforms=inputs[2])
  104. label_map_list = []
  105. score_map_list = []
  106. for logit in logit_list:
  107. logit = paddle.transpose(logit, perm=[0, 2, 3, 1]) # NHWC
  108. label_map_list.append((paddle.argmax(
  109. logit, axis=-1, keepdim=False, dtype='int32')).squeeze(
  110. ).numpy().astype('int32'))
  111. score_map_list.append(
  112. F.softmax(
  113. logit, axis=-1).squeeze().numpy().astype(
  114. 'float32'))
  115. outputs['label_map'] = label_map_list
  116. outputs['score_map'] = score_map_list
  117. if mode == 'eval':
  118. if self.status == 'Infer':
  119. pred = paddle.unsqueeze(net_out[0], axis=1) # NCHW
  120. else:
  121. pred = paddle.argmax(
  122. logit, axis=1, keepdim=True, dtype='int32')
  123. label = inputs[1]
  124. origin_shape = [label.shape[-2:]]
  125. pred = self._postprocess(
  126. pred, origin_shape, transforms=inputs[2])[0] # NCHW
  127. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  128. pred, label, self.num_classes)
  129. outputs['intersect_area'] = intersect_area
  130. outputs['pred_area'] = pred_area
  131. outputs['label_area'] = label_area
  132. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  133. self.num_classes)
  134. if mode == 'train':
  135. loss_list = metrics.loss_computation(
  136. logits_list=net_out, labels=inputs[1], losses=self.losses)
  137. loss = sum(loss_list)
  138. outputs['loss'] = loss
  139. return outputs
  140. def default_loss(self):
  141. if isinstance(self.use_mixed_loss, bool):
  142. if self.use_mixed_loss:
  143. losses = [
  144. paddleseg.models.CrossEntropyLoss(),
  145. paddleseg.models.LovaszSoftmaxLoss()
  146. ]
  147. coef = [.8, .2]
  148. loss_type = [
  149. paddleseg.models.MixedLoss(
  150. losses=losses, coef=coef),
  151. ]
  152. else:
  153. loss_type = [paddleseg.models.CrossEntropyLoss()]
  154. else:
  155. losses, coef = list(zip(*self.use_mixed_loss))
  156. if not set(losses).issubset(
  157. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  158. raise ValueError(
  159. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  160. )
  161. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  162. loss_type = [
  163. paddleseg.models.MixedLoss(
  164. losses=losses, coef=list(coef))
  165. ]
  166. if self.model_name == 'FastSCNN':
  167. loss_type *= 2
  168. loss_coef = [1.0, 0.4]
  169. elif self.model_name == 'BiSeNetV2':
  170. loss_type *= 5
  171. loss_coef = [1.0] * 5
  172. else:
  173. loss_coef = [1.0]
  174. losses = {'types': loss_type, 'coef': loss_coef}
  175. return losses
  176. def default_optimizer(self,
  177. parameters,
  178. learning_rate,
  179. num_epochs,
  180. num_steps_each_epoch,
  181. lr_decay_power=0.9):
  182. decay_step = num_epochs * num_steps_each_epoch
  183. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  184. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  185. optimizer = paddle.optimizer.Momentum(
  186. learning_rate=lr_scheduler,
  187. parameters=parameters,
  188. momentum=0.9,
  189. weight_decay=4e-5)
  190. return optimizer
  191. def train(self,
  192. num_epochs,
  193. train_dataset,
  194. train_batch_size=2,
  195. eval_dataset=None,
  196. optimizer=None,
  197. save_interval_epochs=1,
  198. log_interval_steps=2,
  199. save_dir='output',
  200. pretrain_weights='CITYSCAPES',
  201. learning_rate=0.01,
  202. lr_decay_power=0.9,
  203. early_stop=False,
  204. early_stop_patience=5,
  205. use_vdl=True,
  206. resume_checkpoint=None):
  207. """
  208. Train the model.
  209. Args:
  210. num_epochs(int): The number of epochs.
  211. train_dataset(paddlex.dataset): Training dataset.
  212. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  213. eval_dataset(paddlex.dataset, optional):
  214. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  215. optimizer(paddle.optimizer.Optimizer or None, optional):
  216. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  217. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  218. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  219. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  220. pretrain_weights(str or None, optional):
  221. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
  222. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  223. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  224. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  225. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  226. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  227. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  228. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  229. `pretrain_weights` can be set simultaneously. Defaults to None.
  230. """
  231. if self.status == 'Infer':
  232. logging.error(
  233. "Exported inference model does not support training.",
  234. exit=True)
  235. if pretrain_weights is not None and resume_checkpoint is not None:
  236. logging.error(
  237. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  238. exit=True)
  239. self.labels = train_dataset.labels
  240. if self.losses is None:
  241. self.losses = self.default_loss()
  242. if optimizer is None:
  243. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  244. self.optimizer = self.default_optimizer(
  245. self.net.parameters(), learning_rate, num_epochs,
  246. num_steps_each_epoch, lr_decay_power)
  247. else:
  248. self.optimizer = optimizer
  249. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  250. if pretrain_weights not in seg_pretrain_weights_dict[
  251. self.model_name]:
  252. logging.warning(
  253. "Path of pretrain_weights('{}') does not exist!".format(
  254. pretrain_weights))
  255. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  256. "If don't want to use pretrain weights, "
  257. "set pretrain_weights to be None.".format(
  258. seg_pretrain_weights_dict[self.model_name][
  259. 0]))
  260. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  261. 0]
  262. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  263. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  264. logging.error(
  265. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  266. exit=True)
  267. pretrained_dir = osp.join(save_dir, 'pretrain')
  268. is_backbone_weights = pretrain_weights == 'IMAGENET'
  269. self.net_initialize(
  270. pretrain_weights=pretrain_weights,
  271. save_dir=pretrained_dir,
  272. resume_checkpoint=resume_checkpoint,
  273. is_backbone_weights=is_backbone_weights)
  274. self.train_loop(
  275. num_epochs=num_epochs,
  276. train_dataset=train_dataset,
  277. train_batch_size=train_batch_size,
  278. eval_dataset=eval_dataset,
  279. save_interval_epochs=save_interval_epochs,
  280. log_interval_steps=log_interval_steps,
  281. save_dir=save_dir,
  282. early_stop=early_stop,
  283. early_stop_patience=early_stop_patience,
  284. use_vdl=use_vdl)
  285. def quant_aware_train(self,
  286. num_epochs,
  287. train_dataset,
  288. train_batch_size=2,
  289. eval_dataset=None,
  290. optimizer=None,
  291. save_interval_epochs=1,
  292. log_interval_steps=2,
  293. save_dir='output',
  294. learning_rate=0.0001,
  295. lr_decay_power=0.9,
  296. early_stop=False,
  297. early_stop_patience=5,
  298. use_vdl=True,
  299. resume_checkpoint=None,
  300. quant_config=None):
  301. """
  302. Quantization-aware training.
  303. Args:
  304. num_epochs(int): The number of epochs.
  305. train_dataset(paddlex.dataset): Training dataset.
  306. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  307. eval_dataset(paddlex.dataset, optional):
  308. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  309. optimizer(paddle.optimizer.Optimizer or None, optional):
  310. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  311. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  312. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  313. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  314. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  315. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  316. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  317. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  318. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  319. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  320. configuration will be used. Defaults to None.
  321. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  322. from. If None, no training checkpoint will be resumed. Defaults to None.
  323. """
  324. self._prepare_qat(quant_config)
  325. self.train(
  326. num_epochs=num_epochs,
  327. train_dataset=train_dataset,
  328. train_batch_size=train_batch_size,
  329. eval_dataset=eval_dataset,
  330. optimizer=optimizer,
  331. save_interval_epochs=save_interval_epochs,
  332. log_interval_steps=log_interval_steps,
  333. save_dir=save_dir,
  334. pretrain_weights=None,
  335. learning_rate=learning_rate,
  336. lr_decay_power=lr_decay_power,
  337. early_stop=early_stop,
  338. early_stop_patience=early_stop_patience,
  339. use_vdl=use_vdl,
  340. resume_checkpoint=resume_checkpoint)
  341. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  342. """
  343. Evaluate the model.
  344. Args:
  345. eval_dataset(paddlex.dataset): Evaluation dataset.
  346. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  347. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  348. Returns:
  349. collections.OrderedDict with key-value pairs:
  350. {"miou": `mean intersection over union`,
  351. "category_iou": `category-wise mean intersection over union`,
  352. "oacc": `overall accuracy`,
  353. "category_acc": `category-wise accuracy`,
  354. "kappa": ` kappa coefficient`,
  355. "category_F1-score": `F1 score`}.
  356. """
  357. arrange_transforms(
  358. model_type=self.model_type,
  359. transforms=eval_dataset.transforms,
  360. mode='eval')
  361. self.net.eval()
  362. nranks = paddle.distributed.get_world_size()
  363. local_rank = paddle.distributed.get_rank()
  364. if nranks > 1:
  365. # Initialize parallel environment if not done.
  366. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  367. ):
  368. paddle.distributed.init_parallel_env()
  369. batch_size_each_card = get_single_card_bs(batch_size)
  370. if batch_size_each_card > 1:
  371. batch_size_each_card = 1
  372. batch_size = batch_size_each_card * paddlex.env_info['num']
  373. logging.warning(
  374. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  375. "during evaluation, so batch_size " \
  376. "is forcibly set to {}.".format(batch_size))
  377. self.eval_data_loader = self.build_data_loader(
  378. eval_dataset, batch_size=batch_size, mode='eval')
  379. intersect_area_all = 0
  380. pred_area_all = 0
  381. label_area_all = 0
  382. conf_mat_all = []
  383. logging.info(
  384. "Start to evaluate(total_samples={}, total_steps={})...".format(
  385. eval_dataset.num_samples,
  386. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  387. with paddle.no_grad():
  388. for step, data in enumerate(self.eval_data_loader):
  389. data.append(eval_dataset.transforms.transforms)
  390. outputs = self.run(self.net, data, 'eval')
  391. pred_area = outputs['pred_area']
  392. label_area = outputs['label_area']
  393. intersect_area = outputs['intersect_area']
  394. conf_mat = outputs['conf_mat']
  395. # Gather from all ranks
  396. if nranks > 1:
  397. intersect_area_list = []
  398. pred_area_list = []
  399. label_area_list = []
  400. conf_mat_list = []
  401. paddle.distributed.all_gather(intersect_area_list,
  402. intersect_area)
  403. paddle.distributed.all_gather(pred_area_list, pred_area)
  404. paddle.distributed.all_gather(label_area_list, label_area)
  405. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  406. # Some image has been evaluated and should be eliminated in last iter
  407. if (step + 1) * nranks > len(eval_dataset):
  408. valid = len(eval_dataset) - step * nranks
  409. intersect_area_list = intersect_area_list[:valid]
  410. pred_area_list = pred_area_list[:valid]
  411. label_area_list = label_area_list[:valid]
  412. conf_mat_list = conf_mat_list[:valid]
  413. intersect_area_all += sum(intersect_area_list)
  414. pred_area_all += sum(pred_area_list)
  415. label_area_all += sum(label_area_list)
  416. conf_mat_all.extend(conf_mat_list)
  417. else:
  418. intersect_area_all = intersect_area_all + intersect_area
  419. pred_area_all = pred_area_all + pred_area
  420. label_area_all = label_area_all + label_area
  421. conf_mat_all.append(conf_mat)
  422. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  423. intersect_area_all, pred_area_all, label_area_all)
  424. # TODO 确认是按oacc还是macc
  425. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  426. pred_area_all)
  427. kappa = paddleseg.utils.metrics.kappa(intersect_area_all,
  428. pred_area_all, label_area_all)
  429. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  430. label_area_all)
  431. eval_metrics = OrderedDict(
  432. zip([
  433. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  434. 'category_F1-score'
  435. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  436. if return_details:
  437. conf_mat = sum(conf_mat_all)
  438. eval_details = {'confusion_matrix': conf_mat.tolist()}
  439. return eval_metrics, eval_details
  440. return eval_metrics
  441. def predict(self, img_file, transforms=None):
  442. """
  443. Do inference.
  444. Args:
  445. Args:
  446. img_file(List[np.ndarray or str], str or np.ndarray):
  447. Image path or decoded image data in a BGR format, which also could constitute a list,
  448. meaning all images to be predicted as a mini-batch.
  449. transforms(paddlex.transforms.Compose or None, optional):
  450. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  451. Returns:
  452. If img_file is a string or np.array, the result is a dict with key-value pairs:
  453. {"label map": `label map`, "score_map": `score map`}.
  454. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  455. label_map(np.ndarray): the predicted label map (HW)
  456. score_map(np.ndarray): the prediction score map (HWC)
  457. """
  458. if transforms is None and not hasattr(self, 'test_transforms'):
  459. raise Exception("transforms need to be defined, now is None.")
  460. if transforms is None:
  461. transforms = self.test_transforms
  462. if isinstance(img_file, (str, np.ndarray)):
  463. images = [img_file]
  464. else:
  465. images = img_file
  466. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  467. self.model_type)
  468. self.net.eval()
  469. data = (batch_im, batch_origin_shape, transforms.transforms)
  470. outputs = self.run(self.net, data, 'test')
  471. label_map_list = outputs['label_map']
  472. score_map_list = outputs['score_map']
  473. if isinstance(img_file, list):
  474. prediction = [{
  475. 'label_map': l,
  476. 'score_map': s
  477. } for l, s in zip(label_map_list, score_map_list)]
  478. else:
  479. prediction = {
  480. 'label_map': label_map_list[0],
  481. 'score_map': score_map_list[0]
  482. }
  483. return prediction
  484. def _preprocess(self, images, transforms, to_tensor=True):
  485. arrange_transforms(
  486. model_type=self.model_type, transforms=transforms, mode='test')
  487. batch_im = list()
  488. batch_ori_shape = list()
  489. for im in images:
  490. sample = {'image': im}
  491. if isinstance(sample['image'], str):
  492. sample = Decode(to_rgb=False)(sample)
  493. ori_shape = sample['image'].shape[:2]
  494. im = transforms(sample)[0]
  495. batch_im.append(im)
  496. batch_ori_shape.append(ori_shape)
  497. if to_tensor:
  498. batch_im = paddle.to_tensor(batch_im)
  499. else:
  500. batch_im = np.asarray(batch_im)
  501. return batch_im, batch_ori_shape
  502. @staticmethod
  503. def get_transforms_shape_info(batch_ori_shape, transforms):
  504. batch_restore_list = list()
  505. for ori_shape in batch_ori_shape:
  506. restore_list = list()
  507. h, w = ori_shape[0], ori_shape[1]
  508. for op in transforms:
  509. if op.__class__.__name__ == 'Resize':
  510. restore_list.append(('resize', (h, w)))
  511. h, w = op.target_size
  512. elif op.__class__.__name__ == 'ResizeByShort':
  513. restore_list.append(('resize', (h, w)))
  514. im_short_size = min(h, w)
  515. im_long_size = max(h, w)
  516. scale = float(op.short_size) / float(im_short_size)
  517. if 0 < op.max_size < np.round(scale * im_long_size):
  518. scale = float(op.max_size) / float(im_long_size)
  519. h = int(round(h * scale))
  520. w = int(round(w * scale))
  521. elif op.__class__.__name__ == 'ResizeByLong':
  522. restore_list.append(('resize', (h, w)))
  523. im_long_size = max(h, w)
  524. scale = float(op.long_size) / float(im_long_size)
  525. h = int(round(h * scale))
  526. w = int(round(w * scale))
  527. elif op.__class__.__name__ == 'Padding':
  528. if op.target_size:
  529. target_h, target_w = op.target_size
  530. else:
  531. target_h = int(
  532. (np.ceil(h / op.size_divisor) * op.size_divisor))
  533. target_w = int(
  534. (np.ceil(w / op.size_divisor) * op.size_divisor))
  535. if op.pad_mode == -1:
  536. offsets = op.offsets
  537. elif op.pad_mode == 0:
  538. offsets = [0, 0]
  539. elif op.pad_mode == 1:
  540. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  541. else:
  542. offsets = [target_h - h, target_w - w]
  543. restore_list.append(('padding', (h, w), offsets))
  544. h, w = target_h, target_w
  545. batch_restore_list.append(restore_list)
  546. return batch_restore_list
  547. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  548. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  549. batch_origin_shape, transforms)
  550. if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
  551. return self._infer_postprocess(
  552. batch_label_map=batch_pred[0],
  553. batch_score_map=batch_pred[1],
  554. batch_restore_list=batch_restore_list)
  555. results = []
  556. if batch_pred.dtype == paddle.float32:
  557. mode = 'bilinear'
  558. else:
  559. mode = 'nearest'
  560. for pred, restore_list in zip(batch_pred, batch_restore_list):
  561. pred = paddle.unsqueeze(pred, axis=0)
  562. for item in restore_list[::-1]:
  563. h, w = item[1][0], item[1][1]
  564. if item[0] == 'resize':
  565. pred = F.interpolate(
  566. pred, (h, w), mode=mode, data_format='NCHW')
  567. elif item[0] == 'padding':
  568. x, y = item[2]
  569. pred = pred[:, :, y:y + h, x:x + w]
  570. else:
  571. pass
  572. results.append(pred)
  573. return results
  574. def _infer_postprocess(self, batch_label_map, batch_score_map,
  575. batch_restore_list):
  576. label_maps = []
  577. score_maps = []
  578. for label_map, score_map, restore_list in zip(
  579. batch_label_map, batch_score_map, batch_restore_list):
  580. if not isinstance(label_map, np.ndarray):
  581. label_map = paddle.unsqueeze(label_map, axis=[0, 3])
  582. score_map = paddle.unsqueeze(score_map, axis=0)
  583. for item in restore_list[::-1]:
  584. h, w = item[1][0], item[1][1]
  585. if item[0] == 'resize':
  586. if isinstance(label_map, np.ndarray):
  587. label_map = cv2.resize(
  588. label_map, (w, h), interpolation=cv2.INTER_NEAREST)
  589. score_map = cv2.resize(
  590. score_map, (w, h), interpolation=cv2.INTER_LINEAR)
  591. else:
  592. label_map = F.interpolate(
  593. label_map, (h, w),
  594. mode='nearest',
  595. data_format='NHWC')
  596. score_map = F.interpolate(
  597. score_map, (h, w),
  598. mode='bilinear',
  599. data_format='NHWC')
  600. elif item[0] == 'padding':
  601. x, y = item[2]
  602. if isinstance(label_map, np.ndarray):
  603. label_map = label_map[..., y:y + h, x:x + w]
  604. score_map = score_map[..., y:y + h, x:x + w]
  605. else:
  606. label_map = label_map[:, :, y:y + h, x:x + w]
  607. score_map = score_map[:, :, y:y + h, x:x + w]
  608. else:
  609. pass
  610. label_map = label_map.squeeze()
  611. score_map = score_map.squeeze()
  612. if not isinstance(label_map, np.ndarray):
  613. label_map = label_map.numpy()
  614. score_map = score_map.numpy()
  615. label_maps.append(label_map.squeeze())
  616. score_maps.append(score_map.squeeze())
  617. return label_maps, score_maps
  618. class UNet(BaseSegmenter):
  619. def __init__(self,
  620. num_classes=2,
  621. use_mixed_loss=False,
  622. use_deconv=False,
  623. align_corners=False,
  624. **params):
  625. params.update({
  626. 'use_deconv': use_deconv,
  627. 'align_corners': align_corners
  628. })
  629. super(UNet, self).__init__(
  630. model_name='UNet',
  631. num_classes=num_classes,
  632. use_mixed_loss=use_mixed_loss,
  633. **params)
  634. class DeepLabV3P(BaseSegmenter):
  635. def __init__(self,
  636. num_classes=2,
  637. backbone='ResNet50_vd',
  638. use_mixed_loss=False,
  639. output_stride=8,
  640. backbone_indices=(0, 3),
  641. aspp_ratios=(1, 12, 24, 36),
  642. aspp_out_channels=256,
  643. align_corners=False,
  644. **params):
  645. self.backbone_name = backbone
  646. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  647. raise ValueError(
  648. "backbone: {} is not supported. Please choose one of "
  649. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  650. if params.get('with_net', True):
  651. with DisablePrint():
  652. backbone = getattr(paddleseg.models, backbone)(
  653. output_stride=output_stride)
  654. else:
  655. backbone = None
  656. params.update({
  657. 'backbone': backbone,
  658. 'backbone_indices': backbone_indices,
  659. 'aspp_ratios': aspp_ratios,
  660. 'aspp_out_channels': aspp_out_channels,
  661. 'align_corners': align_corners
  662. })
  663. super(DeepLabV3P, self).__init__(
  664. model_name='DeepLabV3P',
  665. num_classes=num_classes,
  666. use_mixed_loss=use_mixed_loss,
  667. **params)
  668. class FastSCNN(BaseSegmenter):
  669. def __init__(self,
  670. num_classes=2,
  671. use_mixed_loss=False,
  672. align_corners=False,
  673. **params):
  674. params.update({'align_corners': align_corners})
  675. super(FastSCNN, self).__init__(
  676. model_name='FastSCNN',
  677. num_classes=num_classes,
  678. use_mixed_loss=use_mixed_loss,
  679. **params)
  680. class HRNet(BaseSegmenter):
  681. def __init__(self,
  682. num_classes=2,
  683. width=48,
  684. use_mixed_loss=False,
  685. align_corners=False,
  686. **params):
  687. if width not in (18, 48):
  688. raise ValueError(
  689. "width={} is not supported, please choose from [18, 48]".
  690. format(width))
  691. self.backbone_name = 'HRNet_W{}'.format(width)
  692. if params.get('with_net', True):
  693. with DisablePrint():
  694. backbone = getattr(paddleseg.models, self.backbone_name)(
  695. align_corners=align_corners)
  696. else:
  697. backbone = None
  698. params.update({'backbone': backbone, 'align_corners': align_corners})
  699. super(HRNet, self).__init__(
  700. model_name='FCN',
  701. num_classes=num_classes,
  702. use_mixed_loss=use_mixed_loss,
  703. **params)
  704. self.model_name = 'HRNet'
  705. class BiSeNetV2(BaseSegmenter):
  706. def __init__(self,
  707. num_classes=2,
  708. use_mixed_loss=False,
  709. align_corners=False,
  710. **params):
  711. params.update({'align_corners': align_corners})
  712. super(BiSeNetV2, self).__init__(
  713. model_name='BiSeNetV2',
  714. num_classes=num_classes,
  715. use_mixed_loss=use_mixed_loss,
  716. **params)