result.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import copy
  18. import PIL
  19. from PIL import Image, ImageDraw, ImageFont
  20. from ...utils.color_map import get_colormap, font_colormap
  21. from ...common.result import BaseCVResult
  22. from ....utils.fonts import PINGFANG_FONT_FILE_PATH
  23. from ..object_detection.result import draw_box
  24. def draw_segm(im, masks, mask_info, alpha=0.7):
  25. """
  26. Draw segmentation on image
  27. """
  28. mask_color_id = 0
  29. w_ratio = 0.4
  30. color_list = get_colormap(rgb=True)
  31. im = np.array(im).astype("float32")
  32. clsid2color = {}
  33. masks = np.array(masks)
  34. masks = masks.astype(np.uint8)
  35. for i in range(masks.shape[0]):
  36. mask, score, clsid = masks[i], mask_info[i]["score"], mask_info[i]["class_id"]
  37. if clsid not in clsid2color:
  38. color_index = i % len(color_list)
  39. clsid2color[clsid] = color_list[color_index]
  40. color_mask = clsid2color[clsid]
  41. for c in range(3):
  42. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  43. idx = np.nonzero(mask)
  44. color_mask = np.array(color_mask)
  45. idx0 = np.minimum(idx[0], im.shape[0] - 1)
  46. idx1 = np.minimum(idx[1], im.shape[1] - 1)
  47. im[idx0, idx1, :] *= 1.0 - alpha
  48. im[idx0, idx1, :] += alpha * color_mask
  49. sum_x = np.sum(mask, axis=0)
  50. x = np.where(sum_x > 0.5)[0]
  51. sum_y = np.sum(mask, axis=1)
  52. y = np.where(sum_y > 0.5)[0]
  53. x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
  54. cv2.rectangle(
  55. im, (x0, y0), (x1, y1), tuple(color_mask.astype("int32").tolist()), 1
  56. )
  57. bbox_text = "%s %.2f" % (mask_info[i]["label"], score)
  58. t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
  59. cv2.rectangle(
  60. im,
  61. (x0, y0),
  62. (x0 + t_size[0], y0 - t_size[1] - 3),
  63. tuple(color_mask.astype("int32").tolist()),
  64. -1,
  65. )
  66. cv2.putText(
  67. im,
  68. bbox_text,
  69. (x0, y0 - 2),
  70. cv2.FONT_HERSHEY_SIMPLEX,
  71. 0.3,
  72. (0, 0, 0),
  73. 1,
  74. lineType=cv2.LINE_AA,
  75. )
  76. return Image.fromarray(im.astype("uint8"))
  77. def restore_to_draw_masks(img_size, boxes, masks):
  78. """
  79. Restores extracted masks to the original shape and draws them on a blank image.
  80. """
  81. restored_masks = []
  82. for i, (box, mask) in enumerate(zip(boxes, masks)):
  83. restored_mask = np.zeros(img_size, dtype=np.uint8)
  84. x_min, y_min, x_max, y_max = map(lambda x: int(round(x)), box["coordinate"])
  85. restored_mask[y_min:y_max, x_min:x_max] = mask
  86. restored_masks.append(restored_mask)
  87. return np.array(restored_masks)
  88. def draw_mask(im, boxes, np_masks, img_size):
  89. """
  90. Args:
  91. im (PIL.Image.Image): PIL image
  92. boxes (list): a list of dictionaries representing detection box information.
  93. np_masks (np.ndarray): shape:[N, im_h, im_w]
  94. Returns:
  95. im (PIL.Image.Image): visualized image
  96. """
  97. color_list = get_colormap(rgb=True)
  98. w_ratio = 0.4
  99. alpha = 0.7
  100. im = np.array(im).astype("float32")
  101. clsid2color = {}
  102. np_masks = restore_to_draw_masks(img_size, boxes, np_masks)
  103. im_h, im_w = im.shape[:2]
  104. np_masks = np_masks[:, :im_h, :im_w]
  105. for i in range(len(np_masks)):
  106. clsid, score = int(boxes[i]["cls_id"]), boxes[i]["score"]
  107. mask = np_masks[i]
  108. if clsid not in clsid2color:
  109. color_index = i % len(color_list)
  110. clsid2color[clsid] = color_list[color_index]
  111. color_mask = clsid2color[clsid]
  112. for c in range(3):
  113. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  114. idx = np.nonzero(mask)
  115. color_mask = np.array(color_mask)
  116. im[idx[0], idx[1], :] *= 1.0 - alpha
  117. im[idx[0], idx[1], :] += alpha * color_mask
  118. return Image.fromarray(im.astype("uint8"))
  119. class InstanceSegResult(BaseCVResult):
  120. """Save Result Transform"""
  121. def _to_img(self):
  122. """apply"""
  123. # image = self._img_reader.read(self["input_path"])
  124. image = Image.fromarray(self._input_img)
  125. ori_img_size = list(image.size)[::-1]
  126. boxes = self["boxes"]
  127. masks = self["masks"]
  128. if next((True for item in self["boxes"] if "coordinate" in item), False):
  129. image = draw_mask(image, boxes, masks, ori_img_size)
  130. image = draw_box(image, boxes)
  131. else:
  132. image = draw_segm(image, masks, boxes)
  133. return image
  134. def _to_str(self, _, *args, **kwargs):
  135. data = copy.deepcopy(self)
  136. data["masks"] = "..."
  137. return super()._to_str(data, *args, **kwargs)