transforms.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import math
  17. from PIL import Image, ImageDraw, ImageFile
  18. from ....utils import logging
  19. from ...base import BaseTransform
  20. from .keys import DetKeys as K
  21. from ...base.predictor.io.writers import ImageWriter
  22. from ...base.predictor.transforms import image_functions as F
  23. from ...base.predictor.transforms.image_common import _BaseResize, _check_image_size
  24. __all__ = [
  25. 'SaveDetResults', 'PadStride', 'DetResize', 'PrintResult', 'LoadLabels'
  26. ]
  27. def get_color_map_list(num_classes):
  28. """
  29. Args:
  30. num_classes (int): number of class
  31. Returns:
  32. color_map (list): RGB color list
  33. """
  34. color_map = num_classes * [0, 0, 0]
  35. for i in range(0, num_classes):
  36. j = 0
  37. lab = i
  38. while lab:
  39. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  40. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  41. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  42. j += 1
  43. lab >>= 3
  44. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  45. return color_map
  46. def draw_box(img, np_boxes, labels, threshold=0.5):
  47. """
  48. Args:
  49. img (PIL.Image.Image): PIL image
  50. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  51. matix element:[class, score, x_min, y_min, x_max, y_max]
  52. labels (list): labels:['class1', ..., 'classn']
  53. threshold (float): threshold of box
  54. Returns:
  55. img (PIL.Image.Image): visualized image
  56. """
  57. draw_thickness = min(img.size) // 320
  58. draw = ImageDraw.Draw(img)
  59. clsid2color = {}
  60. color_list = get_color_map_list(len(labels))
  61. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  62. np_boxes = np_boxes[expect_boxes, :]
  63. for dt in np_boxes:
  64. clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
  65. if clsid not in clsid2color:
  66. clsid2color[clsid] = color_list[clsid]
  67. color = tuple(clsid2color[clsid])
  68. xmin, ymin, xmax, ymax = bbox
  69. # draw bbox
  70. draw.line(
  71. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  72. (xmin, ymin)],
  73. width=draw_thickness,
  74. fill=color)
  75. # draw label
  76. text = "{} {:.4f}".format(labels[clsid], score)
  77. tw, th = draw.textsize(text)
  78. draw.rectangle(
  79. [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
  80. draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
  81. return img
  82. def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
  83. """
  84. Args:
  85. im (PIL.Image.Image): PIL image
  86. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  87. matix element:[class, score, x_min, y_min, x_max, y_max]
  88. np_masks (np.ndarray): shape:[N, im_h, im_w]
  89. labels (list): labels:['class1', ..., 'classn']
  90. threshold (float): threshold of mask
  91. Returns:
  92. im (PIL.Image.Image): visualized image
  93. """
  94. color_list = get_color_map_list(len(labels))
  95. w_ratio = 0.4
  96. alpha = 0.7
  97. im = np.array(im).astype('float32')
  98. clsid2color = {}
  99. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  100. np_boxes = np_boxes[expect_boxes, :]
  101. np_masks = np_masks[expect_boxes, :, :]
  102. im_h, im_w = im.shape[:2]
  103. np_masks = np_masks[:, :im_h, :im_w]
  104. for i in range(len(np_masks)):
  105. clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
  106. mask = np_masks[i]
  107. if clsid not in clsid2color:
  108. clsid2color[clsid] = color_list[clsid]
  109. color_mask = clsid2color[clsid]
  110. for c in range(3):
  111. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  112. idx = np.nonzero(mask)
  113. color_mask = np.array(color_mask)
  114. im[idx[0], idx[1], :] *= 1.0 - alpha
  115. im[idx[0], idx[1], :] += alpha * color_mask
  116. return Image.fromarray(im.astype('uint8'))
  117. class SaveDetResults(BaseTransform):
  118. """ Save Result Transform """
  119. def __init__(self, save_dir, threshold=0.5, labels=None):
  120. super().__init__()
  121. self.save_dir = save_dir
  122. self.threshold = threshold
  123. self.labels = labels
  124. # We use pillow backend to save both numpy arrays and PIL Image objects
  125. self._writer = ImageWriter(backend='pillow')
  126. def apply(self, data):
  127. """ apply """
  128. ori_path = data[K.IMAGE_PATH]
  129. file_name = os.path.basename(ori_path)
  130. save_path = os.path.join(self.save_dir, file_name)
  131. labels = data[
  132. K.
  133. LABELS] if self.labels is None and K.LABELS in data else self.labels
  134. image = Image.open(ori_path)
  135. if K.MASKS in data:
  136. image = draw_mask(
  137. image,
  138. data[K.BOXES],
  139. data[K.MASKS],
  140. threshold=self.threshold,
  141. labels=labels)
  142. image = draw_box(
  143. image, data[K.BOXES], threshold=self.threshold, labels=labels)
  144. self._write_image(save_path, image)
  145. return data
  146. def _write_image(self, path, image):
  147. """ write image """
  148. if os.path.exists(path):
  149. logging.warning(f"{path} already exists. Overwriting it.")
  150. self._writer.write(path, image)
  151. @classmethod
  152. def get_input_keys(cls):
  153. """ get input keys """
  154. return [K.IMAGE_PATH, K.BOXES]
  155. @classmethod
  156. def get_output_keys(cls):
  157. """ get output keys """
  158. return []
  159. class PadStride(BaseTransform):
  160. """ padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  161. Args:
  162. stride (bool): model with FPN need image shape % stride == 0
  163. """
  164. def __init__(self, stride=0):
  165. self.coarsest_stride = stride
  166. def apply(self, data):
  167. """
  168. Args:
  169. im (np.ndarray): image (np.ndarray)
  170. Returns:
  171. im (np.ndarray): processed image (np.ndarray)
  172. """
  173. im = data[K.IMAGE]
  174. coarsest_stride = self.coarsest_stride
  175. if coarsest_stride <= 0:
  176. return data
  177. im_c, im_h, im_w = im.shape
  178. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  179. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  180. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  181. padding_im[:, :im_h, :im_w] = im
  182. data[K.IMAGE] = padding_im
  183. return data
  184. @classmethod
  185. def get_input_keys(cls):
  186. """ get input keys """
  187. return [K.IMAGE]
  188. @classmethod
  189. def get_output_keys(cls):
  190. """ get output keys """
  191. return [K.IMAGE]
  192. class DetResize(_BaseResize):
  193. """
  194. Resize the image.
  195. Args:
  196. target_size (list|tuple|int): Target height and width.
  197. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  198. image. Default: False.
  199. size_divisor (int|None, optional): Divisor of resized image size.
  200. Default: None.
  201. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  202. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  203. """
  204. def __init__(self,
  205. target_hw,
  206. keep_ratio=False,
  207. size_divisor=None,
  208. interp='LINEAR'):
  209. super().__init__(size_divisor=size_divisor, interp=interp)
  210. if isinstance(target_hw, int):
  211. target_hw = [target_hw, target_hw]
  212. _check_image_size(target_hw)
  213. self.target_hw = target_hw
  214. self.keep_ratio = keep_ratio
  215. def apply(self, data):
  216. """ apply """
  217. target_hw = self.target_hw
  218. im = data['image']
  219. original_size = im.shape[:2]
  220. if self.keep_ratio:
  221. h, w = im.shape[0:2]
  222. target_hw, _ = self._rescale_size((h, w), self.target_hw)
  223. if self.size_divisor:
  224. target_hw = [
  225. math.ceil(i / self.size_divisor) * self.size_divisor
  226. for i in target_hw
  227. ]
  228. im_scale_w, im_scale_h = [
  229. target_hw[1] / original_size[1], target_hw[0] / original_size[0]
  230. ]
  231. im = F.resize(im, target_hw[::-1], interp=self.interp)
  232. data['image'] = im
  233. data['image_size'] = [im.shape[1], im.shape[0]]
  234. data['scale_factors'] = [im_scale_w, im_scale_h]
  235. return data
  236. @classmethod
  237. def get_input_keys(cls):
  238. """ get input keys """
  239. # image: Image in hw or hwc format.
  240. return ['image']
  241. @classmethod
  242. def get_output_keys(cls):
  243. """ get output keys """
  244. # image: Image in hw or hwc format.
  245. # image_size: Width and height of the image.
  246. # scale_factors: Scale factors for image width and height.
  247. return ['image', 'image_size', 'scale_factors']
  248. class PrintResult(BaseTransform):
  249. """ Print Result Transform """
  250. def apply(self, data):
  251. """ apply """
  252. logging.info("The prediction result is:")
  253. logging.info(data[K.BOXES])
  254. return data
  255. @classmethod
  256. def get_input_keys(cls):
  257. """ get input keys """
  258. return [K.BOXES]
  259. @classmethod
  260. def get_output_keys(cls):
  261. """ get output keys """
  262. return []
  263. class LoadLabels(BaseTransform):
  264. def __init__(self, labels=None):
  265. super().__init__()
  266. self.labels = labels
  267. def apply(self, data):
  268. """ apply """
  269. if self.labels:
  270. data[K.LABELS] = self.labels
  271. return data
  272. @classmethod
  273. def get_input_keys(cls):
  274. """ get input keys """
  275. return []
  276. @classmethod
  277. def get_output_keys(cls):
  278. """ get output keys """
  279. return [K.LABELS]