seg.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import PIL
  16. from PIL import Image
  17. import copy
  18. import json
  19. from ...utils import logging
  20. from .base import BaseResult
  21. class SegResult(BaseResult):
  22. """Save Result Transform"""
  23. def __init__(self, data):
  24. super().__init__(data)
  25. self.data = data
  26. # We use pillow backend to save both numpy arrays and PIL Image objects
  27. self._img_writer.set_backend("pillow", format_="PNG")
  28. def _get_res_img(self):
  29. """apply"""
  30. seg_map = self.data["pred"]
  31. pc_map = self.get_pseudo_color_map(seg_map[0])
  32. return pc_map
  33. def get_pseudo_color_map(self, pred):
  34. """get_pseudo_color_map"""
  35. if pred.min() < 0 or pred.max() > 255:
  36. raise ValueError("`pred` cannot be cast to uint8.")
  37. pred = pred.astype(np.uint8)
  38. pred_mask = Image.fromarray(pred, mode="P")
  39. color_map = self._get_color_map_list(256)
  40. pred_mask.putpalette(color_map)
  41. return pred_mask
  42. @staticmethod
  43. def _get_color_map_list(num_classes, custom_color=None):
  44. """_get_color_map_list"""
  45. num_classes += 1
  46. color_map = num_classes * [0, 0, 0]
  47. for i in range(0, num_classes):
  48. j = 0
  49. lab = i
  50. while lab:
  51. color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
  52. color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
  53. color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
  54. j += 1
  55. lab >>= 3
  56. color_map = color_map[3:]
  57. if custom_color:
  58. color_map[: len(custom_color)] = custom_color
  59. return color_map
  60. def print(self, json_format=True, indent=4, ensure_ascii=False):
  61. str_ = copy.deepcopy(self)
  62. del str_["pred"]
  63. if json_format:
  64. str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
  65. logging.info(str_)