processors.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Tuple, Union
  15. import os
  16. import sys
  17. import cv2
  18. import copy
  19. import math
  20. import pyclipper
  21. import numpy as np
  22. from numpy.linalg import norm
  23. from PIL import Image
  24. from shapely.geometry import Polygon
  25. from ...utils.io import ImageReader
  26. from ....utils import logging
  27. class DetResizeForTest:
  28. """DetResizeForTest"""
  29. def __init__(self, **kwargs):
  30. super().__init__()
  31. self.resize_type = 0
  32. self.keep_ratio = False
  33. if "image_shape" in kwargs:
  34. self.image_shape = kwargs["image_shape"]
  35. self.resize_type = 1
  36. if "keep_ratio" in kwargs:
  37. self.keep_ratio = kwargs["keep_ratio"]
  38. elif "limit_side_len" in kwargs:
  39. self.limit_side_len = kwargs["limit_side_len"]
  40. self.limit_type = kwargs.get("limit_type", "min")
  41. elif "resize_long" in kwargs:
  42. self.resize_type = 2
  43. self.resize_long = kwargs.get("resize_long", 960)
  44. else:
  45. self.limit_side_len = 736
  46. self.limit_type = "min"
  47. def __call__(
  48. self,
  49. imgs,
  50. limit_side_len: Union[int, None] = None,
  51. limit_type: Union[str, None] = None,
  52. ):
  53. """apply"""
  54. resize_imgs, img_shapes = [], []
  55. for ori_img in imgs:
  56. img, shape = self.resize(ori_img, limit_side_len, limit_type)
  57. resize_imgs.append(img)
  58. img_shapes.append(shape)
  59. return resize_imgs, img_shapes
  60. def resize(
  61. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  62. ):
  63. src_h, src_w, _ = img.shape
  64. if sum([src_h, src_w]) < 64:
  65. img = self.image_padding(img)
  66. if self.resize_type == 0:
  67. # img, shape = self.resize_image_type0(img)
  68. img, [ratio_h, ratio_w] = self.resize_image_type0(
  69. img, limit_side_len, limit_type
  70. )
  71. elif self.resize_type == 2:
  72. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  73. else:
  74. # img, shape = self.resize_image_type1(img)
  75. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  76. return img, np.array([src_h, src_w, ratio_h, ratio_w])
  77. def image_padding(self, im, value=0):
  78. """padding image"""
  79. h, w, c = im.shape
  80. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  81. im_pad[:h, :w, :] = im
  82. return im_pad
  83. def resize_image_type1(self, img):
  84. """resize the image"""
  85. resize_h, resize_w = self.image_shape
  86. ori_h, ori_w = img.shape[:2] # (h, w, c)
  87. if self.keep_ratio is True:
  88. resize_w = ori_w * resize_h / ori_h
  89. N = math.ceil(resize_w / 32)
  90. resize_w = N * 32
  91. ratio_h = float(resize_h) / ori_h
  92. ratio_w = float(resize_w) / ori_w
  93. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  94. # return img, np.array([ori_h, ori_w])
  95. return img, [ratio_h, ratio_w]
  96. def resize_image_type0(
  97. self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
  98. ):
  99. """
  100. resize image to a size multiple of 32 which is required by the network
  101. args:
  102. img(array): array with shape [h, w, c]
  103. return(tuple):
  104. img, (ratio_h, ratio_w)
  105. """
  106. limit_side_len = limit_side_len or self.limit_side_len
  107. limit_type = limit_type or self.limit_type
  108. h, w, c = img.shape
  109. # limit the max side
  110. if limit_type == "max":
  111. if max(h, w) > limit_side_len:
  112. if h > w:
  113. ratio = float(limit_side_len) / h
  114. else:
  115. ratio = float(limit_side_len) / w
  116. else:
  117. ratio = 1.0
  118. elif limit_type == "min":
  119. if min(h, w) < limit_side_len:
  120. if h < w:
  121. ratio = float(limit_side_len) / h
  122. else:
  123. ratio = float(limit_side_len) / w
  124. else:
  125. ratio = 1.0
  126. elif limit_type == "resize_long":
  127. ratio = float(limit_side_len) / max(h, w)
  128. else:
  129. raise Exception("not support limit type, image ")
  130. resize_h = int(h * ratio)
  131. resize_w = int(w * ratio)
  132. resize_h = max(int(round(resize_h / 32) * 32), 32)
  133. resize_w = max(int(round(resize_w / 32) * 32), 32)
  134. try:
  135. if int(resize_w) <= 0 or int(resize_h) <= 0:
  136. return None, (None, None)
  137. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  138. except:
  139. logging.info(img.shape, resize_w, resize_h)
  140. sys.exit(0)
  141. ratio_h = resize_h / float(h)
  142. ratio_w = resize_w / float(w)
  143. return img, [ratio_h, ratio_w]
  144. def resize_image_type2(self, img):
  145. """resize image size"""
  146. h, w, _ = img.shape
  147. resize_w = w
  148. resize_h = h
  149. if resize_h > resize_w:
  150. ratio = float(self.resize_long) / resize_h
  151. else:
  152. ratio = float(self.resize_long) / resize_w
  153. resize_h = int(resize_h * ratio)
  154. resize_w = int(resize_w * ratio)
  155. max_stride = 128
  156. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  157. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  158. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  159. ratio_h = resize_h / float(h)
  160. ratio_w = resize_w / float(w)
  161. return img, [ratio_h, ratio_w]
  162. class NormalizeImage:
  163. """normalize image such as substract mean, divide std"""
  164. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  165. super().__init__()
  166. if isinstance(scale, str):
  167. scale = eval(scale)
  168. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  169. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  170. std = std if std is not None else [0.229, 0.224, 0.225]
  171. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  172. self.mean = np.array(mean).reshape(shape).astype("float32")
  173. self.std = np.array(std).reshape(shape).astype("float32")
  174. def __call__(self, imgs):
  175. """apply"""
  176. def norm(img):
  177. return (img.astype("float32") * self.scale - self.mean) / self.std
  178. return [norm(img) for img in imgs]
  179. class DBPostProcess:
  180. """
  181. The post process for Differentiable Binarization (DB).
  182. """
  183. def __init__(
  184. self,
  185. thresh=0.3,
  186. box_thresh=0.7,
  187. max_candidates=1000,
  188. unclip_ratio=2.0,
  189. use_dilation=False,
  190. score_mode="fast",
  191. box_type="quad",
  192. **kwargs
  193. ):
  194. super().__init__()
  195. self.thresh = thresh
  196. self.box_thresh = box_thresh
  197. self.max_candidates = max_candidates
  198. self.unclip_ratio = unclip_ratio
  199. self.min_size = 3
  200. self.score_mode = score_mode
  201. self.box_type = box_type
  202. assert score_mode in [
  203. "slow",
  204. "fast",
  205. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  206. self.use_dilation = use_dilation
  207. def polygons_from_bitmap(
  208. self,
  209. pred,
  210. _bitmap,
  211. dest_width,
  212. dest_height,
  213. box_thresh,
  214. max_candidates,
  215. unclip_ratio,
  216. ):
  217. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  218. bitmap = _bitmap
  219. height, width = bitmap.shape
  220. boxes = []
  221. scores = []
  222. contours, _ = cv2.findContours(
  223. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  224. )
  225. for contour in contours[:max_candidates]:
  226. epsilon = 0.002 * cv2.arcLength(contour, True)
  227. approx = cv2.approxPolyDP(contour, epsilon, True)
  228. points = approx.reshape((-1, 2))
  229. if points.shape[0] < 4:
  230. continue
  231. score = self.box_score_fast(pred, points.reshape(-1, 2))
  232. if box_thresh > score:
  233. continue
  234. if points.shape[0] > 2:
  235. box = self.unclip(points, unclip_ratio)
  236. if len(box) > 1:
  237. continue
  238. else:
  239. continue
  240. box = box.reshape(-1, 2)
  241. if len(box) > 0:
  242. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  243. if sside < self.min_size + 2:
  244. continue
  245. else:
  246. continue
  247. box = np.array(box)
  248. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  249. box[:, 1] = np.clip(
  250. np.round(box[:, 1] / height * dest_height), 0, dest_height
  251. )
  252. boxes.append(box)
  253. scores.append(score)
  254. return boxes, scores
  255. def boxes_from_bitmap(
  256. self,
  257. pred,
  258. _bitmap,
  259. dest_width,
  260. dest_height,
  261. box_thresh,
  262. max_candidates,
  263. unclip_ratio,
  264. ):
  265. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  266. bitmap = _bitmap
  267. height, width = bitmap.shape
  268. outs = cv2.findContours(
  269. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  270. )
  271. if len(outs) == 3:
  272. img, contours, _ = outs[0], outs[1], outs[2]
  273. elif len(outs) == 2:
  274. contours, _ = outs[0], outs[1]
  275. num_contours = min(len(contours), max_candidates)
  276. boxes = []
  277. scores = []
  278. for index in range(num_contours):
  279. contour = contours[index]
  280. points, sside = self.get_mini_boxes(contour)
  281. if sside < self.min_size:
  282. continue
  283. points = np.array(points)
  284. if self.score_mode == "fast":
  285. score = self.box_score_fast(pred, points.reshape(-1, 2))
  286. else:
  287. score = self.box_score_slow(pred, contour)
  288. if box_thresh > score:
  289. continue
  290. box = self.unclip(points, unclip_ratio).reshape(-1, 1, 2)
  291. box, sside = self.get_mini_boxes(box)
  292. if sside < self.min_size + 2:
  293. continue
  294. box = np.array(box)
  295. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  296. box[:, 1] = np.clip(
  297. np.round(box[:, 1] / height * dest_height), 0, dest_height
  298. )
  299. boxes.append(box.astype(np.int16))
  300. scores.append(score)
  301. return np.array(boxes, dtype=np.int16), scores
  302. def unclip(self, box, unclip_ratio):
  303. """unclip"""
  304. poly = Polygon(box)
  305. distance = poly.area * unclip_ratio / poly.length
  306. offset = pyclipper.PyclipperOffset()
  307. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  308. try:
  309. expanded = np.array(offset.Execute(distance))
  310. except ValueError:
  311. expanded = np.array(offset.Execute(distance)[0])
  312. return expanded
  313. def get_mini_boxes(self, contour):
  314. """get mini boxes"""
  315. bounding_box = cv2.minAreaRect(contour)
  316. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  317. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  318. if points[1][1] > points[0][1]:
  319. index_1 = 0
  320. index_4 = 1
  321. else:
  322. index_1 = 1
  323. index_4 = 0
  324. if points[3][1] > points[2][1]:
  325. index_2 = 2
  326. index_3 = 3
  327. else:
  328. index_2 = 3
  329. index_3 = 2
  330. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  331. return box, min(bounding_box[1])
  332. def box_score_fast(self, bitmap, _box):
  333. """box_score_fast: use bbox mean score as the mean score"""
  334. h, w = bitmap.shape[:2]
  335. box = _box.copy()
  336. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  337. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  338. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  339. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  340. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  341. box[:, 0] = box[:, 0] - xmin
  342. box[:, 1] = box[:, 1] - ymin
  343. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  344. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  345. def box_score_slow(self, bitmap, contour):
  346. """box_score_slow: use polyon mean score as the mean score"""
  347. h, w = bitmap.shape[:2]
  348. contour = contour.copy()
  349. contour = np.reshape(contour, (-1, 2))
  350. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  351. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  352. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  353. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  354. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  355. contour[:, 0] = contour[:, 0] - xmin
  356. contour[:, 1] = contour[:, 1] - ymin
  357. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  358. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  359. def __call__(
  360. self,
  361. preds,
  362. img_shapes,
  363. thresh: Union[float, None] = None,
  364. box_thresh: Union[float, None] = None,
  365. max_candidates: Union[int, None] = None,
  366. unclip_ratio: Union[float, None] = None,
  367. use_dilation: Union[bool, None] = None,
  368. ):
  369. """apply"""
  370. boxes, scores = [], []
  371. for pred, img_shape in zip(preds[0], img_shapes):
  372. box, score = self.process(
  373. pred,
  374. img_shape,
  375. thresh or self.thresh,
  376. box_thresh or self.box_thresh,
  377. max_candidates or self.max_candidates,
  378. unclip_ratio or self.unclip_ratio,
  379. use_dilation or self.use_dilation,
  380. )
  381. boxes.append(box)
  382. scores.append(score)
  383. return boxes, scores
  384. def process(
  385. self,
  386. pred,
  387. img_shape,
  388. thresh,
  389. box_thresh,
  390. max_candidates,
  391. unclip_ratio,
  392. use_dilation,
  393. ):
  394. pred = pred[0, :, :]
  395. segmentation = pred > thresh
  396. dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
  397. src_h, src_w, ratio_h, ratio_w = img_shape
  398. if dilation_kernel is not None:
  399. mask = cv2.dilate(
  400. np.array(segmentation).astype(np.uint8),
  401. dilation_kernel,
  402. )
  403. else:
  404. mask = segmentation
  405. if self.box_type == "poly":
  406. boxes, scores = self.polygons_from_bitmap(
  407. pred, mask, src_w, src_h, box_thresh, max_candidates, unclip_ratio
  408. )
  409. elif self.box_type == "quad":
  410. boxes, scores = self.boxes_from_bitmap(
  411. pred, mask, src_w, src_h, box_thresh, max_candidates, unclip_ratio
  412. )
  413. else:
  414. raise ValueError("box_type can only be one of ['quad', 'poly']")
  415. return boxes, scores