result.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import math
  16. import numpy as np
  17. from PIL import Image
  18. from ....utils.deps import function_requires_deps, is_dep_available
  19. from ...common.result import BaseCVResult, JsonMixin
  20. if is_dep_available("opencv-contrib-python"):
  21. import cv2
  22. if is_dep_available("matplotlib"):
  23. import matplotlib.pyplot as plt
  24. def get_color(idx):
  25. idx = idx * 3
  26. color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
  27. return color
  28. @function_requires_deps("matplotlib", "opencv-contrib-python")
  29. def draw_keypoints(img, results, visual_thresh=0.1, ids=None):
  30. plt.switch_backend("agg")
  31. skeletons = results["keypoints"]
  32. skeletons = np.array(skeletons)
  33. if len(skeletons) > 0:
  34. kpt_nums = skeletons.shape[1]
  35. if kpt_nums == 17: # plot coco keypoint
  36. EDGES = [
  37. (0, 1),
  38. (0, 2),
  39. (1, 3),
  40. (2, 4),
  41. (3, 5),
  42. (4, 6),
  43. (5, 7),
  44. (6, 8),
  45. (7, 9),
  46. (8, 10),
  47. (5, 11),
  48. (6, 12),
  49. (11, 13),
  50. (12, 14),
  51. (13, 15),
  52. (14, 16),
  53. (11, 12),
  54. ]
  55. else: # plot mpii keypoint
  56. EDGES = [
  57. (0, 1),
  58. (1, 2),
  59. (3, 4),
  60. (4, 5),
  61. (2, 6),
  62. (3, 6),
  63. (6, 7),
  64. (7, 8),
  65. (8, 9),
  66. (10, 11),
  67. (11, 12),
  68. (13, 14),
  69. (14, 15),
  70. (8, 12),
  71. (8, 13),
  72. ]
  73. NUM_EDGES = len(EDGES)
  74. colors = [
  75. [255, 0, 0],
  76. [255, 85, 0],
  77. [255, 170, 0],
  78. [255, 255, 0],
  79. [170, 255, 0],
  80. [85, 255, 0],
  81. [0, 255, 0],
  82. [0, 255, 85],
  83. [0, 255, 170],
  84. [0, 255, 255],
  85. [0, 170, 255],
  86. [0, 85, 255],
  87. [0, 0, 255],
  88. [85, 0, 255],
  89. [170, 0, 255],
  90. [255, 0, 255],
  91. [255, 0, 170],
  92. [255, 0, 85],
  93. ]
  94. plt.figure()
  95. color_set = results["colors"] if "colors" in results else None
  96. if "bbox" in results and ids is None:
  97. bboxs = results["bbox"]
  98. for j, rect in enumerate(bboxs):
  99. xmin, ymin, xmax, ymax = rect
  100. color = (
  101. colors[0] if color_set is None else colors[color_set[j] % len(colors)]
  102. )
  103. cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
  104. canvas = img.copy()
  105. for i in range(kpt_nums):
  106. for j in range(len(skeletons)):
  107. if skeletons[j][i, 2] < visual_thresh:
  108. continue
  109. if ids is None:
  110. color = (
  111. colors[i]
  112. if color_set is None
  113. else colors[color_set[j] % len(colors)]
  114. )
  115. else:
  116. color = get_color(ids[j])
  117. cv2.circle(
  118. canvas,
  119. tuple(skeletons[j][i, 0:2].astype("int32")),
  120. 2,
  121. color,
  122. thickness=-1,
  123. )
  124. stickwidth = 1
  125. for i in range(NUM_EDGES):
  126. for j in range(len(skeletons)):
  127. edge = EDGES[i]
  128. if (
  129. skeletons[j][edge[0], 2] < visual_thresh
  130. or skeletons[j][edge[1], 2] < visual_thresh
  131. ):
  132. continue
  133. cur_canvas = canvas.copy()
  134. X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
  135. Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
  136. mX = np.mean(X)
  137. mY = np.mean(Y)
  138. length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
  139. angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
  140. polygon = cv2.ellipse2Poly(
  141. (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
  142. )
  143. if ids is None:
  144. color = (
  145. colors[i]
  146. if color_set is None
  147. else colors[color_set[j] % len(colors)]
  148. )
  149. else:
  150. color = get_color(ids[j])
  151. cv2.fillConvexPoly(cur_canvas, polygon, color)
  152. canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
  153. plt.close()
  154. return canvas
  155. class KptResult(BaseCVResult):
  156. """Save Result Transform"""
  157. def _to_img(self):
  158. """apply"""
  159. if "kpts" in self: # for single module result
  160. keypoints = [kpt["keypoints"] for kpt in self["kpts"]]
  161. else:
  162. keypoints = [
  163. obj["keypoints"] for obj in self["boxes"]
  164. ] # for top-down pipeline result
  165. image = self["input_img"]
  166. if keypoints:
  167. image = draw_keypoints(image, dict(keypoints=np.stack(keypoints)))
  168. image = Image.fromarray(image[..., ::-1])
  169. return {"res": image}
  170. def _to_str(self, *args, **kwargs):
  171. data = copy.deepcopy(self)
  172. data.pop("input_img")
  173. return JsonMixin._to_str(data, *args, **kwargs)
  174. def _to_json(self, *args, **kwargs):
  175. data = copy.deepcopy(self)
  176. data.pop("input_img")
  177. return JsonMixin._to_json(data, *args, **kwargs)