transforms.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import math
  17. from PIL import Image, ImageDraw, ImageFont
  18. from .keys import DetKeys as K
  19. from ...base import BaseTransform
  20. from ...base.predictor.io import ImageWriter, ImageReader
  21. from ...base.predictor.transforms import image_functions as F
  22. from ...base.predictor.transforms.image_common import _BaseResize, _check_image_size
  23. from ....utils.fonts import PINGFANG_FONT_FILE_PATH
  24. from ....utils import logging
  25. __all__ = ['SaveDetResults', 'PadStride', 'DetResize', 'PrintResult']
  26. def get_color_map_list(num_classes):
  27. """
  28. Args:
  29. num_classes (int): number of class
  30. Returns:
  31. color_map (list): RGB color list
  32. """
  33. color_map = num_classes * [0, 0, 0]
  34. for i in range(0, num_classes):
  35. j = 0
  36. lab = i
  37. while lab:
  38. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  39. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  40. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  41. j += 1
  42. lab >>= 3
  43. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  44. return color_map
  45. def colormap(rgb=False):
  46. """
  47. Get colormap
  48. The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
  49. utils/colormap.py
  50. """
  51. color_list = np.array([
  52. 0xFF, 0x00, 0x00, 0xCC, 0xFF, 0x00, 0x00, 0xFF, 0x66, 0x00, 0x66, 0xFF,
  53. 0xCC, 0x00, 0xFF, 0xFF, 0x4D, 0x00, 0x80, 0xff, 0x00, 0x00, 0xFF, 0xB2,
  54. 0x00, 0x1A, 0xFF, 0xFF, 0x00, 0xE5, 0xFF, 0x99, 0x00, 0x33, 0xFF, 0x00,
  55. 0x00, 0xFF, 0xFF, 0x33, 0x00, 0xFF, 0xff, 0x00, 0x99, 0xFF, 0xE5, 0x00,
  56. 0x00, 0xFF, 0x1A, 0x00, 0xB2, 0xFF, 0x80, 0x00, 0xFF, 0xFF, 0x00, 0x4D
  57. ]).astype(np.float32)
  58. color_list = (color_list.reshape((-1, 3)))
  59. if not rgb:
  60. color_list = color_list[:, ::-1]
  61. return color_list.astype('int32')
  62. def font_colormap(color_index):
  63. """
  64. Get font color according to the index of colormap
  65. """
  66. dark = np.array([0x14, 0x0E, 0x35])
  67. light = np.array([0xFF, 0xFF, 0xFF])
  68. light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
  69. if color_index in light_indexs:
  70. return light.astype('int32')
  71. else:
  72. return dark.astype('int32')
  73. def draw_box(img, np_boxes, labels, threshold=0.5):
  74. """
  75. Args:
  76. img (PIL.Image.Image): PIL image
  77. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  78. matix element:[class, score, x_min, y_min, x_max, y_max]
  79. labels (list): labels:['class1', ..., 'classn']
  80. threshold (float): threshold of box
  81. Returns:
  82. img (PIL.Image.Image): visualized image
  83. """
  84. font_size = int(0.024 * int(img.width)) + 2
  85. font = ImageFont.truetype(
  86. PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
  87. draw_thickness = int(max(img.size) * 0.005)
  88. draw = ImageDraw.Draw(img)
  89. clsid2color = {}
  90. catid2fontcolor = {}
  91. color_list = colormap(rgb=True)
  92. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  93. np_boxes = np_boxes[expect_boxes, :]
  94. for i, dt in enumerate(np_boxes):
  95. clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
  96. if clsid not in clsid2color:
  97. color_index = i % len(color_list)
  98. clsid2color[clsid] = color_list[color_index]
  99. catid2fontcolor[clsid] = font_colormap(color_index)
  100. color = tuple(clsid2color[clsid])
  101. font_color = tuple(catid2fontcolor[clsid])
  102. xmin, ymin, xmax, ymax = bbox
  103. # draw bbox
  104. draw.line(
  105. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  106. (xmin, ymin)],
  107. width=draw_thickness,
  108. fill=color)
  109. # draw label
  110. text = "{} {:.2f}".format(labels[clsid], score)
  111. tw, th = draw.textsize(text, font=font)
  112. if ymin < th:
  113. draw.rectangle(
  114. [(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
  115. draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
  116. else:
  117. draw.rectangle(
  118. [(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
  119. draw.text(
  120. (xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
  121. return img
  122. def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
  123. """
  124. Args:
  125. im (PIL.Image.Image): PIL image
  126. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  127. matix element:[class, score, x_min, y_min, x_max, y_max]
  128. np_masks (np.ndarray): shape:[N, im_h, im_w]
  129. labels (list): labels:['class1', ..., 'classn']
  130. threshold (float): threshold of mask
  131. Returns:
  132. im (PIL.Image.Image): visualized image
  133. """
  134. color_list = get_color_map_list(len(labels))
  135. w_ratio = 0.4
  136. alpha = 0.7
  137. im = np.array(im).astype('float32')
  138. clsid2color = {}
  139. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  140. np_boxes = np_boxes[expect_boxes, :]
  141. np_masks = np_masks[expect_boxes, :, :]
  142. im_h, im_w = im.shape[:2]
  143. np_masks = np_masks[:, :im_h, :im_w]
  144. for i in range(len(np_masks)):
  145. clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
  146. mask = np_masks[i]
  147. if clsid not in clsid2color:
  148. clsid2color[clsid] = color_list[clsid]
  149. color_mask = clsid2color[clsid]
  150. for c in range(3):
  151. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  152. idx = np.nonzero(mask)
  153. color_mask = np.array(color_mask)
  154. im[idx[0], idx[1], :] *= 1.0 - alpha
  155. im[idx[0], idx[1], :] += alpha * color_mask
  156. return Image.fromarray(im.astype('uint8'))
  157. class SaveDetResults(BaseTransform):
  158. """ Save Result Transform """
  159. def __init__(self, save_dir, threshold=0.5, labels=None):
  160. super().__init__()
  161. self.save_dir = save_dir
  162. self.threshold = threshold
  163. self.labels = labels
  164. # We use pillow backend to save both numpy arrays and PIL Image objects
  165. self._writer = ImageWriter(backend='pillow')
  166. def apply(self, data):
  167. """ apply """
  168. ori_path = data[K.IM_PATH]
  169. file_name = os.path.basename(ori_path)
  170. save_path = os.path.join(self.save_dir, file_name)
  171. labels = self.labels
  172. image = ImageReader(backend='pil').read(ori_path)
  173. if K.MASKS in data:
  174. image = draw_mask(
  175. image,
  176. data[K.BOXES],
  177. data[K.MASKS],
  178. threshold=self.threshold,
  179. labels=labels)
  180. image = draw_box(
  181. image, data[K.BOXES], threshold=self.threshold, labels=labels)
  182. self._write_image(save_path, image)
  183. return data
  184. def _write_image(self, path, image):
  185. """ write image """
  186. if os.path.exists(path):
  187. logging.warning(f"{path} already exists. Overwriting it.")
  188. self._writer.write(path, image)
  189. @classmethod
  190. def get_input_keys(cls):
  191. """ get input keys """
  192. return [K.IM_PATH, K.BOXES]
  193. @classmethod
  194. def get_output_keys(cls):
  195. """ get output keys """
  196. return []
  197. class PadStride(BaseTransform):
  198. """ padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  199. Args:
  200. stride (bool): model with FPN need image shape % stride == 0
  201. """
  202. def __init__(self, stride=0):
  203. self.coarsest_stride = stride
  204. def apply(self, data):
  205. """
  206. Args:
  207. im (np.ndarray): image (np.ndarray)
  208. Returns:
  209. im (np.ndarray): processed image (np.ndarray)
  210. """
  211. im = data[K.IMAGE]
  212. coarsest_stride = self.coarsest_stride
  213. if coarsest_stride <= 0:
  214. return data
  215. im_c, im_h, im_w = im.shape
  216. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  217. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  218. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  219. padding_im[:, :im_h, :im_w] = im
  220. data[K.IMAGE] = padding_im
  221. return data
  222. @classmethod
  223. def get_input_keys(cls):
  224. """ get input keys """
  225. return [K.IMAGE]
  226. @classmethod
  227. def get_output_keys(cls):
  228. """ get output keys """
  229. return [K.IMAGE]
  230. class DetResize(_BaseResize):
  231. """
  232. Resize the image.
  233. Args:
  234. target_size (list|tuple|int): Target height and width.
  235. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  236. image. Default: False.
  237. size_divisor (int|None, optional): Divisor of resized image size.
  238. Default: None.
  239. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  240. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  241. """
  242. def __init__(self,
  243. target_hw,
  244. keep_ratio=False,
  245. size_divisor=None,
  246. interp='LINEAR'):
  247. super().__init__(size_divisor=size_divisor, interp=interp)
  248. if isinstance(target_hw, int):
  249. target_hw = [target_hw, target_hw]
  250. _check_image_size(target_hw)
  251. self.target_hw = target_hw
  252. self.keep_ratio = keep_ratio
  253. def apply(self, data):
  254. """ apply """
  255. target_hw = self.target_hw
  256. im = data['image']
  257. original_size = im.shape[:2]
  258. if self.keep_ratio:
  259. h, w = im.shape[0:2]
  260. target_hw, _ = self._rescale_size((h, w), self.target_hw)
  261. if self.size_divisor:
  262. target_hw = [
  263. math.ceil(i / self.size_divisor) * self.size_divisor
  264. for i in target_hw
  265. ]
  266. im_scale_w, im_scale_h = [
  267. target_hw[1] / original_size[1], target_hw[0] / original_size[0]
  268. ]
  269. im = F.resize(im, target_hw[::-1], interp=self.interp)
  270. data['image'] = im
  271. data['image_size'] = [im.shape[1], im.shape[0]]
  272. data['scale_factors'] = [im_scale_w, im_scale_h]
  273. return data
  274. @classmethod
  275. def get_input_keys(cls):
  276. """ get input keys """
  277. # image: Image in hw or hwc format.
  278. return ['image']
  279. @classmethod
  280. def get_output_keys(cls):
  281. """ get output keys """
  282. # image: Image in hw or hwc format.
  283. # image_size: Width and height of the image.
  284. # scale_factors: Scale factors for image width and height.
  285. return ['image', 'image_size', 'scale_factors']
  286. class PrintResult(BaseTransform):
  287. """ Print Result Transform """
  288. def apply(self, data):
  289. """ apply """
  290. logging.info("The prediction result is:")
  291. logging.info(data[K.BOXES])
  292. return data
  293. @classmethod
  294. def get_input_keys(cls):
  295. """ get input keys """
  296. return [K.BOXES]
  297. @classmethod
  298. def get_output_keys(cls):
  299. """ get output keys """
  300. return []