transforms.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import copy
  18. import math
  19. import pyclipper
  20. import numpy as np
  21. from PIL import Image
  22. from shapely.geometry import Polygon
  23. from ....utils import logging
  24. from ...base.predictor.io.writers import ImageWriter
  25. from ...base.predictor.io.readers import ImageReader
  26. from ...base.predictor import BaseTransform
  27. from .keys import TextDetKeys as K
  28. __all__ = [
  29. 'DetResizeForTest', 'NormalizeImage', 'DBPostProcess', 'SaveTextDetResults'
  30. ]
  31. class DetResizeForTest(BaseTransform):
  32. """ DetResizeForTest """
  33. def __init__(self, **kwargs):
  34. super(DetResizeForTest, self).__init__()
  35. self.resize_type = 0
  36. self.keep_ratio = False
  37. if 'image_shape' in kwargs:
  38. self.image_shape = kwargs['image_shape']
  39. self.resize_type = 1
  40. if 'keep_ratio' in kwargs:
  41. self.keep_ratio = kwargs['keep_ratio']
  42. elif 'limit_side_len' in kwargs:
  43. self.limit_side_len = kwargs['limit_side_len']
  44. self.limit_type = kwargs.get('limit_type', 'min')
  45. elif 'resize_long' in kwargs:
  46. self.resize_type = 2
  47. self.resize_long = kwargs.get('resize_long', 960)
  48. else:
  49. self.limit_side_len = 736
  50. self.limit_type = 'min'
  51. def apply(self, data):
  52. """ apply """
  53. img = data[K.IMAGE]
  54. src_h, src_w, _ = img.shape
  55. if sum([src_h, src_w]) < 64:
  56. img = self.image_padding(img)
  57. if self.resize_type == 0:
  58. # img, shape = self.resize_image_type0(img)
  59. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  60. elif self.resize_type == 2:
  61. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  62. else:
  63. # img, shape = self.resize_image_type1(img)
  64. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  65. data[K.IMAGE] = img
  66. data[K.SHAPE] = np.array([src_h, src_w, ratio_h, ratio_w])
  67. return data
  68. @classmethod
  69. def get_input_keys(cls):
  70. """ get input keys """
  71. return [K.IMAGE]
  72. @classmethod
  73. def get_output_keys(cls):
  74. """ get output keys """
  75. return [K.IMAGE, K.SHAPE]
  76. def image_padding(self, im, value=0):
  77. """ padding image """
  78. h, w, c = im.shape
  79. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  80. im_pad[:h, :w, :] = im
  81. return im_pad
  82. def resize_image_type1(self, img):
  83. """ resize the image """
  84. resize_h, resize_w = self.image_shape
  85. ori_h, ori_w = img.shape[:2] # (h, w, c)
  86. if self.keep_ratio is True:
  87. resize_w = ori_w * resize_h / ori_h
  88. N = math.ceil(resize_w / 32)
  89. resize_w = N * 32
  90. ratio_h = float(resize_h) / ori_h
  91. ratio_w = float(resize_w) / ori_w
  92. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  93. # return img, np.array([ori_h, ori_w])
  94. return img, [ratio_h, ratio_w]
  95. def resize_image_type0(self, img):
  96. """
  97. resize image to a size multiple of 32 which is required by the network
  98. args:
  99. img(array): array with shape [h, w, c]
  100. return(tuple):
  101. img, (ratio_h, ratio_w)
  102. """
  103. limit_side_len = self.limit_side_len
  104. h, w, c = img.shape
  105. # limit the max side
  106. if self.limit_type == 'max':
  107. if max(h, w) > limit_side_len:
  108. if h > w:
  109. ratio = float(limit_side_len) / h
  110. else:
  111. ratio = float(limit_side_len) / w
  112. else:
  113. ratio = 1.
  114. elif self.limit_type == 'min':
  115. if min(h, w) < limit_side_len:
  116. if h < w:
  117. ratio = float(limit_side_len) / h
  118. else:
  119. ratio = float(limit_side_len) / w
  120. else:
  121. ratio = 1.
  122. elif self.limit_type == 'resize_long':
  123. ratio = float(limit_side_len) / max(h, w)
  124. else:
  125. raise Exception('not support limit type, image ')
  126. resize_h = int(h * ratio)
  127. resize_w = int(w * ratio)
  128. resize_h = max(int(round(resize_h / 32) * 32), 32)
  129. resize_w = max(int(round(resize_w / 32) * 32), 32)
  130. try:
  131. if int(resize_w) <= 0 or int(resize_h) <= 0:
  132. return None, (None, None)
  133. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  134. except:
  135. logging.info(img.shape, resize_w, resize_h)
  136. sys.exit(0)
  137. ratio_h = resize_h / float(h)
  138. ratio_w = resize_w / float(w)
  139. return img, [ratio_h, ratio_w]
  140. def resize_image_type2(self, img):
  141. """ resize image size """
  142. h, w, _ = img.shape
  143. resize_w = w
  144. resize_h = h
  145. if resize_h > resize_w:
  146. ratio = float(self.resize_long) / resize_h
  147. else:
  148. ratio = float(self.resize_long) / resize_w
  149. resize_h = int(resize_h * ratio)
  150. resize_w = int(resize_w * ratio)
  151. max_stride = 128
  152. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  153. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  154. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  155. ratio_h = resize_h / float(h)
  156. ratio_w = resize_w / float(w)
  157. return img, [ratio_h, ratio_w]
  158. class NormalizeImage(BaseTransform):
  159. """ normalize image such as substract mean, divide std
  160. """
  161. def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
  162. if isinstance(scale, str):
  163. scale = eval(scale)
  164. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  165. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  166. std = std if std is not None else [0.229, 0.224, 0.225]
  167. shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
  168. self.mean = np.array(mean).reshape(shape).astype('float32')
  169. self.std = np.array(std).reshape(shape).astype('float32')
  170. def apply(self, data):
  171. """ apply """
  172. img = data[K.IMAGE]
  173. from PIL import Image
  174. if isinstance(img, Image.Image):
  175. img = np.array(img)
  176. assert isinstance(img,
  177. np.ndarray), "invalid input 'img' in NormalizeImage"
  178. data[K.IMAGE] = (
  179. img.astype('float32') * self.scale - self.mean) / self.std
  180. return data
  181. @classmethod
  182. def get_input_keys(cls):
  183. """ get input keys """
  184. return [K.IMAGE]
  185. @classmethod
  186. def get_output_keys(cls):
  187. """ get output keys """
  188. return [K.IMAGE]
  189. class DBPostProcess(BaseTransform):
  190. """
  191. The post process for Differentiable Binarization (DB).
  192. """
  193. def __init__(self,
  194. thresh=0.3,
  195. box_thresh=0.7,
  196. max_candidates=1000,
  197. unclip_ratio=2.0,
  198. use_dilation=False,
  199. score_mode="fast",
  200. box_type='quad',
  201. **kwargs):
  202. self.thresh = thresh
  203. self.box_thresh = box_thresh
  204. self.max_candidates = max_candidates
  205. self.unclip_ratio = unclip_ratio
  206. self.min_size = 3
  207. self.score_mode = score_mode
  208. self.box_type = box_type
  209. assert score_mode in [
  210. "slow", "fast"
  211. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  212. self.dilation_kernel = None if not use_dilation else np.array([[1, 1],
  213. [1, 1]])
  214. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  215. """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} """
  216. bitmap = _bitmap
  217. height, width = bitmap.shape
  218. boxes = []
  219. scores = []
  220. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
  221. cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  222. for contour in contours[:self.max_candidates]:
  223. epsilon = 0.002 * cv2.arcLength(contour, True)
  224. approx = cv2.approxPolyDP(contour, epsilon, True)
  225. points = approx.reshape((-1, 2))
  226. if points.shape[0] < 4:
  227. continue
  228. score = self.box_score_fast(pred, points.reshape(-1, 2))
  229. if self.box_thresh > score:
  230. continue
  231. if points.shape[0] > 2:
  232. box = self.unclip(points, self.unclip_ratio)
  233. if len(box) > 1:
  234. continue
  235. else:
  236. continue
  237. box = box.reshape(-1, 2)
  238. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  239. if sside < self.min_size + 2:
  240. continue
  241. box = np.array(box)
  242. box[:, 0] = np.clip(
  243. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  244. box[:, 1] = np.clip(
  245. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  246. boxes.append(box.tolist())
  247. scores.append(score)
  248. return boxes, scores
  249. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  250. """ _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1} """
  251. bitmap = _bitmap
  252. height, width = bitmap.shape
  253. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  254. cv2.CHAIN_APPROX_SIMPLE)
  255. if len(outs) == 3:
  256. img, contours, _ = outs[0], outs[1], outs[2]
  257. elif len(outs) == 2:
  258. contours, _ = outs[0], outs[1]
  259. num_contours = min(len(contours), self.max_candidates)
  260. boxes = []
  261. scores = []
  262. for index in range(num_contours):
  263. contour = contours[index]
  264. points, sside = self.get_mini_boxes(contour)
  265. if sside < self.min_size:
  266. continue
  267. points = np.array(points)
  268. if self.score_mode == "fast":
  269. score = self.box_score_fast(pred, points.reshape(-1, 2))
  270. else:
  271. score = self.box_score_slow(pred, contour)
  272. if self.box_thresh > score:
  273. continue
  274. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  275. box, sside = self.get_mini_boxes(box)
  276. if sside < self.min_size + 2:
  277. continue
  278. box = np.array(box)
  279. box[:, 0] = np.clip(
  280. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  281. box[:, 1] = np.clip(
  282. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  283. boxes.append(box.astype(np.int16))
  284. scores.append(score)
  285. return np.array(boxes, dtype=np.int16), scores
  286. def unclip(self, box, unclip_ratio):
  287. """ unclip """
  288. poly = Polygon(box)
  289. distance = poly.area * unclip_ratio / poly.length
  290. offset = pyclipper.PyclipperOffset()
  291. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  292. expanded = np.array(offset.Execute(distance))
  293. return expanded
  294. def get_mini_boxes(self, contour):
  295. """ get mini boxes """
  296. bounding_box = cv2.minAreaRect(contour)
  297. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  298. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  299. if points[1][1] > points[0][1]:
  300. index_1 = 0
  301. index_4 = 1
  302. else:
  303. index_1 = 1
  304. index_4 = 0
  305. if points[3][1] > points[2][1]:
  306. index_2 = 2
  307. index_3 = 3
  308. else:
  309. index_2 = 3
  310. index_3 = 2
  311. box = [
  312. points[index_1], points[index_2], points[index_3], points[index_4]
  313. ]
  314. return box, min(bounding_box[1])
  315. def box_score_fast(self, bitmap, _box):
  316. """ box_score_fast: use bbox mean score as the mean score """
  317. h, w = bitmap.shape[:2]
  318. box = _box.copy()
  319. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  320. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  321. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  322. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  323. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  324. box[:, 0] = box[:, 0] - xmin
  325. box[:, 1] = box[:, 1] - ymin
  326. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  327. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  328. def box_score_slow(self, bitmap, contour):
  329. """ box_score_slow: use polyon mean score as the mean score """
  330. h, w = bitmap.shape[:2]
  331. contour = contour.copy()
  332. contour = np.reshape(contour, (-1, 2))
  333. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  334. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  335. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  336. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  337. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  338. contour[:, 0] = contour[:, 0] - xmin
  339. contour[:, 1] = contour[:, 1] - ymin
  340. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  341. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  342. def apply(self, data):
  343. """ apply """
  344. pred = data[K.PROB_MAP]
  345. shape_list = [data[K.SHAPE]]
  346. pred = pred[0][:, 0, :, :]
  347. segmentation = pred > self.thresh
  348. boxes_batch = []
  349. for batch_index in range(pred.shape[0]):
  350. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  351. if self.dilation_kernel is not None:
  352. mask = cv2.dilate(
  353. np.array(segmentation[batch_index]).astype(np.uint8),
  354. self.dilation_kernel)
  355. else:
  356. mask = segmentation[batch_index]
  357. if self.box_type == 'poly':
  358. boxes, scores = self.polygons_from_bitmap(pred[batch_index],
  359. mask, src_w, src_h)
  360. elif self.box_type == 'quad':
  361. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  362. src_w, src_h)
  363. else:
  364. raise ValueError("box_type can only be one of ['quad', 'poly']")
  365. data[K.DT_POLYS] = boxes
  366. data[K.DT_SCORES] = scores
  367. return data
  368. @classmethod
  369. def get_input_keys(cls):
  370. """ get input keys """
  371. return [K.PROB_MAP]
  372. @classmethod
  373. def get_output_keys(cls):
  374. """ get output keys """
  375. return [K.DT_POLYS, K.DT_SCORES]
  376. class CropByPolys(BaseTransform):
  377. """Crop Image by Polys
  378. """
  379. def __init__(self, det_box_type='quad'):
  380. super().__init__()
  381. self.det_box_type = det_box_type
  382. def apply(self, data):
  383. """ apply """
  384. ori_im = data[K.ORI_IM]
  385. # TODO
  386. # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
  387. dt_boxes = np.array(data[K.DT_POLYS])
  388. img_crop_list = []
  389. for bno in range(len(dt_boxes)):
  390. tmp_box = copy.deepcopy(dt_boxes[bno])
  391. if self.det_box_type == "quad":
  392. img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
  393. else:
  394. img_crop = self.get_minarea_rect_crop(ori_im, tmp_box)
  395. img_crop_list.append(img_crop)
  396. data[K.SUB_IMGS] = img_crop_list
  397. return data
  398. @classmethod
  399. def get_input_keys(cls):
  400. """ get input keys """
  401. return [K.IM_PATH, K.DT_POLYS]
  402. @classmethod
  403. def get_output_keys(cls):
  404. """ get output keys """
  405. return [K.SUB_IMGS]
  406. def sorted_boxes(self, dt_boxes):
  407. """
  408. Sort text boxes in order from top to bottom, left to right
  409. args:
  410. dt_boxes(array):detected text boxes with shape [4, 2]
  411. return:
  412. sorted boxes(array) with shape [4, 2]
  413. """
  414. dt_boxes = np.array(dt_boxes)
  415. num_boxes = dt_boxes.shape[0]
  416. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  417. _boxes = list(sorted_boxes)
  418. for i in range(num_boxes - 1):
  419. for j in range(i, -1, -1):
  420. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  421. _boxes[j + 1][0][0] < _boxes[j][0][0]):
  422. tmp = _boxes[j]
  423. _boxes[j] = _boxes[j + 1]
  424. _boxes[j + 1] = tmp
  425. else:
  426. break
  427. return _boxes
  428. def get_minarea_rect_crop(self, img, points):
  429. """get_minarea_rect_crop
  430. """
  431. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  432. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  433. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  434. if points[1][1] > points[0][1]:
  435. index_a = 0
  436. index_d = 1
  437. else:
  438. index_a = 1
  439. index_d = 0
  440. if points[3][1] > points[2][1]:
  441. index_b = 2
  442. index_c = 3
  443. else:
  444. index_b = 3
  445. index_c = 2
  446. box = [
  447. points[index_a], points[index_b], points[index_c], points[index_d]
  448. ]
  449. crop_img = self.get_rotate_crop_image(img, np.array(box))
  450. return crop_img
  451. def get_rotate_crop_image(self, img, points):
  452. """
  453. img_height, img_width = img.shape[0:2]
  454. left = int(np.min(points[:, 0]))
  455. right = int(np.max(points[:, 0]))
  456. top = int(np.min(points[:, 1]))
  457. bottom = int(np.max(points[:, 1]))
  458. img_crop = img[top:bottom, left:right, :].copy()
  459. points[:, 0] = points[:, 0] - left
  460. points[:, 1] = points[:, 1] - top
  461. """
  462. assert len(points) == 4, "shape of points must be 4*2"
  463. img_crop_width = int(
  464. max(
  465. np.linalg.norm(points[0] - points[1]),
  466. np.linalg.norm(points[2] - points[3])))
  467. img_crop_height = int(
  468. max(
  469. np.linalg.norm(points[0] - points[3]),
  470. np.linalg.norm(points[1] - points[2])))
  471. pts_std = np.float32([
  472. [0, 0],
  473. [img_crop_width, 0],
  474. [img_crop_width, img_crop_height],
  475. [0, img_crop_height],
  476. ])
  477. M = cv2.getPerspectiveTransform(points, pts_std)
  478. dst_img = cv2.warpPerspective(
  479. img,
  480. M,
  481. (img_crop_width, img_crop_height),
  482. borderMode=cv2.BORDER_REPLICATE,
  483. flags=cv2.INTER_CUBIC, )
  484. dst_img_height, dst_img_width = dst_img.shape[0:2]
  485. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  486. dst_img = np.rot90(dst_img)
  487. return dst_img
  488. class SaveTextDetResults(BaseTransform):
  489. """ Save Text Det Results """
  490. _DEFAULT_FILE_NAME = 'text_det_out.png'
  491. def __init__(self, save_dir, file_name=None):
  492. super().__init__()
  493. self.save_dir = save_dir
  494. if file_name is None:
  495. file_name = self._DEFAULT_FILE_NAME
  496. self.file_name = file_name
  497. # We use pillow backend to save both numpy arrays and PIL Image objects
  498. self._writer = ImageWriter(backend='opencv')
  499. def apply(self, data):
  500. """ apply """
  501. if self.save_dir is None:
  502. logging.warning(
  503. "The `save_dir` has been set to None, so the text detection result won't to be saved."
  504. )
  505. return data
  506. save_path = os.path.join(self.save_dir, self.file_name)
  507. bbox_res = data[K.DT_POLYS]
  508. vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
  509. self._writer.write(save_path, vis_img)
  510. return data
  511. @classmethod
  512. def get_input_keys(cls):
  513. """ get input keys """
  514. return [K.IM_PATH, K.DT_POLYS, K.DT_SCORES]
  515. @classmethod
  516. def get_output_keys(cls):
  517. """ get output keys """
  518. return []
  519. def draw_rectangle(self, img_path, boxes):
  520. """ draw rectangle """
  521. boxes = np.array(boxes)
  522. img = cv2.imread(img_path)
  523. img_show = img.copy()
  524. for box in boxes.astype(int):
  525. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  526. cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
  527. return img_show