visualizer.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ......utils.deps import function_requires_deps
  16. def get_color_map_list(length):
  17. """Returns the color map for visualizing the segmentation mask"""
  18. length += 1
  19. color_map = length * [0, 0, 0]
  20. for i in range(0, length):
  21. j = 0
  22. lab = i
  23. while lab:
  24. color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
  25. color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
  26. color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
  27. j += 1
  28. lab >>= 3
  29. color_map = color_map[3:]
  30. return color_map
  31. @function_requires_deps("opencv-contrib-python")
  32. def visualize(image, result, weight=0.6, use_multilabel=False):
  33. """Convert predict result to color image, and save added image."""
  34. import cv2
  35. color_map = get_color_map_list(256)
  36. color_map = [color_map[i : i + 3] for i in range(0, len(color_map), 3)]
  37. color_map = np.array(color_map).astype("uint8")
  38. if not use_multilabel:
  39. # Use OpenCV LUT for color mapping
  40. c1 = cv2.LUT(result, color_map[:, 0])
  41. c2 = cv2.LUT(result, color_map[:, 1])
  42. c3 = cv2.LUT(result, color_map[:, 2])
  43. pseudo_img = np.dstack((c3, c2, c1))
  44. vis_result = cv2.addWeighted(image, weight, pseudo_img, 1 - weight, 0)
  45. else:
  46. vis_result = image.copy()
  47. for i in range(result.shape[0]):
  48. mask = result[i]
  49. c1 = np.where(mask, color_map[i, 0], vis_result[..., 0])
  50. c2 = np.where(mask, color_map[i, 1], vis_result[..., 1])
  51. c3 = np.where(mask, color_map[i, 2], vis_result[..., 2])
  52. pseudo_img = np.dstack((c3, c2, c1)).astype("uint8")
  53. contour, _ = cv2.findContours(
  54. mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
  55. )
  56. vis_result = cv2.addWeighted(vis_result, weight, pseudo_img, 1 - weight, 0)
  57. contour_color = (
  58. int(color_map[i, 0]),
  59. int(color_map[i, 1]),
  60. int(color_map[i, 2]),
  61. )
  62. vis_result = cv2.drawContours(vis_result, contour, -1, contour_color, 1)
  63. return vis_result