sam_result.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import random
  16. import numpy as np
  17. from PIL import Image
  18. from .....utils.deps import function_requires_deps, is_dep_available
  19. from ....common.result import BaseCVResult, JsonMixin
  20. from ....utils.color_map import get_colormap
  21. if is_dep_available("opencv-contrib-python"):
  22. import cv2
  23. @function_requires_deps("opencv-contrib-python")
  24. def draw_segm(im, masks, mask_info, alpha=0.7):
  25. """
  26. Draw segmentation on image
  27. """
  28. w_ratio = 0.4
  29. color_list = get_colormap(rgb=True)
  30. im = np.array(im).astype("float32")
  31. clsid2color = {}
  32. masks = np.array(masks)
  33. masks = masks.astype(np.uint8)
  34. for i in range(masks.shape[0]):
  35. mask = masks[i]
  36. clsid = random.randint(0, len(get_colormap(rgb=True)) - 1)
  37. if clsid not in clsid2color:
  38. color_index = i % len(color_list)
  39. clsid2color[clsid] = color_list[color_index]
  40. color_mask = clsid2color[clsid]
  41. for c in range(3):
  42. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  43. idx = np.nonzero(mask)
  44. color_mask = np.array(color_mask)
  45. idx0 = np.minimum(idx[0], im.shape[0] - 1)
  46. idx1 = np.minimum(idx[1], im.shape[1] - 1)
  47. im[idx0, idx1, :] *= 1.0 - alpha
  48. im[idx0, idx1, :] += alpha * color_mask
  49. # draw box prompt
  50. if mask_info[i]["label"] == "box_prompt":
  51. x0, y0, x1, y1 = mask_info[i]["prompt"]
  52. x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
  53. cv2.rectangle(
  54. im, (x0, y0), (x1, y1), tuple(color_mask.astype("int32").tolist()), 1
  55. )
  56. bbox_text = "%s" % mask_info[i]["label"]
  57. t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
  58. cv2.rectangle(
  59. im,
  60. (x0, y0),
  61. (x0 + t_size[0], y0 - t_size[1] - 3),
  62. tuple(color_mask.astype("int32").tolist()),
  63. -1,
  64. )
  65. cv2.putText(
  66. im,
  67. bbox_text,
  68. (x0, y0 - 2),
  69. cv2.FONT_HERSHEY_SIMPLEX,
  70. 0.3,
  71. (0, 0, 0),
  72. 1,
  73. lineType=cv2.LINE_AA,
  74. )
  75. elif mask_info[i]["label"] == "point_prompt":
  76. x, y = mask_info[i]["prompt"]
  77. bbox_text = "%s" % mask_info[i]["label"]
  78. t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
  79. cv2.circle(
  80. im,
  81. (x, y),
  82. 1,
  83. (255, 255, 255),
  84. 4,
  85. )
  86. cv2.putText(
  87. im,
  88. bbox_text,
  89. (x - t_size[0] // 2, y - t_size[1] - 1),
  90. cv2.FONT_HERSHEY_SIMPLEX,
  91. 0.3,
  92. (255, 255, 255),
  93. 1,
  94. lineType=cv2.LINE_AA,
  95. )
  96. else:
  97. raise NotImplementedError(
  98. f"Prompt type {mask_info[i]['label']} not implemented."
  99. )
  100. return Image.fromarray(im.astype("uint8"))
  101. class SAMSegResult(BaseCVResult):
  102. """Save Result Transform for SAM"""
  103. def __init__(self, data: dict) -> None:
  104. data["masks"] = [mask.squeeze(0) for mask in list(data["masks"])]
  105. prompts = data["prompts"]
  106. assert isinstance(prompts, dict) and len(prompts) == 1
  107. prompt_type, prompts = list(prompts.items())[0]
  108. mask_infos = [
  109. {
  110. "label": prompt_type,
  111. "prompt": p,
  112. }
  113. for p in prompts
  114. ]
  115. data["mask_infos"] = mask_infos
  116. assert len(data["masks"]) == len(mask_infos)
  117. super().__init__(data)
  118. def _to_img(self):
  119. """apply"""
  120. image = Image.fromarray(self["input_img"])
  121. mask_infos = self["mask_infos"]
  122. masks = self["masks"]
  123. image = draw_segm(image, masks, mask_infos)
  124. return {"res": image}
  125. def _to_str(self, *args, **kwargs):
  126. data = copy.deepcopy(self)
  127. data.pop("input_img")
  128. data["masks"] = "..."
  129. return JsonMixin._to_str(data, *args, **kwargs)
  130. def _to_json(self, *args, **kwargs):
  131. data = copy.deepcopy(self)
  132. data.pop("input_img")
  133. return JsonMixin._to_json(data, *args, **kwargs)