processors.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import sys
  16. from typing import Union
  17. import numpy as np
  18. from ....utils import logging
  19. from ....utils.deps import class_requires_deps, is_dep_available
  20. from ...utils.benchmark import benchmark
  21. if is_dep_available("opencv-contrib-python"):
  22. import cv2
  23. if is_dep_available("pyclipper"):
  24. import pyclipper
  25. @benchmark.timeit
  26. @class_requires_deps("opencv-contrib-python")
  27. class DetResizeForTest:
  28. """DetResizeForTest"""
  29. def __init__(self, input_shape=None, **kwargs):
  30. super().__init__()
  31. self.resize_type = 0
  32. self.keep_ratio = False
  33. if input_shape is not None:
  34. self.input_shape = input_shape
  35. self.resize_type = 3
  36. elif "image_shape" in kwargs:
  37. self.image_shape = kwargs["image_shape"]
  38. self.resize_type = 1
  39. if "keep_ratio" in kwargs:
  40. self.keep_ratio = kwargs["keep_ratio"]
  41. elif "limit_side_len" in kwargs:
  42. self.limit_side_len = kwargs["limit_side_len"]
  43. self.limit_type = kwargs.get("limit_type", "min")
  44. elif "resize_long" in kwargs:
  45. self.resize_type = 2
  46. self.resize_long = kwargs.get("resize_long", 960)
  47. else:
  48. self.limit_side_len = 736
  49. self.limit_type = "min"
  50. def __call__(
  51. self,
  52. imgs,
  53. limit_side_len: Union[int, None] = None,
  54. limit_type: Union[str, None] = None,
  55. ):
  56. """apply"""
  57. resize_imgs, img_shapes = [], []
  58. for ori_img in imgs:
  59. img, shape = self.resize(ori_img, limit_side_len, limit_type)
  60. resize_imgs.append(img)
  61. img_shapes.append(shape)
  62. return resize_imgs, img_shapes
  63. def resize(
  64. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  65. ):
  66. src_h, src_w, _ = img.shape
  67. if sum([src_h, src_w]) < 64:
  68. img = self.image_padding(img)
  69. if self.resize_type == 0:
  70. # img, shape = self.resize_image_type0(img)
  71. img, [ratio_h, ratio_w] = self.resize_image_type0(
  72. img, limit_side_len, limit_type
  73. )
  74. elif self.resize_type == 2:
  75. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  76. elif self.resize_type == 3:
  77. img, [ratio_h, ratio_w] = self.resize_image_type3(img)
  78. else:
  79. # img, shape = self.resize_image_type1(img)
  80. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  81. return img, np.array([src_h, src_w, ratio_h, ratio_w])
  82. def image_padding(self, im, value=0):
  83. """padding image"""
  84. h, w, c = im.shape
  85. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  86. im_pad[:h, :w, :] = im
  87. return im_pad
  88. def resize_image_type1(self, img):
  89. """resize the image"""
  90. resize_h, resize_w = self.image_shape
  91. ori_h, ori_w = img.shape[:2] # (h, w, c)
  92. if self.keep_ratio is True:
  93. resize_w = ori_w * resize_h / ori_h
  94. N = math.ceil(resize_w / 32)
  95. resize_w = N * 32
  96. ratio_h = float(resize_h) / ori_h
  97. ratio_w = float(resize_w) / ori_w
  98. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  99. # return img, np.array([ori_h, ori_w])
  100. return img, [ratio_h, ratio_w]
  101. def resize_image_type0(
  102. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  103. ):
  104. """
  105. resize image to a size multiple of 32 which is required by the network
  106. args:
  107. img(array): array with shape [h, w, c]
  108. return(tuple):
  109. img, (ratio_h, ratio_w)
  110. """
  111. limit_side_len = limit_side_len or self.limit_side_len
  112. limit_type = limit_type or self.limit_type
  113. h, w, c = img.shape
  114. # limit the max side
  115. if limit_type == "max":
  116. if max(h, w) > limit_side_len:
  117. if h > w:
  118. ratio = float(limit_side_len) / h
  119. else:
  120. ratio = float(limit_side_len) / w
  121. else:
  122. ratio = 1.0
  123. elif limit_type == "min":
  124. if min(h, w) < limit_side_len:
  125. if h < w:
  126. ratio = float(limit_side_len) / h
  127. else:
  128. ratio = float(limit_side_len) / w
  129. else:
  130. ratio = 1.0
  131. elif limit_type == "resize_long":
  132. ratio = float(limit_side_len) / max(h, w)
  133. else:
  134. raise Exception("not support limit type, image ")
  135. resize_h = int(h * ratio)
  136. resize_w = int(w * ratio)
  137. resize_h = max(int(round(resize_h / 32) * 32), 32)
  138. resize_w = max(int(round(resize_w / 32) * 32), 32)
  139. try:
  140. if int(resize_w) <= 0 or int(resize_h) <= 0:
  141. return None, (None, None)
  142. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  143. except:
  144. logging.info(img.shape, resize_w, resize_h)
  145. sys.exit(0)
  146. ratio_h = resize_h / float(h)
  147. ratio_w = resize_w / float(w)
  148. return img, [ratio_h, ratio_w]
  149. def resize_image_type2(self, img):
  150. """resize image size"""
  151. h, w, _ = img.shape
  152. resize_w = w
  153. resize_h = h
  154. if resize_h > resize_w:
  155. ratio = float(self.resize_long) / resize_h
  156. else:
  157. ratio = float(self.resize_long) / resize_w
  158. resize_h = int(resize_h * ratio)
  159. resize_w = int(resize_w * ratio)
  160. max_stride = 128
  161. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  162. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  163. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  164. ratio_h = resize_h / float(h)
  165. ratio_w = resize_w / float(w)
  166. return img, [ratio_h, ratio_w]
  167. def resize_image_type3(self, img):
  168. """resize the image"""
  169. resize_c, resize_h, resize_w = self.input_shape # (c, h, w)
  170. ori_h, ori_w = img.shape[:2] # (h, w, c)
  171. ratio_h = float(resize_h) / ori_h
  172. ratio_w = float(resize_w) / ori_w
  173. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  174. return img, [ratio_h, ratio_w]
  175. @benchmark.timeit
  176. @class_requires_deps("opencv-contrib-python")
  177. class NormalizeImage:
  178. """normalize image such as substract mean, divide std"""
  179. def __init__(self, scale=None, mean=None, std=None, order="chw"):
  180. super().__init__()
  181. if isinstance(scale, str):
  182. scale = eval(scale)
  183. self.order = order
  184. scale = scale if scale is not None else 1.0 / 255.0
  185. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  186. std = std if std is not None else [0.229, 0.224, 0.225]
  187. self.alpha = [scale / std[i] for i in range(len(std))]
  188. self.beta = [-mean[i] / std[i] for i in range(len(std))]
  189. def __call__(self, imgs):
  190. """apply"""
  191. def _norm(img):
  192. if self.order == "chw":
  193. img = np.transpose(img, (2, 0, 1))
  194. split_im = list(cv2.split(img))
  195. for c in range(img.shape[2]):
  196. split_im[c] = split_im[c].astype(np.float32)
  197. split_im[c] *= self.alpha[c]
  198. split_im[c] += self.beta[c]
  199. res = cv2.merge(split_im)
  200. if self.order == "chw":
  201. res = np.transpose(res, (1, 2, 0))
  202. return res
  203. return [_norm(img) for img in imgs]
  204. @benchmark.timeit
  205. @class_requires_deps("opencv-contrib-python", "pyclipper")
  206. class DBPostProcess:
  207. """
  208. The post process for Differentiable Binarization (DB).
  209. """
  210. def __init__(
  211. self,
  212. thresh=0.3,
  213. box_thresh=0.7,
  214. max_candidates=1000,
  215. unclip_ratio=2.0,
  216. use_dilation=False,
  217. score_mode="fast",
  218. box_type="quad",
  219. **kwargs
  220. ):
  221. super().__init__()
  222. self.thresh = thresh
  223. self.box_thresh = box_thresh
  224. self.max_candidates = max_candidates
  225. self.unclip_ratio = unclip_ratio
  226. self.min_size = 3
  227. self.score_mode = score_mode
  228. self.box_type = box_type
  229. assert score_mode in [
  230. "slow",
  231. "fast",
  232. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  233. self.use_dilation = use_dilation
  234. def polygons_from_bitmap(
  235. self,
  236. pred,
  237. _bitmap,
  238. dest_width,
  239. dest_height,
  240. box_thresh,
  241. unclip_ratio,
  242. ):
  243. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  244. bitmap = _bitmap
  245. height, width = bitmap.shape
  246. width_scale = dest_width / width
  247. height_scale = dest_height / height
  248. boxes = []
  249. scores = []
  250. contours, _ = cv2.findContours(
  251. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  252. )
  253. for contour in contours[: self.max_candidates]:
  254. epsilon = 0.002 * cv2.arcLength(contour, True)
  255. approx = cv2.approxPolyDP(contour, epsilon, True)
  256. points = approx.reshape((-1, 2))
  257. if points.shape[0] < 4:
  258. continue
  259. score = self.box_score_fast(pred, points.reshape(-1, 2))
  260. if box_thresh > score:
  261. continue
  262. if points.shape[0] > 2:
  263. box = self.unclip(points, unclip_ratio)
  264. if len(box) > 1:
  265. continue
  266. else:
  267. continue
  268. box = box.reshape(-1, 2)
  269. if len(box) > 0:
  270. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  271. if sside < self.min_size + 2:
  272. continue
  273. else:
  274. continue
  275. box = np.array(box)
  276. for i in range(box.shape[0]):
  277. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  278. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  279. boxes.append(box)
  280. scores.append(score)
  281. return boxes, scores
  282. def boxes_from_bitmap(
  283. self,
  284. pred,
  285. _bitmap,
  286. dest_width,
  287. dest_height,
  288. box_thresh,
  289. unclip_ratio,
  290. ):
  291. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  292. bitmap = _bitmap
  293. height, width = bitmap.shape
  294. width_scale = dest_width / width
  295. height_scale = dest_height / height
  296. outs = cv2.findContours(
  297. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  298. )
  299. if len(outs) == 3:
  300. img, contours, _ = outs[0], outs[1], outs[2]
  301. elif len(outs) == 2:
  302. contours, _ = outs[0], outs[1]
  303. num_contours = min(len(contours), self.max_candidates)
  304. boxes = []
  305. scores = []
  306. for index in range(num_contours):
  307. contour = contours[index]
  308. points, sside = self.get_mini_boxes(contour)
  309. if sside < self.min_size:
  310. continue
  311. points = np.array(points)
  312. if self.score_mode == "fast":
  313. score = self.box_score_fast(pred, points.reshape(-1, 2))
  314. else:
  315. score = self.box_score_slow(pred, contour)
  316. if box_thresh > score:
  317. continue
  318. box = self.unclip(points, unclip_ratio).reshape(-1, 1, 2)
  319. box, sside = self.get_mini_boxes(box)
  320. if sside < self.min_size + 2:
  321. continue
  322. box = np.array(box)
  323. for i in range(box.shape[0]):
  324. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  325. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  326. boxes.append(box.astype(np.int16))
  327. scores.append(score)
  328. return np.array(boxes, dtype=np.int16), scores
  329. def unclip(self, box, unclip_ratio):
  330. """unclip"""
  331. area = cv2.contourArea(box)
  332. length = cv2.arcLength(box, True)
  333. distance = area * unclip_ratio / length
  334. offset = pyclipper.PyclipperOffset()
  335. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  336. try:
  337. expanded = np.array(offset.Execute(distance))
  338. except ValueError:
  339. expanded = np.array(offset.Execute(distance)[0])
  340. return expanded
  341. def get_mini_boxes(self, contour):
  342. """get mini boxes"""
  343. bounding_box = cv2.minAreaRect(contour)
  344. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  345. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  346. if points[1][1] > points[0][1]:
  347. index_1 = 0
  348. index_4 = 1
  349. else:
  350. index_1 = 1
  351. index_4 = 0
  352. if points[3][1] > points[2][1]:
  353. index_2 = 2
  354. index_3 = 3
  355. else:
  356. index_2 = 3
  357. index_3 = 2
  358. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  359. return box, min(bounding_box[1])
  360. def box_score_fast(self, bitmap, _box):
  361. """box_score_fast: use bbox mean score as the mean score"""
  362. h, w = bitmap.shape[:2]
  363. box = _box.copy()
  364. xmin = max(0, min(math.floor(box[:, 0].min()), w - 1))
  365. xmax = max(0, min(math.ceil(box[:, 0].max()), w - 1))
  366. ymin = max(0, min(math.floor(box[:, 1].min()), h - 1))
  367. ymax = max(0, min(math.ceil(box[:, 1].max()), h - 1))
  368. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  369. box[:, 0] = box[:, 0] - xmin
  370. box[:, 1] = box[:, 1] - ymin
  371. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  372. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  373. def box_score_slow(self, bitmap, contour):
  374. """box_score_slow: use polygon mean score as the mean score"""
  375. h, w = bitmap.shape[:2]
  376. contour = contour.copy()
  377. contour = np.reshape(contour, (-1, 2))
  378. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  379. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  380. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  381. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  382. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  383. contour[:, 0] = contour[:, 0] - xmin
  384. contour[:, 1] = contour[:, 1] - ymin
  385. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  386. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  387. def __call__(
  388. self,
  389. preds,
  390. img_shapes,
  391. thresh: Union[float, None] = None,
  392. box_thresh: Union[float, None] = None,
  393. unclip_ratio: Union[float, None] = None,
  394. ):
  395. """apply"""
  396. boxes, scores = [], []
  397. for pred, img_shape in zip(preds[0], img_shapes):
  398. box, score = self.process(
  399. pred,
  400. img_shape,
  401. thresh or self.thresh,
  402. box_thresh or self.box_thresh,
  403. unclip_ratio or self.unclip_ratio,
  404. )
  405. boxes.append(box)
  406. scores.append(score)
  407. return boxes, scores
  408. def process(
  409. self,
  410. pred,
  411. img_shape,
  412. thresh,
  413. box_thresh,
  414. unclip_ratio,
  415. ):
  416. pred = pred[0, :, :]
  417. segmentation = pred > thresh
  418. dilation_kernel = None if not self.use_dilation else np.array([[1, 1], [1, 1]])
  419. src_h, src_w, ratio_h, ratio_w = img_shape
  420. if dilation_kernel is not None:
  421. mask = cv2.dilate(
  422. np.array(segmentation).astype(np.uint8),
  423. dilation_kernel,
  424. )
  425. else:
  426. mask = segmentation
  427. if self.box_type == "poly":
  428. boxes, scores = self.polygons_from_bitmap(
  429. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  430. )
  431. elif self.box_type == "quad":
  432. boxes, scores = self.boxes_from_bitmap(
  433. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  434. )
  435. else:
  436. raise ValueError("box_type can only be one of ['quad', 'poly']")
  437. return boxes, scores