segmenter.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. import numpy as np
  17. from collections import OrderedDict
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddle.static import InputSpec
  21. import paddleseg
  22. import paddlex
  23. from paddlex.cv.transforms import arrange_transforms
  24. from paddlex.utils import get_single_card_bs, DisablePrint
  25. import paddlex.utils.logging as logging
  26. from .base import BaseModel
  27. from .utils import seg_metrics as metrics
  28. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  29. from paddlex.cv.transforms import Decode
  30. __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2"]
  31. class BaseSegmenter(BaseModel):
  32. def __init__(self,
  33. model_name,
  34. num_classes=2,
  35. use_mixed_loss=False,
  36. **params):
  37. self.init_params = locals()
  38. super(BaseSegmenter, self).__init__('segmenter')
  39. if not hasattr(paddleseg.models, model_name):
  40. raise Exception("ERROR: There's no model named {}.".format(
  41. model_name))
  42. self.model_name = model_name
  43. self.num_classes = num_classes
  44. self.use_mixed_loss = use_mixed_loss
  45. self.losses = None
  46. self.labels = None
  47. self.net = self.build_net(**params)
  48. self.find_unused_parameters = True
  49. def build_net(self, **params):
  50. # TODO: when using paddle.utils.unique_name.guard,
  51. # DeepLabv3p and HRNet will raise a error
  52. net = paddleseg.models.__dict__[self.model_name](
  53. num_classes=self.num_classes, **params)
  54. return net
  55. def get_test_inputs(self, image_shape):
  56. input_spec = [
  57. InputSpec(
  58. shape=[None, 3] + image_shape, name='image', dtype='float32')
  59. ]
  60. return input_spec
  61. def run(self, net, inputs, mode):
  62. net_out = net(inputs[0])
  63. logit = net_out[0]
  64. outputs = OrderedDict()
  65. if mode == 'test':
  66. origin_shape = inputs[1]
  67. score_map = self._postprocess(
  68. logit, origin_shape, transforms=inputs[2])
  69. label_map = paddle.argmax(
  70. score_map, axis=1, keepdim=True, dtype='int32')
  71. score_map = paddle.max(score_map, axis=1, keepdim=True)
  72. score_map = paddle.squeeze(score_map)
  73. label_map = paddle.squeeze(label_map)
  74. outputs = {'label_map': label_map, 'score_map': score_map}
  75. if mode == 'eval':
  76. pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
  77. label = inputs[1]
  78. origin_shape = [label.shape[-2:]]
  79. # TODO: 替换cv2后postprocess移出run
  80. pred = self._postprocess(pred, origin_shape, transforms=inputs[2])
  81. intersect_area, pred_area, label_area = metrics.calculate_area(
  82. pred, label, self.num_classes)
  83. outputs['intersect_area'] = intersect_area
  84. outputs['pred_area'] = pred_area
  85. outputs['label_area'] = label_area
  86. if mode == 'train':
  87. loss_list = metrics.loss_computation(
  88. logits_list=net_out, labels=inputs[1], losses=self.losses)
  89. loss = sum(loss_list)
  90. outputs['loss'] = loss
  91. return outputs
  92. def default_loss(self):
  93. if isinstance(self.use_mixed_loss, bool):
  94. if self.use_mixed_loss:
  95. losses = [
  96. paddleseg.models.CrossEntropyLoss(),
  97. paddleseg.models.LovaszSoftmaxLoss()
  98. ]
  99. coef = [.8, .2]
  100. loss_type = [
  101. paddleseg.models.MixedLoss(
  102. losses=losses, coef=coef),
  103. ]
  104. else:
  105. loss_type = [paddleseg.models.CrossEntropyLoss()]
  106. else:
  107. losses, coef = list(zip(*self.use_mixed_loss))
  108. if not set(losses).issubset(
  109. ['CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss']):
  110. raise ValueError(
  111. "Only 'CrossEntropyLoss', 'DiceLoss', 'LovaszSoftmaxLoss' are supported."
  112. )
  113. losses = [getattr(paddleseg.models, loss)() for loss in losses]
  114. loss_type = [
  115. paddleseg.models.MixedLoss(
  116. losses=losses, coef=list(coef))
  117. ]
  118. if self.model_name == 'FastSCNN':
  119. loss_type *= 2
  120. loss_coef = [1.0, 0.4]
  121. elif self.model_name == 'BiSeNetV2':
  122. loss_type *= 5
  123. loss_coef = [1.0] * 5
  124. else:
  125. loss_coef = [1.0]
  126. losses = {'types': loss_type, 'coef': loss_coef}
  127. return losses
  128. def default_optimizer(self,
  129. parameters,
  130. learning_rate,
  131. num_epochs,
  132. num_steps_each_epoch,
  133. lr_decay_power=0.9):
  134. decay_step = num_epochs * num_steps_each_epoch
  135. lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
  136. learning_rate, decay_step, end_lr=0, power=lr_decay_power)
  137. optimizer = paddle.optimizer.Momentum(
  138. learning_rate=lr_scheduler,
  139. parameters=parameters,
  140. momentum=0.9,
  141. weight_decay=4e-5)
  142. return optimizer
  143. def train(self,
  144. num_epochs,
  145. train_dataset,
  146. train_batch_size=2,
  147. eval_dataset=None,
  148. optimizer=None,
  149. save_interval_epochs=1,
  150. log_interval_steps=2,
  151. save_dir='output',
  152. pretrain_weights='CITYSCAPES',
  153. learning_rate=0.01,
  154. lr_decay_power=0.9,
  155. early_stop=False,
  156. early_stop_patience=5,
  157. use_vdl=True):
  158. """
  159. Train the model.
  160. Args:
  161. num_epochs(int): The number of epochs.
  162. train_dataset(paddlex.dataset): Training dataset.
  163. train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 2.
  164. eval_dataset(paddlex.dataset, optional):
  165. Evaluation dataset. If None, the model will not be evaluated furing training process. Defaults to None.
  166. optimizer(paddle.optimizer.Optimizer or None, optional):
  167. Optimizer used in training. If None, a default optimizer is used. Defaults to None.
  168. save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
  169. log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
  170. save_dir(str, optional): Directory to save the model. Defaults to 'output'.
  171. pretrain_weights(str or None, optional):
  172. None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
  173. learning_rate(float, optional): Learning rate for training. Defaults to .025.
  174. lr_decay_power(float, optional): Learning decay power. Defaults to .9.
  175. early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
  176. early_stop_patience(int, optional): Early stop patience. Defaults to 5.
  177. use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
  178. """
  179. self.labels = train_dataset.labels
  180. if self.losses is None:
  181. self.losses = self.default_loss()
  182. if optimizer is None:
  183. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  184. self.optimizer = self.default_optimizer(
  185. self.net.parameters(), learning_rate, num_epochs,
  186. num_steps_each_epoch, lr_decay_power)
  187. else:
  188. self.optimizer = optimizer
  189. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  190. if pretrain_weights not in seg_pretrain_weights_dict[
  191. self.model_name]:
  192. logging.warning(
  193. "Path of pretrain_weights('{}') does not exist!".format(
  194. pretrain_weights))
  195. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  196. "If don't want to use pretrain weights, "
  197. "set pretrain_weights to be None.".format(
  198. seg_pretrain_weights_dict[self.model_name][
  199. 0]))
  200. pretrain_weights = seg_pretrain_weights_dict[self.model_name][
  201. 0]
  202. pretrained_dir = osp.join(save_dir, 'pretrain')
  203. self.net_initialize(
  204. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  205. self.train_loop(
  206. num_epochs=num_epochs,
  207. train_dataset=train_dataset,
  208. train_batch_size=train_batch_size,
  209. eval_dataset=eval_dataset,
  210. save_interval_epochs=save_interval_epochs,
  211. log_interval_steps=log_interval_steps,
  212. save_dir=save_dir,
  213. early_stop=early_stop,
  214. early_stop_patience=early_stop_patience,
  215. use_vdl=use_vdl)
  216. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  217. """
  218. Evaluate the model.
  219. Args:
  220. eval_dataset(paddlex.dataset): Evaluation dataset.
  221. batch_size(int, optional): Total batch size among all cards used for evaluation. Defaults to 1.
  222. return_details(bool, optional): Whether to return evaluation details. Defaults to False.
  223. Returns:
  224. collections.OrderedDict with key-value pairs:
  225. {"miou": `mean intersection over union`,
  226. "category_iou": `category-wise mean intersection over union`,
  227. "oacc": `overall accuracy`,
  228. "category_acc": `category-wise accuracy`,
  229. "kappa": ` kappa coefficient`,
  230. "category_F1-score": `F1 score`}.
  231. """
  232. arrange_transforms(
  233. model_type=self.model_type,
  234. transforms=eval_dataset.transforms,
  235. mode='eval')
  236. self.net.eval()
  237. nranks = paddle.distributed.get_world_size()
  238. local_rank = paddle.distributed.get_rank()
  239. if nranks > 1:
  240. # Initialize parallel environment if not done.
  241. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  242. ):
  243. paddle.distributed.init_parallel_env()
  244. batch_size_each_card = get_single_card_bs(batch_size)
  245. if batch_size_each_card > 1:
  246. batch_size_each_card = 1
  247. batch_size = batch_size_each_card * paddlex.env_info['num']
  248. logging.warning(
  249. "Segmenter only supports batch_size=1 for each gpu/cpu card " \
  250. "during evaluation, so batch_size " \
  251. "is forcibly set to {}.".format(batch_size))
  252. self.eval_data_loader = self.build_data_loader(
  253. eval_dataset, batch_size=batch_size, mode='eval')
  254. intersect_area_all = 0
  255. pred_area_all = 0
  256. label_area_all = 0
  257. logging.info(
  258. "Start to evaluate(total_samples={}, total_steps={})...".format(
  259. eval_dataset.num_samples,
  260. math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
  261. with paddle.no_grad():
  262. for step, data in enumerate(self.eval_data_loader):
  263. data.append(eval_dataset.transforms.transforms)
  264. outputs = self.run(self.net, data, 'eval')
  265. pred_area = outputs['pred_area']
  266. label_area = outputs['label_area']
  267. intersect_area = outputs['intersect_area']
  268. # Gather from all ranks
  269. if nranks > 1:
  270. intersect_area_list = []
  271. pred_area_list = []
  272. label_area_list = []
  273. paddle.distributed.all_gather(intersect_area_list,
  274. intersect_area)
  275. paddle.distributed.all_gather(pred_area_list, pred_area)
  276. paddle.distributed.all_gather(label_area_list, label_area)
  277. # Some image has been evaluated and should be eliminated in last iter
  278. if (step + 1) * nranks > len(eval_dataset):
  279. valid = len(eval_dataset) - step * nranks
  280. intersect_area_list = intersect_area_list[:valid]
  281. pred_area_list = pred_area_list[:valid]
  282. label_area_list = label_area_list[:valid]
  283. for i in range(len(intersect_area_list)):
  284. intersect_area_all = intersect_area_all + intersect_area_list[
  285. i]
  286. pred_area_all = pred_area_all + pred_area_list[i]
  287. label_area_all = label_area_all + label_area_list[i]
  288. else:
  289. intersect_area_all = intersect_area_all + intersect_area
  290. pred_area_all = pred_area_all + pred_area
  291. label_area_all = label_area_all + label_area
  292. class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
  293. label_area_all)
  294. # TODO 确认是按oacc还是macc
  295. class_acc, oacc = metrics.accuracy(intersect_area_all, pred_area_all)
  296. kappa = metrics.kappa(intersect_area_all, pred_area_all,
  297. label_area_all)
  298. category_f1score = metrics.f1_score(intersect_area_all, pred_area_all,
  299. label_area_all)
  300. eval_metrics = OrderedDict(
  301. zip([
  302. 'miou', 'category_iou', 'oacc', 'category_acc', 'kappa',
  303. 'category_F1-score'
  304. ], [miou, class_iou, oacc, class_acc, kappa, category_f1score]))
  305. return eval_metrics
  306. def predict(self, img_file, transforms=None):
  307. """
  308. Do inference.
  309. Args:
  310. Args:
  311. img_file(List[np.ndarray or str], str or np.ndarray): img_file(list or str or np.array):
  312. Image path or decoded image data in a BGR format, which also could constitute a list,
  313. meaning all images to be predicted as a mini-batch.
  314. transforms(paddlex.transforms.Compose or None, optional):
  315. Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
  316. Returns:
  317. If img_file is a string or np.array, the result is a dict with key-value pairs:
  318. {"label map": `label map`, "score_map": `score map`}.
  319. If img_file is a list, the result is a list composed of dicts with the corresponding fields:
  320. label_map(np.ndarray): the predicted label map
  321. score_map(np.ndarray): the prediction score map
  322. """
  323. if transforms is None and not hasattr(self, 'test_transforms'):
  324. raise Exception("transforms need to be defined, now is None.")
  325. if transforms is None:
  326. transforms = self.test_transforms
  327. if isinstance(img_file, (str, np.ndarray)):
  328. images = [img_file]
  329. else:
  330. images = img_file
  331. batch_im, batch_origin_shape = self._preprocess(images, transforms,
  332. self.model_type)
  333. self.net.eval()
  334. data = (batch_im, batch_origin_shape, transforms.transforms)
  335. outputs = self.run(self.net, data, 'test')
  336. label_map = outputs['label_map']
  337. label_map = label_map.numpy().astype('uint8')
  338. score_map = outputs['score_map']
  339. score_map = score_map.numpy().astype('float32')
  340. return {'label_map': label_map, 'score_map': score_map}
  341. def _preprocess(self, images, transforms, model_type):
  342. arrange_transforms(
  343. model_type=model_type, transforms=transforms, mode='test')
  344. batch_im = list()
  345. batch_ori_shape = list()
  346. for im in images:
  347. sample = {'image': im}
  348. if isinstance(sample['image'], str):
  349. sample = Decode(to_rgb=False)(sample)
  350. ori_shape = sample['image'].shape[:2]
  351. im = transforms(sample)[0]
  352. batch_im.append(im)
  353. batch_ori_shape.append(ori_shape)
  354. batch_im = paddle.to_tensor(batch_im)
  355. return batch_im, batch_ori_shape
  356. @staticmethod
  357. def get_transforms_shape_info(batch_ori_shape, transforms):
  358. batch_restore_list = list()
  359. for ori_shape in batch_ori_shape:
  360. restore_list = list()
  361. h, w = ori_shape[0], ori_shape[1]
  362. for op in transforms:
  363. if op.__class__.__name__ in ['Resize', 'ResizeByShort']:
  364. restore_list.append(('resize', (h, w)))
  365. h, w = op.target_size
  366. if op.__class__.__name__ in ['Padding']:
  367. restore_list.append(('padding', (h, w)))
  368. h, w = op.target_size
  369. batch_restore_list.append(restore_list)
  370. return batch_restore_list
  371. def _postprocess(self, batch_pred, batch_origin_shape, transforms):
  372. batch_restore_list = BaseSegmenter.get_transforms_shape_info(
  373. batch_origin_shape, transforms)
  374. results = list()
  375. for pred, restore_list in zip(batch_pred, batch_restore_list):
  376. pred = paddle.unsqueeze(pred, axis=0)
  377. for item in restore_list[::-1]:
  378. # TODO: 替换成cv2的interpolate(部署阶段无法使用paddle op)
  379. h, w = item[1][0], item[1][1]
  380. if item[0] == 'resize':
  381. pred = F.interpolate(pred, (h, w), mode='nearest')
  382. elif item[0] == 'padding':
  383. pred = pred[:, :, 0:h, 0:w]
  384. else:
  385. pass
  386. results.append(pred)
  387. batch_pred = paddle.concat(results, axis=0)
  388. return batch_pred
  389. class UNet(BaseSegmenter):
  390. def __init__(self,
  391. num_classes=2,
  392. use_mixed_loss=False,
  393. use_deconv=False,
  394. align_corners=False):
  395. params = {'use_deconv': use_deconv, 'align_corners': align_corners}
  396. super(UNet, self).__init__(
  397. model_name='UNet',
  398. num_classes=num_classes,
  399. use_mixed_loss=use_mixed_loss,
  400. **params)
  401. class DeepLabV3P(BaseSegmenter):
  402. def __init__(self,
  403. num_classes=2,
  404. backbone='ResNet50_vd',
  405. use_mixed_loss=False,
  406. output_stride=8,
  407. backbone_indices=(0, 3),
  408. aspp_ratios=(1, 12, 24, 36),
  409. aspp_out_channels=256,
  410. align_corners=False):
  411. self.backbone_name = backbone
  412. if backbone not in ['ResNet50_vd', 'ResNet101_vd']:
  413. raise ValueError(
  414. "backbone: {} is not supported. Please choose one of "
  415. "('ResNet50_vd', 'ResNet101_vd')".format(backbone))
  416. with DisablePrint():
  417. backbone = getattr(paddleseg.models, backbone)(
  418. output_stride=output_stride)
  419. params = {
  420. 'backbone': backbone,
  421. 'backbone_indices': backbone_indices,
  422. 'aspp_ratios': aspp_ratios,
  423. 'aspp_out_channels': aspp_out_channels,
  424. 'align_corners': align_corners
  425. }
  426. super(DeepLabV3P, self).__init__(
  427. model_name='DeepLabV3P',
  428. num_classes=num_classes,
  429. use_mixed_loss=use_mixed_loss,
  430. **params)
  431. class FastSCNN(BaseSegmenter):
  432. def __init__(self,
  433. num_classes=2,
  434. use_mixed_loss=False,
  435. align_corners=False):
  436. params = {'align_corners': align_corners}
  437. super(FastSCNN, self).__init__(
  438. model_name='FastSCNN',
  439. num_classes=num_classes,
  440. use_mixed_loss=use_mixed_loss,
  441. **params)
  442. class HRNet(BaseSegmenter):
  443. def __init__(self,
  444. num_classes=2,
  445. width=48,
  446. use_mixed_loss=False,
  447. align_corners=False):
  448. if width not in (18, 48):
  449. raise ValueError(
  450. "width={} is not supported, please choose from [18, 48]".
  451. format(width))
  452. self.backbone_name = 'HRNet_W{}'.format(width)
  453. with DisablePrint():
  454. backbone = getattr(paddleseg.models, self.backbone_name)(
  455. align_corners=align_corners)
  456. params = {'backbone': backbone, 'align_corners': align_corners}
  457. super(HRNet, self).__init__(
  458. model_name='FCN',
  459. num_classes=num_classes,
  460. use_mixed_loss=use_mixed_loss,
  461. **params)
  462. self.model_name = 'HRNet'
  463. class BiSeNetV2(BaseSegmenter):
  464. def __init__(self,
  465. num_classes=2,
  466. use_mixed_loss=False,
  467. align_corners=False):
  468. params = {'align_corners': align_corners}
  469. super(BiSeNetV2, self).__init__(
  470. model_name='BiSeNetV2',
  471. num_classes=num_classes,
  472. use_mixed_loss=use_mixed_loss,
  473. **params)