processors.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Union
  16. import numpy as np
  17. from ....utils import logging
  18. from ....utils.deps import class_requires_deps, is_dep_available
  19. from ...utils.benchmark import benchmark
  20. if is_dep_available("opencv-contrib-python"):
  21. import cv2
  22. if is_dep_available("pyclipper"):
  23. import pyclipper
  24. @benchmark.timeit
  25. @class_requires_deps("opencv-contrib-python")
  26. class DetResizeForTest:
  27. """DetResizeForTest"""
  28. def __init__(self, input_shape=None, **kwargs):
  29. super().__init__()
  30. self.resize_type = 0
  31. self.keep_ratio = False
  32. if input_shape is not None:
  33. self.input_shape = input_shape
  34. self.resize_type = 3
  35. elif "image_shape" in kwargs:
  36. self.image_shape = kwargs["image_shape"]
  37. self.resize_type = 1
  38. if "keep_ratio" in kwargs:
  39. self.keep_ratio = kwargs["keep_ratio"]
  40. elif "limit_side_len" in kwargs:
  41. self.limit_side_len = kwargs["limit_side_len"]
  42. self.limit_type = kwargs.get("limit_type", "min")
  43. elif "resize_long" in kwargs:
  44. self.resize_type = 2
  45. self.resize_long = kwargs.get("resize_long", 960)
  46. else:
  47. self.limit_side_len = 736
  48. self.limit_type = "min"
  49. def __call__(
  50. self,
  51. imgs,
  52. limit_side_len: Union[int, None] = None,
  53. limit_type: Union[str, None] = None,
  54. ):
  55. """apply"""
  56. resize_imgs, img_shapes = [], []
  57. for ori_img in imgs:
  58. img, shape = self.resize(ori_img, limit_side_len, limit_type)
  59. resize_imgs.append(img)
  60. img_shapes.append(shape)
  61. return resize_imgs, img_shapes
  62. def resize(
  63. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  64. ):
  65. src_h, src_w, _ = img.shape
  66. if sum([src_h, src_w]) < 64:
  67. img = self.image_padding(img)
  68. if self.resize_type == 0:
  69. # img, shape = self.resize_image_type0(img)
  70. img, [ratio_h, ratio_w] = self.resize_image_type0(
  71. img, limit_side_len, limit_type
  72. )
  73. elif self.resize_type == 2:
  74. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  75. elif self.resize_type == 3:
  76. img, [ratio_h, ratio_w] = self.resize_image_type3(img)
  77. else:
  78. # img, shape = self.resize_image_type1(img)
  79. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  80. return img, np.array([src_h, src_w, ratio_h, ratio_w])
  81. def image_padding(self, im, value=0):
  82. """padding image"""
  83. h, w, c = im.shape
  84. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  85. im_pad[:h, :w, :] = im
  86. return im_pad
  87. def resize_image_type1(self, img):
  88. """resize the image"""
  89. resize_h, resize_w = self.image_shape
  90. ori_h, ori_w = img.shape[:2] # (h, w, c)
  91. if self.keep_ratio is True:
  92. resize_w = ori_w * resize_h / ori_h
  93. N = math.ceil(resize_w / 32)
  94. resize_w = N * 32
  95. if resize_h == ori_h and resize_w == ori_w:
  96. return img, [1.0, 1.0]
  97. ratio_h = float(resize_h) / ori_h
  98. ratio_w = float(resize_w) / ori_w
  99. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  100. # return img, np.array([ori_h, ori_w])
  101. return img, [ratio_h, ratio_w]
  102. def resize_image_type0(
  103. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  104. ):
  105. """
  106. resize image to a size multiple of 32 which is required by the network
  107. args:
  108. img(array): array with shape [h, w, c]
  109. return(tuple):
  110. img, (ratio_h, ratio_w)
  111. """
  112. limit_side_len = limit_side_len or self.limit_side_len
  113. limit_type = limit_type or self.limit_type
  114. h, w, c = img.shape
  115. # limit the max side
  116. if limit_type == "max":
  117. if max(h, w) > limit_side_len:
  118. if h > w:
  119. ratio = float(limit_side_len) / h
  120. else:
  121. ratio = float(limit_side_len) / w
  122. else:
  123. ratio = 1.0
  124. elif limit_type == "min":
  125. if min(h, w) < limit_side_len:
  126. if h < w:
  127. ratio = float(limit_side_len) / h
  128. else:
  129. ratio = float(limit_side_len) / w
  130. else:
  131. ratio = 1.0
  132. elif limit_type == "resize_long":
  133. ratio = float(limit_side_len) / max(h, w)
  134. else:
  135. raise Exception("not support limit type, image ")
  136. resize_h = int(h * ratio)
  137. resize_w = int(w * ratio)
  138. resize_h = max(int(round(resize_h / 32) * 32), 32)
  139. resize_w = max(int(round(resize_w / 32) * 32), 32)
  140. if resize_h == h and resize_w == w:
  141. return img, [1.0, 1.0]
  142. try:
  143. if int(resize_w) <= 0 or int(resize_h) <= 0:
  144. return None, (None, None)
  145. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  146. except:
  147. logging.info(img.shape, resize_w, resize_h)
  148. raise
  149. ratio_h = resize_h / float(h)
  150. ratio_w = resize_w / float(w)
  151. return img, [ratio_h, ratio_w]
  152. def resize_image_type2(self, img):
  153. """resize image size"""
  154. h, w, _ = img.shape
  155. resize_w = w
  156. resize_h = h
  157. if resize_h > resize_w:
  158. ratio = float(self.resize_long) / resize_h
  159. else:
  160. ratio = float(self.resize_long) / resize_w
  161. resize_h = int(resize_h * ratio)
  162. resize_w = int(resize_w * ratio)
  163. max_stride = 128
  164. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  165. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  166. if resize_h == h and resize_w == w:
  167. return img, [1.0, 1.0]
  168. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  169. ratio_h = resize_h / float(h)
  170. ratio_w = resize_w / float(w)
  171. return img, [ratio_h, ratio_w]
  172. def resize_image_type3(self, img):
  173. """resize the image"""
  174. resize_c, resize_h, resize_w = self.input_shape # (c, h, w)
  175. ori_h, ori_w = img.shape[:2] # (h, w, c)
  176. if resize_h == ori_h and resize_w == ori_w:
  177. return img, [1.0, 1.0]
  178. ratio_h = float(resize_h) / ori_h
  179. ratio_w = float(resize_w) / ori_w
  180. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  181. return img, [ratio_h, ratio_w]
  182. @benchmark.timeit
  183. @class_requires_deps("opencv-contrib-python")
  184. class NormalizeImage:
  185. """normalize image such as substract mean, divide std"""
  186. def __init__(self, scale=None, mean=None, std=None, order="chw"):
  187. super().__init__()
  188. if isinstance(scale, str):
  189. scale = eval(scale)
  190. self.order = order
  191. scale = scale if scale is not None else 1.0 / 255.0
  192. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  193. std = std if std is not None else [0.229, 0.224, 0.225]
  194. self.alpha = [scale / std[i] for i in range(len(std))]
  195. self.beta = [-mean[i] / std[i] for i in range(len(std))]
  196. def __call__(self, imgs):
  197. """apply"""
  198. def _norm(img):
  199. if self.order == "chw":
  200. img = np.transpose(img, (2, 0, 1))
  201. split_im = list(cv2.split(img))
  202. for c in range(img.shape[2]):
  203. split_im[c] = split_im[c].astype(np.float32)
  204. split_im[c] *= self.alpha[c]
  205. split_im[c] += self.beta[c]
  206. res = cv2.merge(split_im)
  207. if self.order == "chw":
  208. res = np.transpose(res, (1, 2, 0))
  209. return res
  210. return [_norm(img) for img in imgs]
  211. @benchmark.timeit
  212. @class_requires_deps("opencv-contrib-python", "pyclipper")
  213. class DBPostProcess:
  214. """
  215. The post process for Differentiable Binarization (DB).
  216. """
  217. def __init__(
  218. self,
  219. thresh=0.3,
  220. box_thresh=0.7,
  221. max_candidates=1000,
  222. unclip_ratio=2.0,
  223. use_dilation=False,
  224. score_mode="fast",
  225. box_type="quad",
  226. **kwargs
  227. ):
  228. super().__init__()
  229. self.thresh = thresh
  230. self.box_thresh = box_thresh
  231. self.max_candidates = max_candidates
  232. self.unclip_ratio = unclip_ratio
  233. self.min_size = 3
  234. self.score_mode = score_mode
  235. self.box_type = box_type
  236. assert score_mode in [
  237. "slow",
  238. "fast",
  239. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  240. self.use_dilation = use_dilation
  241. def polygons_from_bitmap(
  242. self,
  243. pred,
  244. _bitmap,
  245. dest_width,
  246. dest_height,
  247. box_thresh,
  248. unclip_ratio,
  249. ):
  250. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  251. bitmap = _bitmap
  252. height, width = bitmap.shape
  253. width_scale = dest_width / width
  254. height_scale = dest_height / height
  255. boxes = []
  256. scores = []
  257. contours, _ = cv2.findContours(
  258. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  259. )
  260. for contour in contours[: self.max_candidates]:
  261. epsilon = 0.002 * cv2.arcLength(contour, True)
  262. approx = cv2.approxPolyDP(contour, epsilon, True)
  263. points = approx.reshape((-1, 2))
  264. if points.shape[0] < 4:
  265. continue
  266. score = self.box_score_fast(pred, points.reshape(-1, 2))
  267. if box_thresh > score:
  268. continue
  269. if points.shape[0] > 2:
  270. box = self.unclip(points, unclip_ratio)
  271. if len(box) > 1:
  272. continue
  273. else:
  274. continue
  275. box = box.reshape(-1, 2)
  276. if len(box) > 0:
  277. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  278. if sside < self.min_size + 2:
  279. continue
  280. else:
  281. continue
  282. box = np.array(box)
  283. for i in range(box.shape[0]):
  284. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  285. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  286. boxes.append(box)
  287. scores.append(score)
  288. return boxes, scores
  289. def boxes_from_bitmap(
  290. self,
  291. pred,
  292. _bitmap,
  293. dest_width,
  294. dest_height,
  295. box_thresh,
  296. unclip_ratio,
  297. ):
  298. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  299. bitmap = _bitmap
  300. height, width = bitmap.shape
  301. width_scale = dest_width / width
  302. height_scale = dest_height / height
  303. outs = cv2.findContours(
  304. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  305. )
  306. if len(outs) == 3:
  307. img, contours, _ = outs[0], outs[1], outs[2]
  308. elif len(outs) == 2:
  309. contours, _ = outs[0], outs[1]
  310. num_contours = min(len(contours), self.max_candidates)
  311. boxes = []
  312. scores = []
  313. for index in range(num_contours):
  314. contour = contours[index]
  315. points, sside = self.get_mini_boxes(contour)
  316. if sside < self.min_size:
  317. continue
  318. points = np.array(points)
  319. if self.score_mode == "fast":
  320. score = self.box_score_fast(pred, points.reshape(-1, 2))
  321. else:
  322. score = self.box_score_slow(pred, contour)
  323. if box_thresh > score:
  324. continue
  325. box = self.unclip(points, unclip_ratio).reshape(-1, 1, 2)
  326. box, sside = self.get_mini_boxes(box)
  327. if sside < self.min_size + 2:
  328. continue
  329. box = np.array(box)
  330. for i in range(box.shape[0]):
  331. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  332. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  333. boxes.append(box.astype(np.int16))
  334. scores.append(score)
  335. return np.array(boxes, dtype=np.int16), scores
  336. def unclip(self, box, unclip_ratio):
  337. """unclip"""
  338. area = cv2.contourArea(box)
  339. length = cv2.arcLength(box, True)
  340. distance = area * unclip_ratio / length
  341. offset = pyclipper.PyclipperOffset()
  342. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  343. try:
  344. expanded = np.array(offset.Execute(distance))
  345. except ValueError:
  346. expanded = np.array(offset.Execute(distance)[0])
  347. return expanded
  348. def get_mini_boxes(self, contour):
  349. """get mini boxes"""
  350. bounding_box = cv2.minAreaRect(contour)
  351. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  352. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  353. if points[1][1] > points[0][1]:
  354. index_1 = 0
  355. index_4 = 1
  356. else:
  357. index_1 = 1
  358. index_4 = 0
  359. if points[3][1] > points[2][1]:
  360. index_2 = 2
  361. index_3 = 3
  362. else:
  363. index_2 = 3
  364. index_3 = 2
  365. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  366. return box, min(bounding_box[1])
  367. def box_score_fast(self, bitmap, _box):
  368. """box_score_fast: use bbox mean score as the mean score"""
  369. h, w = bitmap.shape[:2]
  370. box = _box.copy()
  371. xmin = max(0, min(math.floor(box[:, 0].min()), w - 1))
  372. xmax = max(0, min(math.ceil(box[:, 0].max()), w - 1))
  373. ymin = max(0, min(math.floor(box[:, 1].min()), h - 1))
  374. ymax = max(0, min(math.ceil(box[:, 1].max()), h - 1))
  375. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  376. box[:, 0] = box[:, 0] - xmin
  377. box[:, 1] = box[:, 1] - ymin
  378. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  379. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  380. def box_score_slow(self, bitmap, contour):
  381. """box_score_slow: use polygon mean score as the mean score"""
  382. h, w = bitmap.shape[:2]
  383. contour = contour.copy()
  384. contour = np.reshape(contour, (-1, 2))
  385. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  386. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  387. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  388. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  389. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  390. contour[:, 0] = contour[:, 0] - xmin
  391. contour[:, 1] = contour[:, 1] - ymin
  392. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  393. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  394. def __call__(
  395. self,
  396. preds,
  397. img_shapes,
  398. thresh: Union[float, None] = None,
  399. box_thresh: Union[float, None] = None,
  400. unclip_ratio: Union[float, None] = None,
  401. ):
  402. """apply"""
  403. boxes, scores = [], []
  404. for pred, img_shape in zip(preds[0], img_shapes):
  405. box, score = self.process(
  406. pred,
  407. img_shape,
  408. thresh or self.thresh,
  409. box_thresh or self.box_thresh,
  410. unclip_ratio or self.unclip_ratio,
  411. )
  412. boxes.append(box)
  413. scores.append(score)
  414. return boxes, scores
  415. def process(
  416. self,
  417. pred,
  418. img_shape,
  419. thresh,
  420. box_thresh,
  421. unclip_ratio,
  422. ):
  423. pred = pred[0, :, :]
  424. segmentation = pred > thresh
  425. dilation_kernel = None if not self.use_dilation else np.array([[1, 1], [1, 1]])
  426. src_h, src_w, ratio_h, ratio_w = img_shape
  427. if dilation_kernel is not None:
  428. mask = cv2.dilate(
  429. np.array(segmentation).astype(np.uint8),
  430. dilation_kernel,
  431. )
  432. else:
  433. mask = segmentation
  434. if self.box_type == "poly":
  435. boxes, scores = self.polygons_from_bitmap(
  436. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  437. )
  438. elif self.box_type == "quad":
  439. boxes, scores = self.boxes_from_bitmap(
  440. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  441. )
  442. else:
  443. raise ValueError("box_type can only be one of ['quad', 'poly']")
  444. return boxes, scores