processors.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Union
  16. import numpy as np
  17. from ....utils import logging
  18. from ....utils.deps import class_requires_deps, is_dep_available
  19. from ...utils.benchmark import benchmark
  20. if is_dep_available("opencv-contrib-python"):
  21. import cv2
  22. if is_dep_available("pyclipper"):
  23. import pyclipper
  24. @benchmark.timeit
  25. @class_requires_deps("opencv-contrib-python")
  26. class DetResizeForTest:
  27. """DetResizeForTest"""
  28. def __init__(self, input_shape=None, max_side_limit=4000, **kwargs):
  29. self.resize_type = 0
  30. self.keep_ratio = False
  31. if input_shape is not None:
  32. self.input_shape = input_shape
  33. self.resize_type = 3
  34. elif "image_shape" in kwargs:
  35. self.image_shape = kwargs["image_shape"]
  36. self.resize_type = 1
  37. if "keep_ratio" in kwargs:
  38. self.keep_ratio = kwargs["keep_ratio"]
  39. elif "limit_side_len" in kwargs:
  40. self.limit_side_len = kwargs["limit_side_len"]
  41. self.limit_type = kwargs.get("limit_type", "min")
  42. elif "resize_long" in kwargs:
  43. self.resize_type = 2
  44. self.resize_long = kwargs.get("resize_long", 960)
  45. else:
  46. self.limit_side_len = 736
  47. self.limit_type = "min"
  48. self.max_side_limit = max_side_limit
  49. def __call__(
  50. self,
  51. imgs,
  52. limit_side_len: Union[int, None] = None,
  53. limit_type: Union[str, None] = None,
  54. max_side_limit: Union[int, None] = None,
  55. ):
  56. """apply"""
  57. max_side_limit = (
  58. max_side_limit if max_side_limit is not None else self.max_side_limit
  59. )
  60. resize_imgs, img_shapes = [], []
  61. for ori_img in imgs:
  62. img, shape = self.resize(
  63. ori_img, limit_side_len, limit_type, max_side_limit
  64. )
  65. resize_imgs.append(img)
  66. img_shapes.append(shape)
  67. return resize_imgs, img_shapes
  68. def resize(
  69. self,
  70. img,
  71. limit_side_len: Union[int, None],
  72. limit_type: Union[str, None],
  73. max_side_limit: Union[int, None] = None,
  74. ):
  75. src_h, src_w, _ = img.shape
  76. if sum([src_h, src_w]) < 64:
  77. img = self.image_padding(img)
  78. if self.resize_type == 0:
  79. # img, shape = self.resize_image_type0(img)
  80. img, [ratio_h, ratio_w] = self.resize_image_type0(
  81. img, limit_side_len, limit_type, max_side_limit
  82. )
  83. elif self.resize_type == 2:
  84. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  85. elif self.resize_type == 3:
  86. img, [ratio_h, ratio_w] = self.resize_image_type3(img)
  87. else:
  88. # img, shape = self.resize_image_type1(img)
  89. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  90. return img, np.array([src_h, src_w, ratio_h, ratio_w])
  91. def image_padding(self, im, value=0):
  92. """padding image"""
  93. h, w, c = im.shape
  94. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  95. im_pad[:h, :w, :] = im
  96. return im_pad
  97. def resize_image_type1(self, img):
  98. """resize the image"""
  99. resize_h, resize_w = self.image_shape
  100. ori_h, ori_w = img.shape[:2] # (h, w, c)
  101. if self.keep_ratio is True:
  102. resize_w = ori_w * resize_h / ori_h
  103. N = math.ceil(resize_w / 32)
  104. resize_w = N * 32
  105. if resize_h == ori_h and resize_w == ori_w:
  106. return img, [1.0, 1.0]
  107. ratio_h = float(resize_h) / ori_h
  108. ratio_w = float(resize_w) / ori_w
  109. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  110. # return img, np.array([ori_h, ori_w])
  111. return img, [ratio_h, ratio_w]
  112. def resize_image_type0(
  113. self,
  114. img,
  115. limit_side_len: Union[int, None],
  116. limit_type: Union[str, None],
  117. max_side_limit: Union[int, None] = None,
  118. ):
  119. """
  120. resize image to a size multiple of 32 which is required by the network
  121. args:
  122. img(array): array with shape [h, w, c]
  123. return(tuple):
  124. img, (ratio_h, ratio_w)
  125. """
  126. limit_side_len = limit_side_len or self.limit_side_len
  127. limit_type = limit_type or self.limit_type
  128. h, w, c = img.shape
  129. # limit the max side
  130. if limit_type == "max":
  131. if max(h, w) > limit_side_len:
  132. if h > w:
  133. ratio = float(limit_side_len) / h
  134. else:
  135. ratio = float(limit_side_len) / w
  136. else:
  137. ratio = 1.0
  138. elif limit_type == "min":
  139. if min(h, w) < limit_side_len:
  140. if h < w:
  141. ratio = float(limit_side_len) / h
  142. else:
  143. ratio = float(limit_side_len) / w
  144. else:
  145. ratio = 1.0
  146. elif limit_type == "resize_long":
  147. ratio = float(limit_side_len) / max(h, w)
  148. else:
  149. raise Exception("not support limit type, image ")
  150. resize_h = int(h * ratio)
  151. resize_w = int(w * ratio)
  152. if max(resize_h, resize_w) > max_side_limit:
  153. logging.warning(
  154. f"Resized image size ({resize_h}x{resize_w}) exceeds max_side_limit of {max_side_limit}. "
  155. f"Resizing to fit within limit."
  156. )
  157. ratio = float(max_side_limit) / max(resize_h, resize_w)
  158. resize_h, resize_w = int(resize_h * ratio), int(resize_w * ratio)
  159. resize_h = max(int(round(resize_h / 32) * 32), 32)
  160. resize_w = max(int(round(resize_w / 32) * 32), 32)
  161. if resize_h == h and resize_w == w:
  162. return img, [1.0, 1.0]
  163. try:
  164. if int(resize_w) <= 0 or int(resize_h) <= 0:
  165. return None, (None, None)
  166. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  167. except:
  168. logging.info(img.shape, resize_w, resize_h)
  169. raise
  170. ratio_h = resize_h / float(h)
  171. ratio_w = resize_w / float(w)
  172. return img, [ratio_h, ratio_w]
  173. def resize_image_type2(self, img):
  174. """resize image size"""
  175. h, w, _ = img.shape
  176. resize_w = w
  177. resize_h = h
  178. if resize_h > resize_w:
  179. ratio = float(self.resize_long) / resize_h
  180. else:
  181. ratio = float(self.resize_long) / resize_w
  182. resize_h = int(resize_h * ratio)
  183. resize_w = int(resize_w * ratio)
  184. max_stride = 128
  185. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  186. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  187. if resize_h == h and resize_w == w:
  188. return img, [1.0, 1.0]
  189. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  190. ratio_h = resize_h / float(h)
  191. ratio_w = resize_w / float(w)
  192. return img, [ratio_h, ratio_w]
  193. def resize_image_type3(self, img):
  194. """resize the image"""
  195. resize_c, resize_h, resize_w = self.input_shape # (c, h, w)
  196. ori_h, ori_w = img.shape[:2] # (h, w, c)
  197. if resize_h == ori_h and resize_w == ori_w:
  198. return img, [1.0, 1.0]
  199. ratio_h = float(resize_h) / ori_h
  200. ratio_w = float(resize_w) / ori_w
  201. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  202. return img, [ratio_h, ratio_w]
  203. @benchmark.timeit
  204. @class_requires_deps("opencv-contrib-python")
  205. class NormalizeImage:
  206. """normalize image such as subtract mean, divide std"""
  207. def __init__(self, scale=None, mean=None, std=None, order="chw"):
  208. super().__init__()
  209. if isinstance(scale, str):
  210. scale = eval(scale)
  211. self.order = order
  212. scale = scale if scale is not None else 1.0 / 255.0
  213. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  214. std = std if std is not None else [0.229, 0.224, 0.225]
  215. self.alpha = [scale / std[i] for i in range(len(std))]
  216. self.beta = [-mean[i] / std[i] for i in range(len(std))]
  217. def __call__(self, imgs):
  218. """apply"""
  219. def _norm(img):
  220. if self.order == "chw":
  221. img = np.transpose(img, (2, 0, 1))
  222. split_im = list(cv2.split(img))
  223. for c in range(img.shape[2]):
  224. split_im[c] = split_im[c].astype(np.float32)
  225. split_im[c] *= self.alpha[c]
  226. split_im[c] += self.beta[c]
  227. res = cv2.merge(split_im)
  228. if self.order == "chw":
  229. res = np.transpose(res, (1, 2, 0))
  230. return res
  231. return [_norm(img) for img in imgs]
  232. @benchmark.timeit
  233. @class_requires_deps("opencv-contrib-python", "pyclipper")
  234. class DBPostProcess:
  235. """
  236. The post process for Differentiable Binarization (DB).
  237. """
  238. def __init__(
  239. self,
  240. thresh=0.3,
  241. box_thresh=0.7,
  242. max_candidates=1000,
  243. unclip_ratio=2.0,
  244. use_dilation=False,
  245. score_mode="fast",
  246. box_type="quad",
  247. **kwargs,
  248. ):
  249. super().__init__()
  250. self.thresh = thresh
  251. self.box_thresh = box_thresh
  252. self.max_candidates = max_candidates
  253. self.unclip_ratio = unclip_ratio
  254. self.min_size = 3
  255. self.score_mode = score_mode
  256. self.box_type = box_type
  257. assert score_mode in [
  258. "slow",
  259. "fast",
  260. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  261. self.use_dilation = use_dilation
  262. def polygons_from_bitmap(
  263. self,
  264. pred,
  265. _bitmap,
  266. dest_width,
  267. dest_height,
  268. box_thresh,
  269. unclip_ratio,
  270. ):
  271. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  272. bitmap = _bitmap
  273. height, width = bitmap.shape
  274. width_scale = dest_width / width
  275. height_scale = dest_height / height
  276. boxes = []
  277. scores = []
  278. contours, _ = cv2.findContours(
  279. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  280. )
  281. for contour in contours[: self.max_candidates]:
  282. epsilon = 0.002 * cv2.arcLength(contour, True)
  283. approx = cv2.approxPolyDP(contour, epsilon, True)
  284. points = approx.reshape((-1, 2))
  285. if points.shape[0] < 4:
  286. continue
  287. score = self.box_score_fast(pred, points.reshape(-1, 2))
  288. if box_thresh > score:
  289. continue
  290. if points.shape[0] > 2:
  291. box = self.unclip(points, unclip_ratio)
  292. if len(box) > 1:
  293. continue
  294. else:
  295. continue
  296. box = box.reshape(-1, 2)
  297. if len(box) > 0:
  298. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  299. if sside < self.min_size + 2:
  300. continue
  301. else:
  302. continue
  303. box = np.array(box)
  304. for i in range(box.shape[0]):
  305. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  306. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  307. boxes.append(box)
  308. scores.append(score)
  309. return boxes, scores
  310. def boxes_from_bitmap(
  311. self,
  312. pred,
  313. _bitmap,
  314. dest_width,
  315. dest_height,
  316. box_thresh,
  317. unclip_ratio,
  318. ):
  319. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  320. bitmap = _bitmap
  321. height, width = bitmap.shape
  322. width_scale = dest_width / width
  323. height_scale = dest_height / height
  324. outs = cv2.findContours(
  325. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  326. )
  327. if len(outs) == 3:
  328. img, contours, _ = outs[0], outs[1], outs[2]
  329. elif len(outs) == 2:
  330. contours, _ = outs[0], outs[1]
  331. num_contours = min(len(contours), self.max_candidates)
  332. boxes = []
  333. scores = []
  334. for index in range(num_contours):
  335. contour = contours[index]
  336. points, sside = self.get_mini_boxes(contour)
  337. if sside < self.min_size:
  338. continue
  339. points = np.array(points)
  340. if self.score_mode == "fast":
  341. score = self.box_score_fast(pred, points.reshape(-1, 2))
  342. else:
  343. score = self.box_score_slow(pred, contour)
  344. if box_thresh > score:
  345. continue
  346. box = self.unclip(points, unclip_ratio).reshape(-1, 1, 2)
  347. box, sside = self.get_mini_boxes(box)
  348. if sside < self.min_size + 2:
  349. continue
  350. box = np.array(box)
  351. for i in range(box.shape[0]):
  352. box[i, 0] = max(0, min(round(box[i, 0] * width_scale), dest_width))
  353. box[i, 1] = max(0, min(round(box[i, 1] * height_scale), dest_height))
  354. boxes.append(box.astype(np.int16))
  355. scores.append(score)
  356. return np.array(boxes, dtype=np.int16), scores
  357. def unclip(self, box, unclip_ratio):
  358. """unclip"""
  359. area = cv2.contourArea(box)
  360. length = cv2.arcLength(box, True)
  361. distance = area * unclip_ratio / length
  362. offset = pyclipper.PyclipperOffset()
  363. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  364. try:
  365. expanded = np.array(offset.Execute(distance))
  366. except ValueError:
  367. expanded = np.array(offset.Execute(distance)[0])
  368. return expanded
  369. def get_mini_boxes(self, contour):
  370. """get mini boxes"""
  371. bounding_box = cv2.minAreaRect(contour)
  372. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  373. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  374. if points[1][1] > points[0][1]:
  375. index_1 = 0
  376. index_4 = 1
  377. else:
  378. index_1 = 1
  379. index_4 = 0
  380. if points[3][1] > points[2][1]:
  381. index_2 = 2
  382. index_3 = 3
  383. else:
  384. index_2 = 3
  385. index_3 = 2
  386. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  387. return box, min(bounding_box[1])
  388. def box_score_fast(self, bitmap, _box):
  389. """box_score_fast: use bbox mean score as the mean score"""
  390. h, w = bitmap.shape[:2]
  391. box = _box.copy()
  392. xmin = max(0, min(math.floor(box[:, 0].min()), w - 1))
  393. xmax = max(0, min(math.ceil(box[:, 0].max()), w - 1))
  394. ymin = max(0, min(math.floor(box[:, 1].min()), h - 1))
  395. ymax = max(0, min(math.ceil(box[:, 1].max()), h - 1))
  396. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  397. box[:, 0] = box[:, 0] - xmin
  398. box[:, 1] = box[:, 1] - ymin
  399. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  400. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  401. def box_score_slow(self, bitmap, contour):
  402. """box_score_slow: use polygon mean score as the mean score"""
  403. h, w = bitmap.shape[:2]
  404. contour = contour.copy()
  405. contour = np.reshape(contour, (-1, 2))
  406. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  407. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  408. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  409. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  410. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  411. contour[:, 0] = contour[:, 0] - xmin
  412. contour[:, 1] = contour[:, 1] - ymin
  413. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  414. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  415. def __call__(
  416. self,
  417. preds,
  418. img_shapes,
  419. thresh: Union[float, None] = None,
  420. box_thresh: Union[float, None] = None,
  421. unclip_ratio: Union[float, None] = None,
  422. ):
  423. """apply"""
  424. boxes, scores = [], []
  425. for pred, img_shape in zip(preds[0], img_shapes):
  426. box, score = self.process(
  427. pred,
  428. img_shape,
  429. thresh or self.thresh,
  430. box_thresh or self.box_thresh,
  431. unclip_ratio or self.unclip_ratio,
  432. )
  433. boxes.append(box)
  434. scores.append(score)
  435. return boxes, scores
  436. def process(
  437. self,
  438. pred,
  439. img_shape,
  440. thresh,
  441. box_thresh,
  442. unclip_ratio,
  443. ):
  444. pred = pred[0, :, :]
  445. segmentation = pred > thresh
  446. dilation_kernel = None if not self.use_dilation else np.array([[1, 1], [1, 1]])
  447. src_h, src_w, ratio_h, ratio_w = img_shape
  448. if dilation_kernel is not None:
  449. mask = cv2.dilate(
  450. np.array(segmentation).astype(np.uint8),
  451. dilation_kernel,
  452. )
  453. else:
  454. mask = segmentation
  455. if self.box_type == "poly":
  456. boxes, scores = self.polygons_from_bitmap(
  457. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  458. )
  459. elif self.box_type == "quad":
  460. boxes, scores = self.boxes_from_bitmap(
  461. pred, mask, src_w, src_h, box_thresh, unclip_ratio
  462. )
  463. else:
  464. raise ValueError("box_type can only be one of ['quad', 'poly']")
  465. return boxes, scores