processors.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import copy
  18. import math
  19. import pyclipper
  20. import numpy as np
  21. from numpy.linalg import norm
  22. from PIL import Image
  23. from shapely.geometry import Polygon
  24. from ...utils.io import ImageReader
  25. from ....utils import logging
  26. class DetResizeForTest:
  27. """DetResizeForTest"""
  28. def __init__(self, **kwargs):
  29. super().__init__()
  30. self.resize_type = 0
  31. self.keep_ratio = False
  32. if "image_shape" in kwargs:
  33. self.image_shape = kwargs["image_shape"]
  34. self.resize_type = 1
  35. if "keep_ratio" in kwargs:
  36. self.keep_ratio = kwargs["keep_ratio"]
  37. elif "limit_side_len" in kwargs:
  38. self.limit_side_len = kwargs["limit_side_len"]
  39. self.limit_type = kwargs.get("limit_type", "min")
  40. elif "resize_long" in kwargs:
  41. self.resize_type = 2
  42. self.resize_long = kwargs.get("resize_long", 960)
  43. else:
  44. self.limit_side_len = 736
  45. self.limit_type = "min"
  46. def __call__(self, imgs):
  47. """apply"""
  48. resize_imgs, img_shapes = [], []
  49. for ori_img in imgs:
  50. img, shape = self.resize(ori_img)
  51. resize_imgs.append(img)
  52. img_shapes.append(shape)
  53. return resize_imgs, img_shapes
  54. def resize(self, img):
  55. src_h, src_w, _ = img.shape
  56. if sum([src_h, src_w]) < 64:
  57. img = self.image_padding(img)
  58. if self.resize_type == 0:
  59. # img, shape = self.resize_image_type0(img)
  60. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  61. elif self.resize_type == 2:
  62. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  63. else:
  64. # img, shape = self.resize_image_type1(img)
  65. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  66. return img, np.array([src_h, src_w, ratio_h, ratio_w])
  67. def image_padding(self, im, value=0):
  68. """padding image"""
  69. h, w, c = im.shape
  70. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  71. im_pad[:h, :w, :] = im
  72. return im_pad
  73. def resize_image_type1(self, img):
  74. """resize the image"""
  75. resize_h, resize_w = self.image_shape
  76. ori_h, ori_w = img.shape[:2] # (h, w, c)
  77. if self.keep_ratio is True:
  78. resize_w = ori_w * resize_h / ori_h
  79. N = math.ceil(resize_w / 32)
  80. resize_w = N * 32
  81. ratio_h = float(resize_h) / ori_h
  82. ratio_w = float(resize_w) / ori_w
  83. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  84. # return img, np.array([ori_h, ori_w])
  85. return img, [ratio_h, ratio_w]
  86. def resize_image_type0(self, img):
  87. """
  88. resize image to a size multiple of 32 which is required by the network
  89. args:
  90. img(array): array with shape [h, w, c]
  91. return(tuple):
  92. img, (ratio_h, ratio_w)
  93. """
  94. limit_side_len = self.limit_side_len
  95. h, w, c = img.shape
  96. # limit the max side
  97. if self.limit_type == "max":
  98. if max(h, w) > limit_side_len:
  99. if h > w:
  100. ratio = float(limit_side_len) / h
  101. else:
  102. ratio = float(limit_side_len) / w
  103. else:
  104. ratio = 1.0
  105. elif self.limit_type == "min":
  106. if min(h, w) < limit_side_len:
  107. if h < w:
  108. ratio = float(limit_side_len) / h
  109. else:
  110. ratio = float(limit_side_len) / w
  111. else:
  112. ratio = 1.0
  113. elif self.limit_type == "resize_long":
  114. ratio = float(limit_side_len) / max(h, w)
  115. else:
  116. raise Exception("not support limit type, image ")
  117. resize_h = int(h * ratio)
  118. resize_w = int(w * ratio)
  119. resize_h = max(int(round(resize_h / 32) * 32), 32)
  120. resize_w = max(int(round(resize_w / 32) * 32), 32)
  121. try:
  122. if int(resize_w) <= 0 or int(resize_h) <= 0:
  123. return None, (None, None)
  124. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  125. except:
  126. logging.info(img.shape, resize_w, resize_h)
  127. sys.exit(0)
  128. ratio_h = resize_h / float(h)
  129. ratio_w = resize_w / float(w)
  130. return img, [ratio_h, ratio_w]
  131. def resize_image_type2(self, img):
  132. """resize image size"""
  133. h, w, _ = img.shape
  134. resize_w = w
  135. resize_h = h
  136. if resize_h > resize_w:
  137. ratio = float(self.resize_long) / resize_h
  138. else:
  139. ratio = float(self.resize_long) / resize_w
  140. resize_h = int(resize_h * ratio)
  141. resize_w = int(resize_w * ratio)
  142. max_stride = 128
  143. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  144. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  145. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  146. ratio_h = resize_h / float(h)
  147. ratio_w = resize_w / float(w)
  148. return img, [ratio_h, ratio_w]
  149. class NormalizeImage:
  150. """normalize image such as substract mean, divide std"""
  151. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  152. super().__init__()
  153. if isinstance(scale, str):
  154. scale = eval(scale)
  155. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  156. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  157. std = std if std is not None else [0.229, 0.224, 0.225]
  158. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  159. self.mean = np.array(mean).reshape(shape).astype("float32")
  160. self.std = np.array(std).reshape(shape).astype("float32")
  161. def __call__(self, imgs):
  162. """apply"""
  163. def norm(img):
  164. return (img.astype("float32") * self.scale - self.mean) / self.std
  165. return [norm(img) for img in imgs]
  166. class DBPostProcess:
  167. """
  168. The post process for Differentiable Binarization (DB).
  169. """
  170. def __init__(
  171. self,
  172. thresh=0.3,
  173. box_thresh=0.7,
  174. max_candidates=1000,
  175. unclip_ratio=2.0,
  176. use_dilation=False,
  177. score_mode="fast",
  178. box_type="quad",
  179. **kwargs
  180. ):
  181. super().__init__()
  182. self.thresh = thresh
  183. self.box_thresh = box_thresh
  184. self.max_candidates = max_candidates
  185. self.unclip_ratio = unclip_ratio
  186. self.min_size = 3
  187. self.score_mode = score_mode
  188. self.box_type = box_type
  189. assert score_mode in [
  190. "slow",
  191. "fast",
  192. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  193. self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
  194. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  195. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  196. bitmap = _bitmap
  197. height, width = bitmap.shape
  198. boxes = []
  199. scores = []
  200. contours, _ = cv2.findContours(
  201. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  202. )
  203. for contour in contours[: self.max_candidates]:
  204. epsilon = 0.002 * cv2.arcLength(contour, True)
  205. approx = cv2.approxPolyDP(contour, epsilon, True)
  206. points = approx.reshape((-1, 2))
  207. if points.shape[0] < 4:
  208. continue
  209. score = self.box_score_fast(pred, points.reshape(-1, 2))
  210. if self.box_thresh > score:
  211. continue
  212. if points.shape[0] > 2:
  213. box = self.unclip(points, self.unclip_ratio)
  214. if len(box) > 1:
  215. continue
  216. else:
  217. continue
  218. box = box.reshape(-1, 2)
  219. if len(box) > 0:
  220. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  221. if sside < self.min_size + 2:
  222. continue
  223. else:
  224. continue
  225. box = np.array(box)
  226. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  227. box[:, 1] = np.clip(
  228. np.round(box[:, 1] / height * dest_height), 0, dest_height
  229. )
  230. boxes.append(box)
  231. scores.append(score)
  232. return boxes, scores
  233. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  234. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  235. bitmap = _bitmap
  236. height, width = bitmap.shape
  237. outs = cv2.findContours(
  238. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  239. )
  240. if len(outs) == 3:
  241. img, contours, _ = outs[0], outs[1], outs[2]
  242. elif len(outs) == 2:
  243. contours, _ = outs[0], outs[1]
  244. num_contours = min(len(contours), self.max_candidates)
  245. boxes = []
  246. scores = []
  247. for index in range(num_contours):
  248. contour = contours[index]
  249. points, sside = self.get_mini_boxes(contour)
  250. if sside < self.min_size:
  251. continue
  252. points = np.array(points)
  253. if self.score_mode == "fast":
  254. score = self.box_score_fast(pred, points.reshape(-1, 2))
  255. else:
  256. score = self.box_score_slow(pred, contour)
  257. if self.box_thresh > score:
  258. continue
  259. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  260. box, sside = self.get_mini_boxes(box)
  261. if sside < self.min_size + 2:
  262. continue
  263. box = np.array(box)
  264. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  265. box[:, 1] = np.clip(
  266. np.round(box[:, 1] / height * dest_height), 0, dest_height
  267. )
  268. boxes.append(box.astype(np.int16))
  269. scores.append(score)
  270. return np.array(boxes, dtype=np.int16), scores
  271. def unclip(self, box, unclip_ratio):
  272. """unclip"""
  273. poly = Polygon(box)
  274. distance = poly.area * unclip_ratio / poly.length
  275. offset = pyclipper.PyclipperOffset()
  276. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  277. try:
  278. expanded = np.array(offset.Execute(distance))
  279. except ValueError:
  280. expanded = np.array(offset.Execute(distance)[0])
  281. return expanded
  282. def get_mini_boxes(self, contour):
  283. """get mini boxes"""
  284. bounding_box = cv2.minAreaRect(contour)
  285. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  286. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  287. if points[1][1] > points[0][1]:
  288. index_1 = 0
  289. index_4 = 1
  290. else:
  291. index_1 = 1
  292. index_4 = 0
  293. if points[3][1] > points[2][1]:
  294. index_2 = 2
  295. index_3 = 3
  296. else:
  297. index_2 = 3
  298. index_3 = 2
  299. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  300. return box, min(bounding_box[1])
  301. def box_score_fast(self, bitmap, _box):
  302. """box_score_fast: use bbox mean score as the mean score"""
  303. h, w = bitmap.shape[:2]
  304. box = _box.copy()
  305. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  306. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  307. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  308. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  309. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  310. box[:, 0] = box[:, 0] - xmin
  311. box[:, 1] = box[:, 1] - ymin
  312. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  313. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  314. def box_score_slow(self, bitmap, contour):
  315. """box_score_slow: use polyon mean score as the mean score"""
  316. h, w = bitmap.shape[:2]
  317. contour = contour.copy()
  318. contour = np.reshape(contour, (-1, 2))
  319. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  320. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  321. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  322. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  323. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  324. contour[:, 0] = contour[:, 0] - xmin
  325. contour[:, 1] = contour[:, 1] - ymin
  326. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  327. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  328. def __call__(self, preds, img_shapes):
  329. """apply"""
  330. boxes, scores = [], []
  331. for pred, img_shape in zip(preds[0], img_shapes):
  332. box, score = self.process(pred, img_shape)
  333. boxes.append(box)
  334. scores.append(score)
  335. return boxes, scores
  336. def process(self, pred, img_shape):
  337. pred = pred[0, :, :]
  338. segmentation = pred > self.thresh
  339. src_h, src_w, ratio_h, ratio_w = img_shape
  340. if self.dilation_kernel is not None:
  341. mask = cv2.dilate(
  342. np.array(segmentation).astype(np.uint8),
  343. self.dilation_kernel,
  344. )
  345. else:
  346. mask = segmentation
  347. if self.box_type == "poly":
  348. boxes, scores = self.polygons_from_bitmap(pred, mask, src_w, src_h)
  349. elif self.box_type == "quad":
  350. boxes, scores = self.boxes_from_bitmap(pred, mask, src_w, src_h)
  351. else:
  352. raise ValueError("box_type can only be one of ['quad', 'poly']")
  353. return boxes, scores