segmenter.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. from collections import OrderedDict
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddle.static import InputSpec
  21. import paddleseg
  22. import paddlex
  23. from paddlex.cv.transforms import arrange_transforms
  24. from paddlex.utils import get_single_card_bs, DisablePrint
  25. import paddlex.utils.logging as logging
  26. from .base import BaseModel
  27. from .utils import seg_metrics as metrics
  28. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  29. from paddlex.cv.transforms import Decode, Resize
  30. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
  31. class BaseSegmenter(BaseModel):
  32. def __init__(self,
  33. model_name,
  34. num_classes=2,
  35. use_mixed_loss=False,
  36. **params):
  37. self.init_params = locals()
  38. super(BaseSegmenter, self).__init__('segmenter')
  39. if not hasattr(paddleseg.models, model_name):
  40. raise Exception("ERROR: There's no model named {}.".format(
  41. model_name))
  42. self.model_name = model_name
  43. self.num_classes = num_classes
  44. self.use_mixed_loss = use_mixed_loss
  45. self.losses = None
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. def build_net(self, **params):
  49. # TODO: when using paddle.utils.unique_name.guard,
  50. # DeepLabv3p and HRNet will raise a error
  51. net = paddleseg.models.__dict__[self.model_name](
  52. num_classes=self.num_classes, **params)
  53. return net
  54. def _fix_transforms_shape(self, image_shape):
  55. if hasattr(self, 'test_transforms'):
  56. if self.test_transforms is not None:
  57. has_resize_op = False
  58. resize_op_idx = -1
  59. normalize_op_idx = len(self.test_transforms.transforms)
  60. for idx, op in enumerate(self.test_transforms.transforms):
  61. name = op.__class__.__name__
  62. if name == 'Normalize':
  63. normalize_op_idx = idx
  64. if 'Resize' in name:
  65. has_resize_op = True
  66. resize_op_idx = idx
  67. if not has_resize_op:
  68. self.test_transforms.transforms.insert(
  69. normalize_op_idx, Resize(target_size=image_shape))
  70. else:
  71. self.test_transforms.transforms[resize_op_idx] = Resize(
  72. target_size=image_shape)
  73. def _get_test_inputs(self, image_shape):
  74. if image_shape is not None:
  75. if len(image_shape) == 2:
  76. image_shape = [None, 3] + image_shape
  77. self._fix_transforms_shape(image_shape[-2:])
  78. else:
  79. image_shape = [None, 3, -1, -1]
  80. input_spec = [
  81. InputSpec(
  82. shape=image_shape, name='image', dtype='float32')
  83. ]
  84. return input_spec
  85. def run(self, net, inputs, mode):
  86. net_out = net(inputs[0])
  87. logit = net_out[0]
  88. outputs = OrderedDict()
  89. if mode == 'test':
  90. origin_shape = inputs[1]
  91. score_map = self._postprocess(
  92. logit, origin_shape, transforms=inputs[2])
  93. label_map = paddle.argmax(
  94. score_map, axis=1, keepdim=True, dtype='int32')
  95. score_map = paddle.max(score_map, axis=1, keepdim=True)
  96. score_map = paddle.squeeze(score_map)
  97. label_map = paddle.squeeze(label_map)
  98. outputs = {'label_map': label_map, 'score_map': score_map}
  99. if mode == 'eval':
  100. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  101. label = inputs[1]
  102. origin_shape = [label.shape[-2:]]
  103. # TODO: 替换cv2后postprocess移出run
  104. pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
  105. intersect_area, pred_area, label_area = metrics.calculate_area(
  106. pred, label, self.num_classes)
  107. outputs['intersect_area'] = intersect_area
  108. outputs['pred_area'] = pred_area
  109. outputs['label_area'] = label_area
  110. if mode == 'train':
  111. loss_list = metrics.loss_computation(
  112. logits_list=net_out, labels=inputs[1], losses=self.losses)
  113. loss = sum(loss_list)
  114. outputs['loss'] = loss
  115. return outputs
  116. def default_loss(self):
  117. if isinstance(self.use_mixed_loss, bool):
  118. if self.use_mixed_loss:
  119. losses = [
  120. paddleseg.models.CrossEntropyLoss(),
  121. paddleseg.models.LovaszSoftmaxLoss()
  122. ]
  123. coef = [.8, .2]
  124. loss_type = [
  125. paddleseg.models.MixedLoss(
  126. losses=losses, coef=coef),
  127. ]
  128. else:
  129. loss_type = [paddleseg.models.CrossEntropyLoss()]
  130. else:
  131. losses, coef = list(zip(*self.use_mixed_loss))
  132. if not set(losses).issubset(
  133. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  134. raise ValueError(
  135. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  136. )
  137. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  138. loss_type = [
  139. paddleseg.models.MixedLoss(
  140. losses=losses, coef=list(coef))
  141. ]
  142. if self.model_name == 'FastSCNN':
  143. loss_type *= 2
  144. loss_coef = [1.0, 0.4]
  145. elif self.model_name == 'BiSeNetV2':
  146. loss_type *= 5
  147. loss_coef = [1.0] * 5
  148. else:
  149. loss_coef = [1.0]
  150. losses = {'types': loss_type, 'coef': loss_coef}
  151. return losses
  152. def default_optimizer(self,
  153. parameters,
  154. learning_rate,
  155. num_epochs,
  156. num_steps_each_epoch,
  157. lr_decay_power=0.9):
  158. decay_step = num_epochs * num_steps_each_epoch
  159. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  160. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  161. optimizer = paddle.optimizer.Momentum(
  162. learning_rate=lr_scheduler,
  163. parameters=parameters,
  164. momentum=0.9,
  165. weight_decay=4e-5)
  166. return optimizer
  167. def train(self,
  168. num_epochs,
  169. train_dataset,
  170. train_batch_size=2,
  171. eval_dataset=None,
  172. optimizer=None,
  173. save_interval_epochs=1,
  174. log_interval_steps=2,
  175. save_dir='output',
  176. pretrain_weights='CITYSCAPES',
  177. learning_rate=0.01,
  178. lr_decay_power=0.9,
  179. early_stop=False,
  180. early_stop_patience=5,
  181. use_vdl=True):
  182. """
  183. Train the model.
  184. Args:
  185. num_epochs(int): The number of epochs.
  186. train_dataset(paddlex.dataset): Training dataset.
  187. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  188. eval_dataset(paddlex.dataset, optional):
  189. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  190. optimizer(paddle.optimizer.Optimizer or None, optional):
  191. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  192. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  193. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  194. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  195. pretrain_weights(str or None, optional):
  196. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  197. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  198. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  199. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  200. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  201. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  202. """
  203. self.labels = train_dataset.labels
  204. if self.losses is None:
  205. self.losses = self.default_loss()
  206. if optimizer is None:
  207. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  208. self.optimizer = self.default_optimizer(
  209. self.net.parameters(), learning_rate, num_epochs,
  210. num_steps_each_epoch, lr_decay_power)
  211. else:
  212. self.optimizer = optimizer
  213. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  214. if pretrain_weights not in seg_pretrain_weights_dict[
  215. self.model_name]:
  216. logging.warning(
  217. "Path of pretrain_weights('{}') does not exist!".format(
  218. pretrain_weights))
  219. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  220. "If don't want to use pretrain weights, "
  221. "set pretrain_weights to be None.".format(
  222. seg_pretrain_weights_dict[self.model_name][
  223. 0]))
  224. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  225. 0]
  226. pretrained_dir = osp.join(save_dir, 'pretrain')
  227. self.net_initialize(
  228. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  229. self.train_loop(
  230. num_epochs=num_epochs,
  231. train_dataset=train_dataset,
  232. train_batch_size=train_batch_size,
  233. eval_dataset=eval_dataset,
  234. save_interval_epochs=save_interval_epochs,
  235. log_interval_steps=log_interval_steps,
  236. save_dir=save_dir,
  237. early_stop=early_stop,
  238. early_stop_patience=early_stop_patience,
  239. use_vdl=use_vdl)
  240. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  241. """
  242. Evaluate the model.
  243. Args:
  244. eval_dataset(paddlex.dataset): Evaluation dataset.
  245. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  246. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  247. Returns:
  248. collections.OrderedDict with key-value pairs:
  249. {"miou": `mean intersection over union`,
  250. "category_iou": `category-wise mean intersection over union`,
  251. "oacc": `overall accuracy`,
  252. "category_acc": `category-wise accuracy`,
  253. "kappa": ` kappa coefficient`,
  254. "category_F1-score": `F1 score`}.
  255. """
  256. arrange_transforms(
  257. model_type=self.model_type,
  258. transforms=eval_dataset.transforms,
  259. mode='eval')
  260. self.net.eval()
  261. nranks = paddle.distributed.get_world_size()
  262. local_rank = paddle.distributed.get_rank()
  263. if nranks > 1:
  264. # Initialize parallel environment if not done.
  265. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  266. ):
  267. paddle.distributed.init_parallel_env()
  268. batch_size_each_card = get_single_card_bs(batch_size)
  269. if batch_size_each_card > 1:
  270. batch_size_each_card = 1
  271. batch_size = batch_size_each_card * paddlex.env_info['num']
  272. logging.warning(
  273. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  274. "during evaluation, so batch_size " \
  275. "is forcibly set to {}.".format(batch_size))
  276. self.eval_data_loader = self.build_data_loader(
  277. eval_dataset, batch_size=batch_size, mode='eval')
  278. intersect_area_all = 0
  279. pred_area_all = 0
  280. label_area_all = 0
  281. logging.info(
  282. "Start to evaluate(total_samples={}, total_steps={})...".format(
  283. eval_dataset.num_samples,
  284. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  285. with paddle.no_grad():
  286. for step, data in enumerate(self.eval_data_loader):
  287. data.append(eval_dataset.transforms.transforms)
  288. outputs = self.run(self.net, data, 'eval')
  289. pred_area = outputs['pred_area']
  290. label_area = outputs['label_area']
  291. intersect_area = outputs['intersect_area']
  292. # Gather from all ranks
  293. if nranks > 1:
  294. intersect_area_list = []
  295. pred_area_list = []
  296. label_area_list = []
  297. paddle.distributed.all_gather(intersect_area_list,
  298. intersect_area)
  299. paddle.distributed.all_gather(pred_area_list, pred_area)
  300. paddle.distributed.all_gather(label_area_list, label_area)
  301. # Some image has been evaluated and should be eliminated in last iter
  302. if (step + 1) * nranks > len(eval_dataset):
  303. valid = len(eval_dataset) - step * nranks
  304. intersect_area_list = intersect_area_list[:valid]
  305. pred_area_list = pred_area_list[:valid]
  306. label_area_list = label_area_list[:valid]
  307. for i in range(len(intersect_area_list)):
  308. intersect_area_all = intersect_area_all + intersect_area_list[
  309. i]
  310. pred_area_all = pred_area_all + pred_area_list[i]
  311. label_area_all = label_area_all + label_area_list[i]
  312. else:
  313. intersect_area_all = intersect_area_all + intersect_area
  314. pred_area_all = pred_area_all + pred_area
  315. label_area_all = label_area_all + label_area
  316. class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
  317. label_area_all)
  318. # TODO 确认是按oacc还是macc
  319. class_acc, oacc = metrics.accuracy(intersect_area_all, pred_area_all)
  320. kappa = metrics.kappa(intersect_area_all, pred_area_all,
  321. label_area_all)
  322. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  323. label_area_all)
  324. eval_metrics = OrderedDict(
  325. zip([
  326. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  327. 'category_F1-score'
  328. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  329. return eval_metrics
  330. def predict(self, img_file, transforms=None):
  331. """
  332. Do inference.
  333. Args:
  334. Args:
  335. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  336. Image path or decoded image data in a BGR format, which also could constitute a list,
  337. meaning all images to be predicted as a mini-batch.
  338. transforms(paddlex.transforms.Compose or None, optional):
  339. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  340. Returns:
  341. If img_file is a string or np.array, the result is a dict with key-value pairs:
  342. {"label map": `label map`, "score_map": `score map`}.
  343. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  344. label_map(np.ndarray): the predicted label map
  345. score_map(np.ndarray): the prediction score map
  346. """
  347. if transforms is None and not hasattr(self, 'test_transforms'):
  348. raise Exception("transforms need to be defined, now is None.")
  349. if transforms is None:
  350. transforms = self.test_transforms
  351. if isinstance(img_file, (str, np.ndarray)):
  352. images = [img_file]
  353. else:
  354. images = img_file
  355. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  356. self.model_type)
  357. self.net.eval()
  358. data = (batch_im, batch_origin_shape, transforms.transforms)
  359. outputs = self.run(self.net, data, 'test')
  360. label_map = outputs['label_map']
  361. label_map = label_map.numpy().astype('uint8')
  362. score_map = outputs['score_map']
  363. score_map = score_map.numpy().astype('float32')
  364. return {'label_map': label_map, 'score_map': score_map}
  365. def _preprocess(self, images, transforms, model_type):
  366. arrange_transforms(
  367. model_type=model_type, transforms=transforms, mode='test')
  368. batch_im = list()
  369. batch_ori_shape = list()
  370. for im in images:
  371. sample = {'image': im}
  372. if isinstance(sample['image'], str):
  373. sample = Decode(to_rgb=False)(sample)
  374. ori_shape = sample['image'].shape[:2]
  375. im = transforms(sample)[0]
  376. batch_im.append(im)
  377. batch_ori_shape.append(ori_shape)
  378. batch_im = paddle.to_tensor(batch_im)
  379. return batch_im, batch_ori_shape
  380. @staticmethod
  381. def get_transforms_shape_info(batch_ori_shape, transforms):
  382. batch_restore_list = list()
  383. for ori_shape in batch_ori_shape:
  384. restore_list = list()
  385. h, w = ori_shape[0], ori_shape[1]
  386. for op in transforms:
  387. if op.__class__.__name__ in ['Resize', 'ResizeByShort']:
  388. restore_list.append(('resize', (h, w)))
  389. h, w = op.target_size
  390. if op.__class__.__name__ in ['Padding']:
  391. restore_list.append(('padding', (h, w)))
  392. h, w = op.target_size
  393. batch_restore_list.append(restore_list)
  394. return batch_restore_list
  395. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  396. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  397. batch_origin_shape, transforms)
  398. results = list()
  399. for pred, restore_list in zip(batch_pred, batch_restore_list):
  400. pred = paddle.unsqueeze(pred, axis=0)
  401. for item in restore_list[::-1]:
  402. # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
  403. h, w = item[1][0], item[1][1]
  404. if item[0] == 'resize':
  405. pred = F.interpolate(pred, (h, w), mode='nearest')
  406. elif item[0] == 'padding':
  407. pred = pred[:, :, 0:h, 0:w]
  408. else:
  409. pass
  410. results.append(pred)
  411. batch_pred = paddle.concat(results, axis=0)
  412. return batch_pred
  413. class UNet(BaseSegmenter):
  414. def __init__(self,
  415. num_classes=2,
  416. use_mixed_loss=False,
  417. use_deconv=False,
  418. align_corners=False):
  419. params = {'use_deconv': use_deconv, 'align_corners': align_corners}
  420. super(UNet, self).__init__(
  421. model_name='UNet',
  422. num_classes=num_classes,
  423. use_mixed_loss=use_mixed_loss,
  424. **params)
  425. class DeepLabV3P(BaseSegmenter):
  426. def __init__(self,
  427. num_classes=2,
  428. backbone='ResNet50_vd',
  429. use_mixed_loss=False,
  430. output_stride=8,
  431. backbone_indices=(0, 3),
  432. aspp_ratios=(1, 12, 24, 36),
  433. aspp_out_channels=256,
  434. align_corners=False):
  435. self.backbone_name = backbone
  436. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  437. raise ValueError(
  438. "backbone: {} is not supported. Please choose one of "
  439. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  440. with DisablePrint():
  441. backbone = getattr(paddleseg.models, backbone)(
  442. output_stride=output_stride)
  443. params = {
  444. 'backbone': backbone,
  445. 'backbone_indices': backbone_indices,
  446. 'aspp_ratios': aspp_ratios,
  447. 'aspp_out_channels': aspp_out_channels,
  448. 'align_corners': align_corners
  449. }
  450. super(DeepLabV3P, self).__init__(
  451. model_name='DeepLabV3P',
  452. num_classes=num_classes,
  453. use_mixed_loss=use_mixed_loss,
  454. **params)
  455. class FastSCNN(BaseSegmenter):
  456. def __init__(self,
  457. num_classes=2,
  458. use_mixed_loss=False,
  459. align_corners=False):
  460. params = {'align_corners': align_corners}
  461. super(FastSCNN, self).__init__(
  462. model_name='FastSCNN',
  463. num_classes=num_classes,
  464. use_mixed_loss=use_mixed_loss,
  465. **params)
  466. class HRNet(BaseSegmenter):
  467. def __init__(self,
  468. num_classes=2,
  469. width=48,
  470. use_mixed_loss=False,
  471. align_corners=False):
  472. if width not in (18, 48):
  473. raise ValueError(
  474. "width={} is not supported, please choose from [18, 48]".
  475. format(width))
  476. self.backbone_name = 'HRNet_W{}'.format(width)
  477. with DisablePrint():
  478. backbone = getattr(paddleseg.models, self.backbone_name)(
  479. align_corners=align_corners)
  480. params = {'backbone': backbone, 'align_corners': align_corners}
  481. super(HRNet, self).__init__(
  482. model_name='FCN',
  483. num_classes=num_classes,
  484. use_mixed_loss=use_mixed_loss,
  485. **params)
  486. self.model_name = 'HRNet'
  487. class BiSeNetV2(BaseSegmenter):
  488. def __init__(self,
  489. num_classes=2,
  490. use_mixed_loss=False,
  491. align_corners=False):
  492. params = {'align_corners': align_corners}
  493. super(BiSeNetV2, self).__init__(
  494. model_name='BiSeNetV2',
  495. num_classes=num_classes,
  496. use_mixed_loss=use_mixed_loss,
  497. **params)