segmenter.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. from collections import OrderedDict
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddle.static import InputSpec
  21. import paddlex.paddleseg as paddleseg
  22. import paddlex
  23. from paddlex.cv.transforms import arrange_transforms
  24. from paddlex.utils import get_single_card_bs, DisablePrint
  25. import paddlex.utils.logging as logging
  26. from .base import BaseModel
  27. from .utils import seg_metrics as metrics
  28. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  29. from paddlex.cv.transforms import Decode, Resize
  30. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
  31. class BaseSegmenter(BaseModel):
  32. def __init__(self,
  33. model_name,
  34. num_classes=2,
  35. use_mixed_loss=False,
  36. **params):
  37. self.init_params = locals()
  38. if 'with_net' in self.init_params:
  39. del self.init_params['with_net']
  40. super(BaseSegmenter, self).__init__('segmenter')
  41. if not hasattr(paddleseg.models, model_name):
  42. raise Exception("ERROR: There's no model named {}.".format(
  43. model_name))
  44. self.model_name = model_name
  45. self.num_classes = num_classes
  46. self.use_mixed_loss = use_mixed_loss
  47. self.losses = None
  48. self.labels = None
  49. if params.get('with_net', True):
  50. params.pop('with_net', None)
  51. self.net = self.build_net(**params)
  52. self.find_unused_parameters = True
  53. def build_net(self, **params):
  54. # TODO: when using paddle.utils.unique_name.guard,
  55. # DeepLabv3p and HRNet will raise a error
  56. net = paddleseg.models.__dict__[self.model_name](
  57. num_classes=self.num_classes, **params)
  58. return net
  59. def _fix_transforms_shape(self, image_shape):
  60. if hasattr(self, 'test_transforms'):
  61. if self.test_transforms is not None:
  62. has_resize_op = False
  63. resize_op_idx = -1
  64. normalize_op_idx = len(self.test_transforms.transforms)
  65. for idx, op in enumerate(self.test_transforms.transforms):
  66. name = op.__class__.__name__
  67. if name == 'Normalize':
  68. normalize_op_idx = idx
  69. if 'Resize' in name:
  70. has_resize_op = True
  71. resize_op_idx = idx
  72. if not has_resize_op:
  73. self.test_transforms.transforms.insert(
  74. normalize_op_idx, Resize(target_size=image_shape))
  75. else:
  76. self.test_transforms.transforms[resize_op_idx] = Resize(
  77. target_size=image_shape)
  78. def _get_test_inputs(self, image_shape):
  79. if image_shape is not None:
  80. if len(image_shape) == 2:
  81. image_shape = [1, 3] + image_shape
  82. self._fix_transforms_shape(image_shape[-2:])
  83. else:
  84. image_shape = [None, 3, -1, -1]
  85. self.fixed_input_shape = image_shape
  86. input_spec = [
  87. InputSpec(
  88. shape=image_shape, name='image', dtype='float32')
  89. ]
  90. return input_spec
  91. def run(self, net, inputs, mode):
  92. net_out = net(inputs[0])
  93. logit = net_out[0]
  94. outputs = OrderedDict()
  95. if mode == 'test':
  96. origin_shape = inputs[1]
  97. score_map = self._postprocess(
  98. logit, origin_shape, transforms=inputs[2])
  99. label_map = paddle.argmax(
  100. score_map, axis=1, keepdim=True, dtype='int32')
  101. score_map = paddle.max(score_map, axis=1, keepdim=True)
  102. score_map = paddle.squeeze(score_map)
  103. label_map = paddle.squeeze(label_map)
  104. outputs = {'label_map': label_map, 'score_map': score_map}
  105. if mode == 'eval':
  106. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  107. label = inputs[1]
  108. origin_shape = [label.shape[-2:]]
  109. # TODO: 替换cv2后postprocess移出run
  110. pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
  111. intersect_area, pred_area, label_area = paddleseg.utils.metrics.calculate_area(
  112. pred, label, self.num_classes)
  113. outputs['intersect_area'] = intersect_area
  114. outputs['pred_area'] = pred_area
  115. outputs['label_area'] = label_area
  116. outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
  117. self.num_classes)
  118. if mode == 'train':
  119. loss_list = metrics.loss_computation(
  120. logits_list=net_out, labels=inputs[1], losses=self.losses)
  121. loss = sum(loss_list)
  122. outputs['loss'] = loss
  123. return outputs
  124. def default_loss(self):
  125. if isinstance(self.use_mixed_loss, bool):
  126. if self.use_mixed_loss:
  127. losses = [
  128. paddleseg.models.CrossEntropyLoss(),
  129. paddleseg.models.LovaszSoftmaxLoss()
  130. ]
  131. coef = [.8, .2]
  132. loss_type = [
  133. paddleseg.models.MixedLoss(
  134. losses=losses, coef=coef),
  135. ]
  136. else:
  137. loss_type = [paddleseg.models.CrossEntropyLoss()]
  138. else:
  139. losses, coef = list(zip(*self.use_mixed_loss))
  140. if not set(losses).issubset(
  141. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  142. raise ValueError(
  143. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  144. )
  145. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  146. loss_type = [
  147. paddleseg.models.MixedLoss(
  148. losses=losses, coef=list(coef))
  149. ]
  150. if self.model_name == 'FastSCNN':
  151. loss_type *= 2
  152. loss_coef = [1.0, 0.4]
  153. elif self.model_name == 'BiSeNetV2':
  154. loss_type *= 5
  155. loss_coef = [1.0] * 5
  156. else:
  157. loss_coef = [1.0]
  158. losses = {'types': loss_type, 'coef': loss_coef}
  159. return losses
  160. def default_optimizer(self,
  161. parameters,
  162. learning_rate,
  163. num_epochs,
  164. num_steps_each_epoch,
  165. lr_decay_power=0.9):
  166. decay_step = num_epochs * num_steps_each_epoch
  167. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  168. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  169. optimizer = paddle.optimizer.Momentum(
  170. learning_rate=lr_scheduler,
  171. parameters=parameters,
  172. momentum=0.9,
  173. weight_decay=4e-5)
  174. return optimizer
  175. def train(self,
  176. num_epochs,
  177. train_dataset,
  178. train_batch_size=2,
  179. eval_dataset=None,
  180. optimizer=None,
  181. save_interval_epochs=1,
  182. log_interval_steps=2,
  183. save_dir='output',
  184. pretrain_weights='CITYSCAPES',
  185. learning_rate=0.01,
  186. lr_decay_power=0.9,
  187. early_stop=False,
  188. early_stop_patience=5,
  189. use_vdl=True,
  190. resume_checkpoint=None):
  191. """
  192. Train the model.
  193. Args:
  194. num_epochs(int): The number of epochs.
  195. train_dataset(paddlex.dataset): Training dataset.
  196. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  197. eval_dataset(paddlex.dataset, optional):
  198. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  199. optimizer(paddle.optimizer.Optimizer or None, optional):
  200. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  201. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  202. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  203. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  204. pretrain_weights(str or None, optional):
  205. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'CITYSCAPES'.
  206. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  207. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  208. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  209. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  210. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  211. resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
  212. If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
  213. `pretrain_weights` can be set simultaneously. Defaults to None.
  214. """
  215. if pretrain_weights is not None and resume_checkpoint is not None:
  216. logging.error(
  217. "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
  218. exit=True)
  219. self.labels = train_dataset.labels
  220. if self.losses is None:
  221. self.losses = self.default_loss()
  222. if optimizer is None:
  223. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  224. self.optimizer = self.default_optimizer(
  225. self.net.parameters(), learning_rate, num_epochs,
  226. num_steps_each_epoch, lr_decay_power)
  227. else:
  228. self.optimizer = optimizer
  229. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  230. if pretrain_weights not in seg_pretrain_weights_dict[
  231. self.model_name]:
  232. logging.warning(
  233. "Path of pretrain_weights('{}') does not exist!".format(
  234. pretrain_weights))
  235. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  236. "If don't want to use pretrain weights, "
  237. "set pretrain_weights to be None.".format(
  238. seg_pretrain_weights_dict[self.model_name][
  239. 0]))
  240. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  241. 0]
  242. elif pretrain_weights is not None and osp.exists(pretrain_weights):
  243. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  244. logging.error(
  245. "Invalid pretrain weights. Please specify a '.pdparams' file.",
  246. exit=True)
  247. pretrained_dir = osp.join(save_dir, 'pretrain')
  248. is_backbone_weights = pretrain_weights == 'IMAGENET'
  249. self.net_initialize(
  250. pretrain_weights=pretrain_weights,
  251. save_dir=pretrained_dir,
  252. resume_checkpoint=resume_checkpoint,
  253. is_backbone_weights=is_backbone_weights)
  254. self.train_loop(
  255. num_epochs=num_epochs,
  256. train_dataset=train_dataset,
  257. train_batch_size=train_batch_size,
  258. eval_dataset=eval_dataset,
  259. save_interval_epochs=save_interval_epochs,
  260. log_interval_steps=log_interval_steps,
  261. save_dir=save_dir,
  262. early_stop=early_stop,
  263. early_stop_patience=early_stop_patience,
  264. use_vdl=use_vdl)
  265. def quant_aware_train(self,
  266. num_epochs,
  267. train_dataset,
  268. train_batch_size=2,
  269. eval_dataset=None,
  270. optimizer=None,
  271. save_interval_epochs=1,
  272. log_interval_steps=2,
  273. save_dir='output',
  274. learning_rate=0.0001,
  275. lr_decay_power=0.9,
  276. early_stop=False,
  277. early_stop_patience=5,
  278. use_vdl=True,
  279. resume_checkpoint=None,
  280. quant_config=None):
  281. """
  282. Quantization-aware training.
  283. Args:
  284. num_epochs(int): The number of epochs.
  285. train_dataset(paddlex.dataset): Training dataset.
  286. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  287. eval_dataset(paddlex.dataset, optional):
  288. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  289. optimizer(paddle.optimizer.Optimizer or None, optional):
  290. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  291. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  292. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  293. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  294. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  295. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  296. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  297. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  298. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  299. quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
  300. configuration will be used. Defaults to None.
  301. resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
  302. from. If None, no training checkpoint will be resumed. Defaults to None.
  303. """
  304. self._prepare_qat(quant_config)
  305. self.train(
  306. num_epochs=num_epochs,
  307. train_dataset=train_dataset,
  308. train_batch_size=train_batch_size,
  309. eval_dataset=eval_dataset,
  310. optimizer=optimizer,
  311. save_interval_epochs=save_interval_epochs,
  312. log_interval_steps=log_interval_steps,
  313. save_dir=save_dir,
  314. pretrain_weights=None,
  315. learning_rate=learning_rate,
  316. lr_decay_power=lr_decay_power,
  317. early_stop=early_stop,
  318. early_stop_patience=early_stop_patience,
  319. use_vdl=use_vdl,
  320. resume_checkpoint=resume_checkpoint)
  321. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  322. """
  323. Evaluate the model.
  324. Args:
  325. eval_dataset(paddlex.dataset): Evaluation dataset.
  326. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  327. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  328. Returns:
  329. collections.OrderedDict with key-value pairs:
  330. {"miou": `mean intersection over union`,
  331. "category_iou": `category-wise mean intersection over union`,
  332. "oacc": `overall accuracy`,
  333. "category_acc": `category-wise accuracy`,
  334. "kappa": ` kappa coefficient`,
  335. "category_F1-score": `F1 score`}.
  336. """
  337. arrange_transforms(
  338. model_type=self.model_type,
  339. transforms=eval_dataset.transforms,
  340. mode='eval')
  341. self.net.eval()
  342. nranks = paddle.distributed.get_world_size()
  343. local_rank = paddle.distributed.get_rank()
  344. if nranks > 1:
  345. # Initialize parallel environment if not done.
  346. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  347. ):
  348. paddle.distributed.init_parallel_env()
  349. batch_size_each_card = get_single_card_bs(batch_size)
  350. if batch_size_each_card > 1:
  351. batch_size_each_card = 1
  352. batch_size = batch_size_each_card * paddlex.env_info['num']
  353. logging.warning(
  354. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  355. "during evaluation, so batch_size " \
  356. "is forcibly set to {}.".format(batch_size))
  357. self.eval_data_loader = self.build_data_loader(
  358. eval_dataset, batch_size=batch_size, mode='eval')
  359. intersect_area_all = 0
  360. pred_area_all = 0
  361. label_area_all = 0
  362. conf_mat_all = []
  363. logging.info(
  364. "Start to evaluate(total_samples={}, total_steps={})...".format(
  365. eval_dataset.num_samples,
  366. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  367. with paddle.no_grad():
  368. for step, data in enumerate(self.eval_data_loader):
  369. data.append(eval_dataset.transforms.transforms)
  370. outputs = self.run(self.net, data, 'eval')
  371. pred_area = outputs['pred_area']
  372. label_area = outputs['label_area']
  373. intersect_area = outputs['intersect_area']
  374. conf_mat = outputs['conf_mat']
  375. # Gather from all ranks
  376. if nranks > 1:
  377. intersect_area_list = []
  378. pred_area_list = []
  379. label_area_list = []
  380. conf_mat_list = []
  381. paddle.distributed.all_gather(intersect_area_list,
  382. intersect_area)
  383. paddle.distributed.all_gather(pred_area_list, pred_area)
  384. paddle.distributed.all_gather(label_area_list, label_area)
  385. paddle.distributed.all_gather(conf_mat_list, conf_mat)
  386. # Some image has been evaluated and should be eliminated in last iter
  387. if (step + 1) * nranks > len(eval_dataset):
  388. valid = len(eval_dataset) - step * nranks
  389. intersect_area_list = intersect_area_list[:valid]
  390. pred_area_list = pred_area_list[:valid]
  391. label_area_list = label_area_list[:valid]
  392. conf_mat_list = conf_mat_list[:valid]
  393. intersect_area_all += sum(intersect_area_list)
  394. pred_area_all += sum(pred_area_list)
  395. label_area_all += sum(label_area_list)
  396. conf_mat_all.extend(conf_mat_list)
  397. else:
  398. intersect_area_all = intersect_area_all + intersect_area
  399. pred_area_all = pred_area_all + pred_area
  400. label_area_all = label_area_all + label_area
  401. conf_mat_all.append(conf_mat)
  402. class_iou, miou = paddleseg.utils.metrics.mean_iou(
  403. intersect_area_all, pred_area_all, label_area_all)
  404. # TODO 确认是按oacc还是macc
  405. class_acc, oacc = paddleseg.utils.metrics.accuracy(intersect_area_all,
  406. pred_area_all)
  407. kappa = paddleseg.utils.metrics.kappa(intersect_area_all,
  408. pred_area_all, label_area_all)
  409. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  410. label_area_all)
  411. eval_metrics = OrderedDict(
  412. zip([
  413. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  414. 'category_F1-score'
  415. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  416. if return_details:
  417. conf_mat = sum(conf_mat_all)
  418. eval_details = {'confusion_matrix': conf_mat.tolist()}
  419. return eval_metrics, eval_details
  420. return eval_metrics
  421. def predict(self, img_file, transforms=None):
  422. """
  423. Do inference.
  424. Args:
  425. Args:
  426. img_file(List[np.ndarray or str], str or np.ndarray):
  427. Image path or decoded image data in a BGR format, which also could constitute a list,
  428. meaning all images to be predicted as a mini-batch.
  429. transforms(paddlex.transforms.Compose or None, optional):
  430. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  431. Returns:
  432. If img_file is a string or np.array, the result is a dict with key-value pairs:
  433. {"label map": `label map`, "score_map": `score map`}.
  434. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  435. label_map(np.ndarray): the predicted label map
  436. score_map(np.ndarray): the prediction score map
  437. """
  438. if transforms is None and not hasattr(self, 'test_transforms'):
  439. raise Exception("transforms need to be defined, now is None.")
  440. if transforms is None:
  441. transforms = self.test_transforms
  442. if isinstance(img_file, (str, np.ndarray)):
  443. images = [img_file]
  444. else:
  445. images = img_file
  446. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  447. self.model_type)
  448. self.net.eval()
  449. data = (batch_im, batch_origin_shape, transforms.transforms)
  450. outputs = self.run(self.net, data, 'test')
  451. label_map = outputs['label_map']
  452. label_map = label_map.numpy().astype('uint8')
  453. score_map = outputs['score_map']
  454. score_map = score_map.numpy().astype('float32')
  455. if isinstance(img_file, list) and len(img_file) > 1:
  456. prediction = [{
  457. 'label_map': l,
  458. 'score_map': s
  459. } for l, s in zip(label_map, score_map)]
  460. elif isinstance(img_file, list):
  461. prediction = [{'label_map': label_map, 'score_map': score_map}]
  462. else:
  463. prediction = {'label_map': label_map, 'score_map': score_map}
  464. return prediction
  465. def _preprocess(self, images, transforms, to_tensor=True):
  466. arrange_transforms(
  467. model_type=self.model_type, transforms=transforms, mode='test')
  468. batch_im = list()
  469. batch_ori_shape = list()
  470. for im in images:
  471. sample = {'image': im}
  472. if isinstance(sample['image'], str):
  473. sample = Decode(to_rgb=False)(sample)
  474. ori_shape = sample['image'].shape[:2]
  475. im = transforms(sample)[0]
  476. batch_im.append(im)
  477. batch_ori_shape.append(ori_shape)
  478. if to_tensor:
  479. batch_im = paddle.to_tensor(batch_im)
  480. return batch_im, batch_ori_shape
  481. @staticmethod
  482. def get_transforms_shape_info(batch_ori_shape, transforms):
  483. batch_restore_list = list()
  484. for ori_shape in batch_ori_shape:
  485. restore_list = list()
  486. h, w = ori_shape[0], ori_shape[1]
  487. for op in transforms:
  488. if op.__class__.__name__ == 'Resize':
  489. restore_list.append(('resize', (h, w)))
  490. h, w = op.target_size
  491. elif op.__class__.__name__ == 'ResizeByShort':
  492. restore_list.append(('resize', (h, w)))
  493. im_short_size = min(h, w)
  494. im_long_size = max(h, w)
  495. scale = float(op.short_size) / float(im_short_size)
  496. if 0 < op.max_size < np.round(scale * im_long_size):
  497. scale = float(op.max_size) / float(im_long_size)
  498. h = int(round(h * scale))
  499. w = int(round(w * scale))
  500. elif op.__class__.__name__ == 'ResizeByLong':
  501. restore_list.append(('resize', (h, w)))
  502. im_long_size = max(h, w)
  503. scale = float(op.long_size) / float(im_long_size)
  504. h = int(round(h * scale))
  505. w = int(round(w * scale))
  506. elif op.__class__.__name__ == 'Padding':
  507. if op.target_size:
  508. target_h, target_w = op.target_size
  509. else:
  510. target_h = int(
  511. (np.ceil(h / op.size_divisor) * op.size_divisor))
  512. target_w = int(
  513. (np.ceil(w / op.size_divisor) * op.size_divisor))
  514. if op.pad_mode == -1:
  515. offsets = op.offsets
  516. elif op.pad_mode == 0:
  517. offsets = [0, 0]
  518. elif op.pad_mode == 1:
  519. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  520. else:
  521. offsets = [target_h - h, target_w - w]
  522. restore_list.append(('padding', (h, w), offsets))
  523. h, w = target_h, target_w
  524. batch_restore_list.append(restore_list)
  525. return batch_restore_list
  526. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  527. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  528. batch_origin_shape, transforms)
  529. results = list()
  530. for pred, restore_list in zip(batch_pred, batch_restore_list):
  531. pred = paddle.unsqueeze(pred, axis=0)
  532. for item in restore_list[::-1]:
  533. # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
  534. h, w = item[1][0], item[1][1]
  535. if item[0] == 'resize':
  536. pred = F.interpolate(pred, (h, w), mode='nearest')
  537. elif item[0] == 'padding':
  538. x, y = item[2]
  539. pred = pred[:, :, y:y + h, x:x + w]
  540. else:
  541. pass
  542. results.append(pred)
  543. batch_pred = paddle.concat(results, axis=0)
  544. return batch_pred
  545. class UNet(BaseSegmenter):
  546. def __init__(self,
  547. num_classes=2,
  548. use_mixed_loss=False,
  549. use_deconv=False,
  550. align_corners=False,
  551. **params):
  552. params.update({
  553. 'use_deconv': use_deconv,
  554. 'align_corners': align_corners
  555. })
  556. super(UNet, self).__init__(
  557. model_name='UNet',
  558. num_classes=num_classes,
  559. use_mixed_loss=use_mixed_loss,
  560. **params)
  561. class DeepLabV3P(BaseSegmenter):
  562. def __init__(self,
  563. num_classes=2,
  564. backbone='ResNet50_vd',
  565. use_mixed_loss=False,
  566. output_stride=8,
  567. backbone_indices=(0, 3),
  568. aspp_ratios=(1, 12, 24, 36),
  569. aspp_out_channels=256,
  570. align_corners=False,
  571. **params):
  572. self.backbone_name = backbone
  573. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  574. raise ValueError(
  575. "backbone: {} is not supported. Please choose one of "
  576. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  577. if params.get('with_net', True):
  578. with DisablePrint():
  579. backbone = getattr(paddleseg.models, backbone)(
  580. output_stride=output_stride)
  581. else:
  582. backbone = None
  583. params.update({
  584. 'backbone': backbone,
  585. 'backbone_indices': backbone_indices,
  586. 'aspp_ratios': aspp_ratios,
  587. 'aspp_out_channels': aspp_out_channels,
  588. 'align_corners': align_corners
  589. })
  590. super(DeepLabV3P, self).__init__(
  591. model_name='DeepLabV3P',
  592. num_classes=num_classes,
  593. use_mixed_loss=use_mixed_loss,
  594. **params)
  595. class FastSCNN(BaseSegmenter):
  596. def __init__(self,
  597. num_classes=2,
  598. use_mixed_loss=False,
  599. align_corners=False,
  600. **params):
  601. params.update({'align_corners': align_corners})
  602. super(FastSCNN, self).__init__(
  603. model_name='FastSCNN',
  604. num_classes=num_classes,
  605. use_mixed_loss=use_mixed_loss,
  606. **params)
  607. class HRNet(BaseSegmenter):
  608. def __init__(self,
  609. num_classes=2,
  610. width=48,
  611. use_mixed_loss=False,
  612. align_corners=False,
  613. **params):
  614. if width not in (18, 48):
  615. raise ValueError(
  616. "width={} is not supported, please choose from [18, 48]".
  617. format(width))
  618. self.backbone_name = 'HRNet_W{}'.format(width)
  619. if params.get('with_net', True):
  620. with DisablePrint():
  621. backbone = getattr(paddleseg.models, self.backbone_name)(
  622. align_corners=align_corners)
  623. else:
  624. backbone = None
  625. params.update({'backbone': backbone, 'align_corners': align_corners})
  626. super(HRNet, self).__init__(
  627. model_name='FCN',
  628. num_classes=num_classes,
  629. use_mixed_loss=use_mixed_loss,
  630. **params)
  631. self.model_name = 'HRNet'
  632. class BiSeNetV2(BaseSegmenter):
  633. def __init__(self,
  634. num_classes=2,
  635. use_mixed_loss=False,
  636. align_corners=False,
  637. **params):
  638. params.update({'align_corners': align_corners})
  639. super(BiSeNetV2, self).__init__(
  640. model_name='BiSeNetV2',
  641. num_classes=num_classes,
  642. use_mixed_loss=use_mixed_loss,
  643. **params)