transforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import math
  17. import PIL
  18. from PIL import Image, ImageDraw, ImageFont
  19. from .keys import DetKeys as K
  20. from ...base import BaseTransform
  21. from ...base.predictor.io import ImageWriter, ImageReader
  22. from ...base.predictor.transforms import image_functions as F
  23. from ...base.predictor.transforms.image_common import _BaseResize, _check_image_size
  24. from ....utils.fonts import PINGFANG_FONT_FILE_PATH
  25. from ....utils import logging
  26. __all__ = ['SaveDetResults', 'PadStride', 'DetResize', 'PrintResult']
  27. def get_color_map_list(num_classes):
  28. """
  29. Args:
  30. num_classes (int): number of class
  31. Returns:
  32. color_map (list): RGB color list
  33. """
  34. color_map = num_classes * [0, 0, 0]
  35. for i in range(0, num_classes):
  36. j = 0
  37. lab = i
  38. while lab:
  39. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  40. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  41. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  42. j += 1
  43. lab >>= 3
  44. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  45. return color_map
  46. def colormap(rgb=False):
  47. """
  48. Get colormap
  49. The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
  50. utils/colormap.py
  51. """
  52. color_list = np.array([
  53. 0xFF, 0x00, 0x00, 0xCC, 0xFF, 0x00, 0x00, 0xFF, 0x66, 0x00, 0x66, 0xFF,
  54. 0xCC, 0x00, 0xFF, 0xFF, 0x4D, 0x00, 0x80, 0xff, 0x00, 0x00, 0xFF, 0xB2,
  55. 0x00, 0x1A, 0xFF, 0xFF, 0x00, 0xE5, 0xFF, 0x99, 0x00, 0x33, 0xFF, 0x00,
  56. 0x00, 0xFF, 0xFF, 0x33, 0x00, 0xFF, 0xff, 0x00, 0x99, 0xFF, 0xE5, 0x00,
  57. 0x00, 0xFF, 0x1A, 0x00, 0xB2, 0xFF, 0x80, 0x00, 0xFF, 0xFF, 0x00, 0x4D
  58. ]).astype(np.float32)
  59. color_list = (color_list.reshape((-1, 3)))
  60. if not rgb:
  61. color_list = color_list[:, ::-1]
  62. return color_list.astype('int32')
  63. def font_colormap(color_index):
  64. """
  65. Get font color according to the index of colormap
  66. """
  67. dark = np.array([0x14, 0x0E, 0x35])
  68. light = np.array([0xFF, 0xFF, 0xFF])
  69. light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
  70. if color_index in light_indexs:
  71. return light.astype('int32')
  72. else:
  73. return dark.astype('int32')
  74. def draw_box(img, np_boxes, labels, threshold=0.5):
  75. """
  76. Args:
  77. img (PIL.Image.Image): PIL image
  78. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  79. matix element:[class, score, x_min, y_min, x_max, y_max]
  80. labels (list): labels:['class1', ..., 'classn']
  81. threshold (float): threshold of box
  82. Returns:
  83. img (PIL.Image.Image): visualized image
  84. """
  85. font_size = int(0.024 * int(img.width)) + 2
  86. font = ImageFont.truetype(
  87. PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
  88. draw_thickness = int(max(img.size) * 0.005)
  89. draw = ImageDraw.Draw(img)
  90. clsid2color = {}
  91. catid2fontcolor = {}
  92. color_list = colormap(rgb=True)
  93. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  94. np_boxes = np_boxes[expect_boxes, :]
  95. for i, dt in enumerate(np_boxes):
  96. clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
  97. if clsid not in clsid2color:
  98. color_index = i % len(color_list)
  99. clsid2color[clsid] = color_list[color_index]
  100. catid2fontcolor[clsid] = font_colormap(color_index)
  101. color = tuple(clsid2color[clsid])
  102. font_color = tuple(catid2fontcolor[clsid])
  103. xmin, ymin, xmax, ymax = bbox
  104. # draw bbox
  105. draw.line(
  106. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  107. (xmin, ymin)],
  108. width=draw_thickness,
  109. fill=color)
  110. # draw label
  111. text = "{} {:.2f}".format(labels[clsid], score)
  112. if tuple(map(int, PIL.__version__.split('.'))) <= (10, 0, 0):
  113. tw, th = draw.textsize(text, font=font)
  114. else:
  115. left, top, right, bottom = draw.textbbox((0, 0), text, font)
  116. tw, th = right - left, bottom - top
  117. if ymin < th:
  118. draw.rectangle(
  119. [(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
  120. draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
  121. else:
  122. draw.rectangle(
  123. [(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
  124. draw.text(
  125. (xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
  126. return img
  127. def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
  128. """
  129. Args:
  130. im (PIL.Image.Image): PIL image
  131. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  132. matix element:[class, score, x_min, y_min, x_max, y_max]
  133. np_masks (np.ndarray): shape:[N, im_h, im_w]
  134. labels (list): labels:['class1', ..., 'classn']
  135. threshold (float): threshold of mask
  136. Returns:
  137. im (PIL.Image.Image): visualized image
  138. """
  139. color_list = get_color_map_list(len(labels))
  140. w_ratio = 0.4
  141. alpha = 0.7
  142. im = np.array(im).astype('float32')
  143. clsid2color = {}
  144. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  145. np_boxes = np_boxes[expect_boxes, :]
  146. np_masks = np_masks[expect_boxes, :, :]
  147. im_h, im_w = im.shape[:2]
  148. np_masks = np_masks[:, :im_h, :im_w]
  149. for i in range(len(np_masks)):
  150. clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
  151. mask = np_masks[i]
  152. if clsid not in clsid2color:
  153. clsid2color[clsid] = color_list[clsid]
  154. color_mask = clsid2color[clsid]
  155. for c in range(3):
  156. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  157. idx = np.nonzero(mask)
  158. color_mask = np.array(color_mask)
  159. im[idx[0], idx[1], :] *= 1.0 - alpha
  160. im[idx[0], idx[1], :] += alpha * color_mask
  161. return Image.fromarray(im.astype('uint8'))
  162. class SaveDetResults(BaseTransform):
  163. """ Save Result Transform """
  164. def __init__(self, save_dir, threshold=0.5, labels=None):
  165. super().__init__()
  166. self.save_dir = save_dir
  167. self.threshold = threshold
  168. self.labels = labels
  169. # We use pillow backend to save both numpy arrays and PIL Image objects
  170. self._writer = ImageWriter(backend='pillow')
  171. def apply(self, data):
  172. """ apply """
  173. ori_path = data[K.IM_PATH]
  174. file_name = os.path.basename(ori_path)
  175. save_path = os.path.join(self.save_dir, file_name)
  176. labels = self.labels
  177. image = ImageReader(backend='pil').read(ori_path)
  178. if K.MASKS in data:
  179. image = draw_mask(
  180. image,
  181. data[K.BOXES],
  182. data[K.MASKS],
  183. threshold=self.threshold,
  184. labels=labels)
  185. image = draw_box(
  186. image, data[K.BOXES], threshold=self.threshold, labels=labels)
  187. self._write_image(save_path, image)
  188. return data
  189. def _write_image(self, path, image):
  190. """ write image """
  191. if os.path.exists(path):
  192. logging.warning(f"{path} already exists. Overwriting it.")
  193. self._writer.write(path, image)
  194. @classmethod
  195. def get_input_keys(cls):
  196. """ get input keys """
  197. return [K.IM_PATH, K.BOXES]
  198. @classmethod
  199. def get_output_keys(cls):
  200. """ get output keys """
  201. return []
  202. class PadStride(BaseTransform):
  203. """ padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  204. Args:
  205. stride (bool): model with FPN need image shape % stride == 0
  206. """
  207. def __init__(self, stride=0):
  208. self.coarsest_stride = stride
  209. def apply(self, data):
  210. """
  211. Args:
  212. im (np.ndarray): image (np.ndarray)
  213. Returns:
  214. im (np.ndarray): processed image (np.ndarray)
  215. """
  216. im = data[K.IMAGE]
  217. coarsest_stride = self.coarsest_stride
  218. if coarsest_stride <= 0:
  219. return data
  220. im_c, im_h, im_w = im.shape
  221. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  222. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  223. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  224. padding_im[:, :im_h, :im_w] = im
  225. data[K.IMAGE] = padding_im
  226. return data
  227. @classmethod
  228. def get_input_keys(cls):
  229. """ get input keys """
  230. return [K.IMAGE]
  231. @classmethod
  232. def get_output_keys(cls):
  233. """ get output keys """
  234. return [K.IMAGE]
  235. class DetResize(_BaseResize):
  236. """
  237. Resize the image.
  238. Args:
  239. target_size (list|tuple|int): Target height and width.
  240. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  241. image. Default: False.
  242. size_divisor (int|None, optional): Divisor of resized image size.
  243. Default: None.
  244. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  245. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  246. """
  247. def __init__(self,
  248. target_hw,
  249. keep_ratio=False,
  250. size_divisor=None,
  251. interp='LINEAR'):
  252. super().__init__(size_divisor=size_divisor, interp=interp)
  253. if isinstance(target_hw, int):
  254. target_hw = [target_hw, target_hw]
  255. _check_image_size(target_hw)
  256. self.target_hw = target_hw
  257. self.keep_ratio = keep_ratio
  258. def apply(self, data):
  259. """ apply """
  260. target_hw = self.target_hw
  261. im = data['image']
  262. original_size = im.shape[:2]
  263. if self.keep_ratio:
  264. h, w = im.shape[0:2]
  265. target_hw, _ = self._rescale_size((h, w), self.target_hw)
  266. if self.size_divisor:
  267. target_hw = [
  268. math.ceil(i / self.size_divisor) * self.size_divisor
  269. for i in target_hw
  270. ]
  271. im_scale_w, im_scale_h = [
  272. target_hw[1] / original_size[1], target_hw[0] / original_size[0]
  273. ]
  274. im = F.resize(im, target_hw[::-1], interp=self.interp)
  275. data['image'] = im
  276. data['image_size'] = [im.shape[1], im.shape[0]]
  277. data['scale_factors'] = [im_scale_w, im_scale_h]
  278. return data
  279. @classmethod
  280. def get_input_keys(cls):
  281. """ get input keys """
  282. # image: Image in hw or hwc format.
  283. return ['image']
  284. @classmethod
  285. def get_output_keys(cls):
  286. """ get output keys """
  287. # image: Image in hw or hwc format.
  288. # image_size: Width and height of the image.
  289. # scale_factors: Scale factors for image width and height.
  290. return ['image', 'image_size', 'scale_factors']
  291. class PrintResult(BaseTransform):
  292. """ Print Result Transform """
  293. def apply(self, data):
  294. """ apply """
  295. logging.info("The prediction result is:")
  296. logging.info(data[K.BOXES])
  297. return data
  298. @classmethod
  299. def get_input_keys(cls):
  300. """ get input keys """
  301. return [K.BOXES]
  302. @classmethod
  303. def get_output_keys(cls):
  304. """ get output keys """
  305. return []