instance_seg.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import copy
  18. from PIL import Image
  19. from ..utils.color_map import get_colormap
  20. from .base import CVResult
  21. from .det import draw_box
  22. def draw_segm(im, masks, mask_info, alpha=0.7):
  23. """
  24. Draw segmentation on image
  25. """
  26. mask_color_id = 0
  27. w_ratio = 0.4
  28. color_list = get_colormap(rgb=True)
  29. im = np.array(im).astype("float32")
  30. clsid2color = {}
  31. masks = np.array(masks)
  32. masks = masks.astype(np.uint8)
  33. for i in range(masks.shape[0]):
  34. mask, score, clsid = masks[i], mask_info[i]["score"], mask_info[i]["class_id"]
  35. if clsid not in clsid2color:
  36. color_index = i % len(color_list)
  37. clsid2color[clsid] = color_list[color_index]
  38. color_mask = clsid2color[clsid]
  39. for c in range(3):
  40. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  41. idx = np.nonzero(mask)
  42. color_mask = np.array(color_mask)
  43. idx0 = np.minimum(idx[0], im.shape[0] - 1)
  44. idx1 = np.minimum(idx[1], im.shape[1] - 1)
  45. im[idx0, idx1, :] *= 1.0 - alpha
  46. im[idx0, idx1, :] += alpha * color_mask
  47. sum_x = np.sum(mask, axis=0)
  48. x = np.where(sum_x > 0.5)[0]
  49. sum_y = np.sum(mask, axis=1)
  50. y = np.where(sum_y > 0.5)[0]
  51. x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
  52. cv2.rectangle(
  53. im, (x0, y0), (x1, y1), tuple(color_mask.astype("int32").tolist()), 1
  54. )
  55. bbox_text = "%s %.2f" % (mask_info[i]["label"], score)
  56. t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
  57. cv2.rectangle(
  58. im,
  59. (x0, y0),
  60. (x0 + t_size[0], y0 - t_size[1] - 3),
  61. tuple(color_mask.astype("int32").tolist()),
  62. -1,
  63. )
  64. cv2.putText(
  65. im,
  66. bbox_text,
  67. (x0, y0 - 2),
  68. cv2.FONT_HERSHEY_SIMPLEX,
  69. 0.3,
  70. (0, 0, 0),
  71. 1,
  72. lineType=cv2.LINE_AA,
  73. )
  74. return Image.fromarray(im.astype("uint8"))
  75. def restore_to_draw_masks(img_size, boxes, masks):
  76. """
  77. Restores extracted masks to the original shape and draws them on a blank image.
  78. """
  79. restored_masks = []
  80. for i, (box, mask) in enumerate(zip(boxes, masks)):
  81. restored_mask = np.zeros(img_size, dtype=np.uint8)
  82. x_min, y_min, x_max, y_max = map(lambda x: int(round(x)), box["coordinate"])
  83. restored_mask[y_min:y_max, x_min:x_max] = mask
  84. restored_masks.append(restored_mask)
  85. return np.array(restored_masks)
  86. def draw_mask(im, boxes, np_masks, img_size):
  87. """
  88. Args:
  89. im (PIL.Image.Image): PIL image
  90. boxes (list): a list of dictionaries representing detection box information.
  91. np_masks (np.ndarray): shape:[N, im_h, im_w]
  92. Returns:
  93. im (PIL.Image.Image): visualized image
  94. """
  95. color_list = get_colormap(rgb=True)
  96. w_ratio = 0.4
  97. alpha = 0.7
  98. im = np.array(im).astype("float32")
  99. clsid2color = {}
  100. np_masks = restore_to_draw_masks(img_size, boxes, np_masks)
  101. im_h, im_w = im.shape[:2]
  102. np_masks = np_masks[:, :im_h, :im_w]
  103. for i in range(len(np_masks)):
  104. clsid, score = int(boxes[i]["cls_id"]), boxes[i]["score"]
  105. mask = np_masks[i]
  106. if clsid not in clsid2color:
  107. color_index = i % len(color_list)
  108. clsid2color[clsid] = color_list[color_index]
  109. color_mask = clsid2color[clsid]
  110. for c in range(3):
  111. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  112. idx = np.nonzero(mask)
  113. color_mask = np.array(color_mask)
  114. im[idx[0], idx[1], :] *= 1.0 - alpha
  115. im[idx[0], idx[1], :] += alpha * color_mask
  116. return Image.fromarray(im.astype("uint8"))
  117. class InstanceSegResult(CVResult):
  118. """Save Result Transform"""
  119. def _to_img(self):
  120. """apply"""
  121. image = self._img_reader.read(self["input_path"])
  122. ori_img_size = list(image.size)[::-1]
  123. boxes = self["boxes"]
  124. masks = self["masks"]
  125. if next((True for item in self["boxes"] if "coordinate" in item), False):
  126. image = draw_mask(image, boxes, masks, ori_img_size)
  127. image = draw_box(image, boxes)
  128. else:
  129. image = draw_segm(image, masks, boxes)
  130. return image
  131. def _to_str(self):
  132. str_ = copy.deepcopy(self)
  133. str_["masks"] = "..."
  134. return str(str_)