transforms.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import math
  17. import PIL
  18. from PIL import Image, ImageDraw, ImageFont
  19. from .keys import DetKeys as K
  20. from ...base import BaseTransform
  21. from ...base.predictor.io import ImageWriter, ImageReader
  22. from ...base.predictor.transforms import image_functions as F
  23. from ...base.predictor.transforms.image_common import _BaseResize, _check_image_size
  24. from ....utils.fonts import PINGFANG_FONT_FILE_PATH
  25. from ....utils import logging
  26. __all__ = ["SaveDetResults", "PadStride", "DetResize", "PrintResult"]
  27. def get_color_map_list(num_classes):
  28. """
  29. Args:
  30. num_classes (int): number of class
  31. Returns:
  32. color_map (list): RGB color list
  33. """
  34. color_map = num_classes * [0, 0, 0]
  35. for i in range(0, num_classes):
  36. j = 0
  37. lab = i
  38. while lab:
  39. color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
  40. color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
  41. color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
  42. j += 1
  43. lab >>= 3
  44. color_map = [color_map[i : i + 3] for i in range(0, len(color_map), 3)]
  45. return color_map
  46. def colormap(rgb=False):
  47. """
  48. Get colormap
  49. The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
  50. utils/colormap.py
  51. """
  52. color_list = np.array(
  53. [
  54. 0xFF,
  55. 0x00,
  56. 0x00,
  57. 0xCC,
  58. 0xFF,
  59. 0x00,
  60. 0x00,
  61. 0xFF,
  62. 0x66,
  63. 0x00,
  64. 0x66,
  65. 0xFF,
  66. 0xCC,
  67. 0x00,
  68. 0xFF,
  69. 0xFF,
  70. 0x4D,
  71. 0x00,
  72. 0x80,
  73. 0xFF,
  74. 0x00,
  75. 0x00,
  76. 0xFF,
  77. 0xB2,
  78. 0x00,
  79. 0x1A,
  80. 0xFF,
  81. 0xFF,
  82. 0x00,
  83. 0xE5,
  84. 0xFF,
  85. 0x99,
  86. 0x00,
  87. 0x33,
  88. 0xFF,
  89. 0x00,
  90. 0x00,
  91. 0xFF,
  92. 0xFF,
  93. 0x33,
  94. 0x00,
  95. 0xFF,
  96. 0xFF,
  97. 0x00,
  98. 0x99,
  99. 0xFF,
  100. 0xE5,
  101. 0x00,
  102. 0x00,
  103. 0xFF,
  104. 0x1A,
  105. 0x00,
  106. 0xB2,
  107. 0xFF,
  108. 0x80,
  109. 0x00,
  110. 0xFF,
  111. 0xFF,
  112. 0x00,
  113. 0x4D,
  114. ]
  115. ).astype(np.float32)
  116. color_list = color_list.reshape((-1, 3))
  117. if not rgb:
  118. color_list = color_list[:, ::-1]
  119. return color_list.astype("int32")
  120. def font_colormap(color_index):
  121. """
  122. Get font color according to the index of colormap
  123. """
  124. dark = np.array([0x14, 0x0E, 0x35])
  125. light = np.array([0xFF, 0xFF, 0xFF])
  126. light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
  127. if color_index in light_indexs:
  128. return light.astype("int32")
  129. else:
  130. return dark.astype("int32")
  131. def draw_box(img, np_boxes, labels, threshold=0.5):
  132. """
  133. Args:
  134. img (PIL.Image.Image): PIL image
  135. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  136. matix element:[class, score, x_min, y_min, x_max, y_max]
  137. labels (list): labels:['class1', ..., 'classn']
  138. threshold (float): threshold of box
  139. Returns:
  140. img (PIL.Image.Image): visualized image
  141. """
  142. font_size = int(0.024 * int(img.width)) + 2
  143. font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
  144. draw_thickness = int(max(img.size) * 0.005)
  145. draw = ImageDraw.Draw(img)
  146. clsid2color = {}
  147. catid2fontcolor = {}
  148. color_list = colormap(rgb=True)
  149. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  150. np_boxes = np_boxes[expect_boxes, :]
  151. for i, dt in enumerate(np_boxes):
  152. clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
  153. if clsid not in clsid2color:
  154. color_index = i % len(color_list)
  155. clsid2color[clsid] = color_list[color_index]
  156. catid2fontcolor[clsid] = font_colormap(color_index)
  157. color = tuple(clsid2color[clsid])
  158. font_color = tuple(catid2fontcolor[clsid])
  159. xmin, ymin, xmax, ymax = bbox
  160. # draw bbox
  161. draw.line(
  162. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
  163. width=draw_thickness,
  164. fill=color,
  165. )
  166. # draw label
  167. text = "{} {:.2f}".format(labels[clsid], score)
  168. if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
  169. tw, th = draw.textsize(text, font=font)
  170. else:
  171. left, top, right, bottom = draw.textbbox((0, 0), text, font)
  172. tw, th = right - left, bottom - top
  173. if ymin < th:
  174. draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
  175. draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
  176. else:
  177. draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
  178. draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
  179. return img
  180. def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
  181. """
  182. Args:
  183. im (PIL.Image.Image): PIL image
  184. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  185. matix element:[class, score, x_min, y_min, x_max, y_max]
  186. np_masks (np.ndarray): shape:[N, im_h, im_w]
  187. labels (list): labels:['class1', ..., 'classn']
  188. threshold (float): threshold of mask
  189. Returns:
  190. im (PIL.Image.Image): visualized image
  191. """
  192. color_list = get_color_map_list(len(labels))
  193. w_ratio = 0.4
  194. alpha = 0.7
  195. im = np.array(im).astype("float32")
  196. clsid2color = {}
  197. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  198. np_boxes = np_boxes[expect_boxes, :]
  199. np_masks = np_masks[expect_boxes, :, :]
  200. im_h, im_w = im.shape[:2]
  201. np_masks = np_masks[:, :im_h, :im_w]
  202. for i in range(len(np_masks)):
  203. clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
  204. mask = np_masks[i]
  205. if clsid not in clsid2color:
  206. clsid2color[clsid] = color_list[clsid]
  207. color_mask = clsid2color[clsid]
  208. for c in range(3):
  209. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  210. idx = np.nonzero(mask)
  211. color_mask = np.array(color_mask)
  212. im[idx[0], idx[1], :] *= 1.0 - alpha
  213. im[idx[0], idx[1], :] += alpha * color_mask
  214. return Image.fromarray(im.astype("uint8"))
  215. class SaveDetResults(BaseTransform):
  216. """Save Result Transform"""
  217. def __init__(self, save_dir, threshold=0.5, labels=None):
  218. super().__init__()
  219. self.save_dir = save_dir
  220. self.threshold = threshold
  221. self.labels = labels
  222. # We use pillow backend to save both numpy arrays and PIL Image objects
  223. self._writer = ImageWriter(backend="pillow")
  224. def apply(self, data):
  225. """apply"""
  226. ori_path = data[K.IM_PATH]
  227. file_name = os.path.basename(ori_path)
  228. save_path = os.path.join(self.save_dir, file_name)
  229. labels = self.labels
  230. image = ImageReader(backend="pil").read(ori_path)
  231. if K.MASKS in data:
  232. image = draw_mask(
  233. image,
  234. data[K.BOXES],
  235. data[K.MASKS],
  236. threshold=self.threshold,
  237. labels=labels,
  238. )
  239. image = draw_box(image, data[K.BOXES], threshold=self.threshold, labels=labels)
  240. self._write_image(save_path, image)
  241. return data
  242. def _write_image(self, path, image):
  243. """write image"""
  244. if os.path.exists(path):
  245. logging.warning(f"{path} already exists. Overwriting it.")
  246. self._writer.write(path, image)
  247. @classmethod
  248. def get_input_keys(cls):
  249. """get input keys"""
  250. return [K.IM_PATH, K.BOXES]
  251. @classmethod
  252. def get_output_keys(cls):
  253. """get output keys"""
  254. return []
  255. class PadStride(BaseTransform):
  256. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  257. Args:
  258. stride (bool): model with FPN need image shape % stride == 0
  259. """
  260. def __init__(self, stride=0):
  261. self.coarsest_stride = stride
  262. def apply(self, data):
  263. """
  264. Args:
  265. im (np.ndarray): image (np.ndarray)
  266. Returns:
  267. im (np.ndarray): processed image (np.ndarray)
  268. """
  269. im = data[K.IMAGE]
  270. coarsest_stride = self.coarsest_stride
  271. if coarsest_stride <= 0:
  272. return data
  273. im_c, im_h, im_w = im.shape
  274. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  275. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  276. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  277. padding_im[:, :im_h, :im_w] = im
  278. data[K.IMAGE] = padding_im
  279. return data
  280. @classmethod
  281. def get_input_keys(cls):
  282. """get input keys"""
  283. return [K.IMAGE]
  284. @classmethod
  285. def get_output_keys(cls):
  286. """get output keys"""
  287. return [K.IMAGE]
  288. class Pad(BaseTransform):
  289. def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
  290. """
  291. Pad image to a specified size.
  292. Args:
  293. size (list[int]): image target size
  294. fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
  295. """
  296. super(Pad, self).__init__()
  297. if isinstance(size, int):
  298. size = [size, size]
  299. self.size = size
  300. self.fill_value = fill_value
  301. def apply(self, data):
  302. im = data[K.IMAGE]
  303. im_h, im_w = im.shape[:2]
  304. h, w = self.size
  305. if h == im_h and w == im_w:
  306. # im = im.astype(np.float32)
  307. return data
  308. canvas = np.ones((h, w, 3), dtype=np.float32)
  309. canvas *= np.array(self.fill_value, dtype=np.float32)
  310. canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
  311. data[K.IMAGE] = canvas
  312. return data
  313. @classmethod
  314. def get_input_keys(cls):
  315. """get input keys"""
  316. return [K.IMAGE]
  317. @classmethod
  318. def get_output_keys(cls):
  319. """get output keys"""
  320. return [K.IMAGE]
  321. class DetResize(_BaseResize):
  322. """
  323. Resize the image.
  324. Args:
  325. target_size (list|tuple|int): Target height and width.
  326. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  327. image. Default: False.
  328. size_divisor (int|None, optional): Divisor of resized image size.
  329. Default: None.
  330. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  331. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  332. """
  333. def __init__(self, target_hw, keep_ratio=False, size_divisor=None, interp="LINEAR"):
  334. super().__init__(size_divisor=size_divisor, interp=interp)
  335. if isinstance(target_hw, int):
  336. target_hw = [target_hw, target_hw]
  337. _check_image_size(target_hw)
  338. self.target_hw = target_hw
  339. self.keep_ratio = keep_ratio
  340. def apply(self, data):
  341. """apply"""
  342. target_hw = self.target_hw
  343. im = data["image"]
  344. original_size = im.shape[:2]
  345. if self.keep_ratio:
  346. h, w = im.shape[0:2]
  347. target_hw, _ = self._rescale_size((h, w), self.target_hw)
  348. if self.size_divisor:
  349. target_hw = [
  350. math.ceil(i / self.size_divisor) * self.size_divisor for i in target_hw
  351. ]
  352. im_scale_w, im_scale_h = [
  353. target_hw[1] / original_size[1],
  354. target_hw[0] / original_size[0],
  355. ]
  356. im = F.resize(im, target_hw[::-1], interp=self.interp)
  357. data["image"] = im
  358. data["image_size"] = [im.shape[1], im.shape[0]]
  359. data["scale_factors"] = [im_scale_w, im_scale_h]
  360. return data
  361. @classmethod
  362. def get_input_keys(cls):
  363. """get input keys"""
  364. # image: Image in hw or hwc format.
  365. return ["image"]
  366. @classmethod
  367. def get_output_keys(cls):
  368. """get output keys"""
  369. # image: Image in hw or hwc format.
  370. # image_size: Width and height of the image.
  371. # scale_factors: Scale factors for image width and height.
  372. return ["image", "image_size", "scale_factors"]
  373. class PrintResult(BaseTransform):
  374. """Print Result Transform"""
  375. def apply(self, data):
  376. """apply"""
  377. logging.info("The prediction result is:")
  378. logging.info(data[K.BOXES])
  379. return data
  380. @classmethod
  381. def get_input_keys(cls):
  382. """get input keys"""
  383. return [K.BOXES]
  384. @classmethod
  385. def get_output_keys(cls):
  386. """get output keys"""
  387. return []