transforms.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import math
  18. import PIL
  19. from PIL import Image, ImageDraw, ImageFont
  20. from .keys import DetKeys as K
  21. from ...base import BaseTransform
  22. from ...base.predictor.io import ImageWriter, ImageReader
  23. from ...base.predictor.transforms import image_functions as F
  24. from ...base.predictor.transforms.image_common import _BaseResize, _check_image_size
  25. from ....utils.fonts import PINGFANG_FONT_FILE_PATH
  26. from ....utils import logging
  27. __all__ = ["SaveDetResults", "PadStride", "DetResize", "PrintResult"]
  28. def get_color_map_list(num_classes):
  29. """
  30. Args:
  31. num_classes (int): number of class
  32. Returns:
  33. color_map (list): RGB color list
  34. """
  35. color_map = num_classes * [0, 0, 0]
  36. for i in range(0, num_classes):
  37. j = 0
  38. lab = i
  39. while lab:
  40. color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
  41. color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
  42. color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
  43. j += 1
  44. lab >>= 3
  45. color_map = [color_map[i : i + 3] for i in range(0, len(color_map), 3)]
  46. return color_map
  47. def colormap(rgb=False):
  48. """
  49. Get colormap
  50. The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
  51. utils/colormap.py
  52. """
  53. color_list = np.array(
  54. [
  55. 0xFF,
  56. 0x00,
  57. 0x00,
  58. 0xCC,
  59. 0xFF,
  60. 0x00,
  61. 0x00,
  62. 0xFF,
  63. 0x66,
  64. 0x00,
  65. 0x66,
  66. 0xFF,
  67. 0xCC,
  68. 0x00,
  69. 0xFF,
  70. 0xFF,
  71. 0x4D,
  72. 0x00,
  73. 0x80,
  74. 0xFF,
  75. 0x00,
  76. 0x00,
  77. 0xFF,
  78. 0xB2,
  79. 0x00,
  80. 0x1A,
  81. 0xFF,
  82. 0xFF,
  83. 0x00,
  84. 0xE5,
  85. 0xFF,
  86. 0x99,
  87. 0x00,
  88. 0x33,
  89. 0xFF,
  90. 0x00,
  91. 0x00,
  92. 0xFF,
  93. 0xFF,
  94. 0x33,
  95. 0x00,
  96. 0xFF,
  97. 0xFF,
  98. 0x00,
  99. 0x99,
  100. 0xFF,
  101. 0xE5,
  102. 0x00,
  103. 0x00,
  104. 0xFF,
  105. 0x1A,
  106. 0x00,
  107. 0xB2,
  108. 0xFF,
  109. 0x80,
  110. 0x00,
  111. 0xFF,
  112. 0xFF,
  113. 0x00,
  114. 0x4D,
  115. ]
  116. ).astype(np.float32)
  117. color_list = color_list.reshape((-1, 3))
  118. if not rgb:
  119. color_list = color_list[:, ::-1]
  120. return color_list.astype("int32")
  121. def font_colormap(color_index):
  122. """
  123. Get font color according to the index of colormap
  124. """
  125. dark = np.array([0x14, 0x0E, 0x35])
  126. light = np.array([0xFF, 0xFF, 0xFF])
  127. light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
  128. if color_index in light_indexs:
  129. return light.astype("int32")
  130. else:
  131. return dark.astype("int32")
  132. def draw_box(img, np_boxes, labels, threshold=0.5):
  133. """
  134. Args:
  135. img (PIL.Image.Image): PIL image
  136. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  137. matix element:[class, score, x_min, y_min, x_max, y_max]
  138. labels (list): labels:['class1', ..., 'classn']
  139. threshold (float): threshold of box
  140. Returns:
  141. img (PIL.Image.Image): visualized image
  142. """
  143. font_size = int(0.024 * int(img.width)) + 2
  144. font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
  145. draw_thickness = int(max(img.size) * 0.005)
  146. draw = ImageDraw.Draw(img)
  147. clsid2color = {}
  148. catid2fontcolor = {}
  149. color_list = colormap(rgb=True)
  150. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  151. np_boxes = np_boxes[expect_boxes, :]
  152. for i, dt in enumerate(np_boxes):
  153. clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
  154. if clsid not in clsid2color:
  155. color_index = i % len(color_list)
  156. clsid2color[clsid] = color_list[color_index]
  157. catid2fontcolor[clsid] = font_colormap(color_index)
  158. color = tuple(clsid2color[clsid])
  159. font_color = tuple(catid2fontcolor[clsid])
  160. xmin, ymin, xmax, ymax = bbox
  161. # draw bbox
  162. draw.line(
  163. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
  164. width=draw_thickness,
  165. fill=color,
  166. )
  167. # draw label
  168. text = "{} {:.2f}".format(labels[clsid], score)
  169. if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
  170. tw, th = draw.textsize(text, font=font)
  171. else:
  172. left, top, right, bottom = draw.textbbox((0, 0), text, font)
  173. tw, th = right - left, bottom - top
  174. if ymin < th:
  175. draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
  176. draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
  177. else:
  178. draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
  179. draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
  180. return img
  181. def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
  182. """
  183. Args:
  184. im (PIL.Image.Image): PIL image
  185. np_boxes (np.ndarray): shape:[N,6], N: number of box,
  186. matix element:[class, score, x_min, y_min, x_max, y_max]
  187. np_masks (np.ndarray): shape:[N, im_h, im_w]
  188. labels (list): labels:['class1', ..., 'classn']
  189. threshold (float): threshold of mask
  190. Returns:
  191. im (PIL.Image.Image): visualized image
  192. """
  193. color_list = get_color_map_list(len(labels))
  194. w_ratio = 0.4
  195. alpha = 0.7
  196. im = np.array(im).astype("float32")
  197. clsid2color = {}
  198. expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
  199. np_boxes = np_boxes[expect_boxes, :]
  200. np_masks = np_masks[expect_boxes, :, :]
  201. im_h, im_w = im.shape[:2]
  202. np_masks = np_masks[:, :im_h, :im_w]
  203. for i in range(len(np_masks)):
  204. clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
  205. mask = np_masks[i]
  206. if clsid not in clsid2color:
  207. clsid2color[clsid] = color_list[clsid]
  208. color_mask = clsid2color[clsid]
  209. for c in range(3):
  210. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  211. idx = np.nonzero(mask)
  212. color_mask = np.array(color_mask)
  213. im[idx[0], idx[1], :] *= 1.0 - alpha
  214. im[idx[0], idx[1], :] += alpha * color_mask
  215. return Image.fromarray(im.astype("uint8"))
  216. def draw_segm(im, np_segms, np_label, np_score, labels, threshold=0.5, alpha=0.7):
  217. """
  218. Draw segmentation on image
  219. """
  220. mask_color_id = 0
  221. w_ratio = 0.4
  222. color_list = get_color_map_list(len(labels))
  223. im = np.array(im).astype("float32")
  224. clsid2color = {}
  225. np_segms = np_segms.astype(np.uint8)
  226. for i in range(np_segms.shape[0]):
  227. mask, score, clsid = np_segms[i], np_score[i], np_label[i]
  228. if score < threshold:
  229. continue
  230. if clsid not in clsid2color:
  231. clsid2color[clsid] = color_list[clsid]
  232. color_mask = clsid2color[clsid]
  233. for c in range(3):
  234. color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
  235. idx = np.nonzero(mask)
  236. color_mask = np.array(color_mask)
  237. idx0 = np.minimum(idx[0], im.shape[0] - 1)
  238. idx1 = np.minimum(idx[1], im.shape[1] - 1)
  239. im[idx0, idx1, :] *= 1.0 - alpha
  240. im[idx0, idx1, :] += alpha * color_mask
  241. sum_x = np.sum(mask, axis=0)
  242. x = np.where(sum_x > 0.5)[0]
  243. sum_y = np.sum(mask, axis=1)
  244. y = np.where(sum_y > 0.5)[0]
  245. x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
  246. cv2.rectangle(
  247. im, (x0, y0), (x1, y1), tuple(color_mask.astype("int32").tolist()), 1
  248. )
  249. bbox_text = "%s %.2f" % (labels[clsid], score)
  250. t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
  251. cv2.rectangle(
  252. im,
  253. (x0, y0),
  254. (x0 + t_size[0], y0 - t_size[1] - 3),
  255. tuple(color_mask.astype("int32").tolist()),
  256. -1,
  257. )
  258. cv2.putText(
  259. im,
  260. bbox_text,
  261. (x0, y0 - 2),
  262. cv2.FONT_HERSHEY_SIMPLEX,
  263. 0.3,
  264. (0, 0, 0),
  265. 1,
  266. lineType=cv2.LINE_AA,
  267. )
  268. return Image.fromarray(im.astype("uint8"))
  269. class SaveDetResults(BaseTransform):
  270. """Save Result Transform"""
  271. def __init__(self, save_dir, threshold=0.5, labels=None):
  272. super().__init__()
  273. self.save_dir = save_dir
  274. self.threshold = threshold
  275. self.labels = labels
  276. # We use pillow backend to save both numpy arrays and PIL Image objects
  277. self._writer = ImageWriter(backend="pillow")
  278. def apply(self, data):
  279. """apply"""
  280. ori_path = data[K.IM_PATH]
  281. file_name = os.path.basename(ori_path)
  282. save_path = os.path.join(self.save_dir, file_name)
  283. labels = self.labels
  284. image = ImageReader(backend="pil").read(ori_path)
  285. if K.MASKS in data:
  286. image = draw_mask(
  287. image,
  288. data[K.BOXES],
  289. data[K.MASKS],
  290. threshold=self.threshold,
  291. labels=labels,
  292. )
  293. if K.SEGM in data:
  294. image = draw_segm(
  295. image,
  296. data[K.SEGM],
  297. data[K.LABEL],
  298. data[K.SCORE],
  299. labels=labels,
  300. threshold=self.threshold,
  301. )
  302. if K.SEGM not in data:
  303. image = draw_box(
  304. image, data[K.BOXES], threshold=self.threshold, labels=labels
  305. )
  306. self._write_image(save_path, image)
  307. return data
  308. def _write_image(self, path, image):
  309. """write image"""
  310. if os.path.exists(path):
  311. logging.warning(f"{path} already exists. Overwriting it.")
  312. self._writer.write(path, image)
  313. @classmethod
  314. def get_input_keys(cls):
  315. """get input keys"""
  316. return [[K.IM_PATH, K.BOXES], [K.IM_PATH]]
  317. @classmethod
  318. def get_output_keys(cls):
  319. """get output keys"""
  320. return []
  321. class PadStride(BaseTransform):
  322. """padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
  323. Args:
  324. stride (bool): model with FPN need image shape % stride == 0
  325. """
  326. def __init__(self, stride=0):
  327. self.coarsest_stride = stride
  328. def apply(self, data):
  329. """
  330. Args:
  331. im (np.ndarray): image (np.ndarray)
  332. Returns:
  333. im (np.ndarray): processed image (np.ndarray)
  334. """
  335. im = data[K.IMAGE]
  336. coarsest_stride = self.coarsest_stride
  337. if coarsest_stride <= 0:
  338. return data
  339. im_c, im_h, im_w = im.shape
  340. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  341. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  342. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  343. padding_im[:, :im_h, :im_w] = im
  344. data[K.IMAGE] = padding_im
  345. return data
  346. @classmethod
  347. def get_input_keys(cls):
  348. """get input keys"""
  349. return [K.IMAGE]
  350. @classmethod
  351. def get_output_keys(cls):
  352. """get output keys"""
  353. return [K.IMAGE]
  354. class Pad(BaseTransform):
  355. def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
  356. """
  357. Pad image to a specified size.
  358. Args:
  359. size (list[int]): image target size
  360. fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
  361. """
  362. super(Pad, self).__init__()
  363. if isinstance(size, int):
  364. size = [size, size]
  365. self.size = size
  366. self.fill_value = fill_value
  367. def apply(self, data):
  368. im = data[K.IMAGE]
  369. im_h, im_w = im.shape[:2]
  370. h, w = self.size
  371. if h == im_h and w == im_w:
  372. # im = im.astype(np.float32)
  373. return data
  374. canvas = np.ones((h, w, 3), dtype=np.float32)
  375. canvas *= np.array(self.fill_value, dtype=np.float32)
  376. canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
  377. data[K.IMAGE] = canvas
  378. return data
  379. @classmethod
  380. def get_input_keys(cls):
  381. """get input keys"""
  382. return [K.IMAGE]
  383. @classmethod
  384. def get_output_keys(cls):
  385. """get output keys"""
  386. return [K.IMAGE]
  387. class DetResize(_BaseResize):
  388. """
  389. Resize the image.
  390. Args:
  391. target_size (list|tuple|int): Target height and width.
  392. keep_ratio (bool, optional): Whether to keep the aspect ratio of resized
  393. image. Default: False.
  394. size_divisor (int|None, optional): Divisor of resized image size.
  395. Default: None.
  396. interp (str, optional): Interpolation method. Choices are 'NEAREST',
  397. 'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
  398. """
  399. def __init__(self, target_hw, keep_ratio=False, size_divisor=None, interp="LINEAR"):
  400. super().__init__(size_divisor=size_divisor, interp=interp)
  401. if isinstance(target_hw, int):
  402. target_hw = [target_hw, target_hw]
  403. _check_image_size(target_hw)
  404. self.target_hw = target_hw
  405. self.keep_ratio = keep_ratio
  406. def apply(self, data):
  407. """apply"""
  408. target_hw = self.target_hw
  409. im = data["image"]
  410. original_size = im.shape[:2]
  411. if self.keep_ratio:
  412. h, w = im.shape[0:2]
  413. target_hw, _ = self._rescale_size((h, w), self.target_hw)
  414. if self.size_divisor:
  415. target_hw = [
  416. math.ceil(i / self.size_divisor) * self.size_divisor for i in target_hw
  417. ]
  418. im_scale_w, im_scale_h = [
  419. target_hw[1] / original_size[1],
  420. target_hw[0] / original_size[0],
  421. ]
  422. im = F.resize(im, target_hw[::-1], interp=self.interp)
  423. data["image"] = im
  424. data["image_size"] = [im.shape[1], im.shape[0]]
  425. data["scale_factors"] = [im_scale_w, im_scale_h]
  426. return data
  427. @classmethod
  428. def get_input_keys(cls):
  429. """get input keys"""
  430. # image: Image in hw or hwc format.
  431. return ["image"]
  432. @classmethod
  433. def get_output_keys(cls):
  434. """get output keys"""
  435. # image: Image in hw or hwc format.
  436. # image_size: Width and height of the image.
  437. # scale_factors: Scale factors for image width and height.
  438. return ["image", "image_size", "scale_factors"]
  439. class PrintResult(BaseTransform):
  440. """Print Result Transform"""
  441. def apply(self, data):
  442. """apply"""
  443. logging.info("The prediction result is:")
  444. logging.info(data[K.BOXES] if K.BOXES in data else data[K.SEGM])
  445. return data
  446. @classmethod
  447. def get_input_keys(cls):
  448. """get input keys"""
  449. return [[], [K.BOXES]]
  450. @classmethod
  451. def get_output_keys(cls):
  452. """get output keys"""
  453. return []