text_det.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import copy
  18. import math
  19. import pyclipper
  20. import numpy as np
  21. from PIL import Image
  22. from shapely.geometry import Polygon
  23. from ...utils.io import ImageReader
  24. from ....utils import logging
  25. from ...results import TextDetResult
  26. from ..base import BaseComponent
  27. __all__ = ["DetResizeForTest", "NormalizeImage", "DBPostProcess", "CropByPolys"]
  28. class DetResizeForTest(BaseComponent):
  29. """DetResizeForTest"""
  30. INPUT_KEYS = ["img"]
  31. OUTPUT_KEYS = ["img", "img_shape"]
  32. DEAULT_INPUTS = {"img": "img"}
  33. DEAULT_OUTPUTS = {"img": "img", "img_shape": "img_shape"}
  34. def __init__(self, **kwargs):
  35. super().__init__()
  36. self.resize_type = 0
  37. self.keep_ratio = False
  38. if "image_shape" in kwargs:
  39. self.image_shape = kwargs["image_shape"]
  40. self.resize_type = 1
  41. if "keep_ratio" in kwargs:
  42. self.keep_ratio = kwargs["keep_ratio"]
  43. elif "limit_side_len" in kwargs:
  44. self.limit_side_len = kwargs["limit_side_len"]
  45. self.limit_type = kwargs.get("limit_type", "min")
  46. elif "resize_long" in kwargs:
  47. self.resize_type = 2
  48. self.resize_long = kwargs.get("resize_long", 960)
  49. else:
  50. self.limit_side_len = 736
  51. self.limit_type = "min"
  52. def apply(self, img):
  53. """apply"""
  54. src_h, src_w, _ = img.shape
  55. if sum([src_h, src_w]) < 64:
  56. img = self.image_padding(img)
  57. if self.resize_type == 0:
  58. # img, shape = self.resize_image_type0(img)
  59. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  60. elif self.resize_type == 2:
  61. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  62. else:
  63. # img, shape = self.resize_image_type1(img)
  64. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  65. return {"img": img, "img_shape": np.array([src_h, src_w, ratio_h, ratio_w])}
  66. def image_padding(self, im, value=0):
  67. """padding image"""
  68. h, w, c = im.shape
  69. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  70. im_pad[:h, :w, :] = im
  71. return im_pad
  72. def resize_image_type1(self, img):
  73. """resize the image"""
  74. resize_h, resize_w = self.image_shape
  75. ori_h, ori_w = img.shape[:2] # (h, w, c)
  76. if self.keep_ratio is True:
  77. resize_w = ori_w * resize_h / ori_h
  78. N = math.ceil(resize_w / 32)
  79. resize_w = N * 32
  80. ratio_h = float(resize_h) / ori_h
  81. ratio_w = float(resize_w) / ori_w
  82. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  83. # return img, np.array([ori_h, ori_w])
  84. return img, [ratio_h, ratio_w]
  85. def resize_image_type0(self, img):
  86. """
  87. resize image to a size multiple of 32 which is required by the network
  88. args:
  89. img(array): array with shape [h, w, c]
  90. return(tuple):
  91. img, (ratio_h, ratio_w)
  92. """
  93. limit_side_len = self.limit_side_len
  94. h, w, c = img.shape
  95. # limit the max side
  96. if self.limit_type == "max":
  97. if max(h, w) > limit_side_len:
  98. if h > w:
  99. ratio = float(limit_side_len) / h
  100. else:
  101. ratio = float(limit_side_len) / w
  102. else:
  103. ratio = 1.0
  104. elif self.limit_type == "min":
  105. if min(h, w) < limit_side_len:
  106. if h < w:
  107. ratio = float(limit_side_len) / h
  108. else:
  109. ratio = float(limit_side_len) / w
  110. else:
  111. ratio = 1.0
  112. elif self.limit_type == "resize_long":
  113. ratio = float(limit_side_len) / max(h, w)
  114. else:
  115. raise Exception("not support limit type, image ")
  116. resize_h = int(h * ratio)
  117. resize_w = int(w * ratio)
  118. resize_h = max(int(round(resize_h / 32) * 32), 32)
  119. resize_w = max(int(round(resize_w / 32) * 32), 32)
  120. try:
  121. if int(resize_w) <= 0 or int(resize_h) <= 0:
  122. return None, (None, None)
  123. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  124. except:
  125. logging.info(img.shape, resize_w, resize_h)
  126. sys.exit(0)
  127. ratio_h = resize_h / float(h)
  128. ratio_w = resize_w / float(w)
  129. return img, [ratio_h, ratio_w]
  130. def resize_image_type2(self, img):
  131. """resize image size"""
  132. h, w, _ = img.shape
  133. resize_w = w
  134. resize_h = h
  135. if resize_h > resize_w:
  136. ratio = float(self.resize_long) / resize_h
  137. else:
  138. ratio = float(self.resize_long) / resize_w
  139. resize_h = int(resize_h * ratio)
  140. resize_w = int(resize_w * ratio)
  141. max_stride = 128
  142. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  143. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  144. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  145. ratio_h = resize_h / float(h)
  146. ratio_w = resize_w / float(w)
  147. return img, [ratio_h, ratio_w]
  148. class NormalizeImage(BaseComponent):
  149. """normalize image such as substract mean, divide std"""
  150. INPUT_KEYS = ["img"]
  151. OUTPUT_KEYS = ["img"]
  152. DEAULT_INPUTS = {"img": "img"}
  153. DEAULT_OUTPUTS = {"img": "img"}
  154. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  155. super().__init__()
  156. if isinstance(scale, str):
  157. scale = eval(scale)
  158. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  159. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  160. std = std if std is not None else [0.229, 0.224, 0.225]
  161. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  162. self.mean = np.array(mean).reshape(shape).astype("float32")
  163. self.std = np.array(std).reshape(shape).astype("float32")
  164. def apply(self, img):
  165. """apply"""
  166. from PIL import Image
  167. if isinstance(img, Image.Image):
  168. img = np.array(img)
  169. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  170. img = (img.astype("float32") * self.scale - self.mean) / self.std
  171. return {"img": img}
  172. class DBPostProcess(BaseComponent):
  173. """
  174. The post process for Differentiable Binarization (DB).
  175. """
  176. INPUT_KEYS = ["pred", "img_shape", "img_path"]
  177. OUTPUT_KEYS = ["text_det_res"]
  178. DEAULT_INPUTS = {"pred": "pred", "img_shape": "img_shape", "img_path": "img_path"}
  179. DEAULT_OUTPUTS = {"text_det_res": "text_det_res"}
  180. def __init__(
  181. self,
  182. thresh=0.3,
  183. box_thresh=0.7,
  184. max_candidates=1000,
  185. unclip_ratio=2.0,
  186. use_dilation=False,
  187. score_mode="fast",
  188. box_type="quad",
  189. **kwargs
  190. ):
  191. super().__init__()
  192. self.thresh = thresh
  193. self.box_thresh = box_thresh
  194. self.max_candidates = max_candidates
  195. self.unclip_ratio = unclip_ratio
  196. self.min_size = 3
  197. self.score_mode = score_mode
  198. self.box_type = box_type
  199. assert score_mode in [
  200. "slow",
  201. "fast",
  202. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  203. self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
  204. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  205. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  206. bitmap = _bitmap
  207. height, width = bitmap.shape
  208. boxes = []
  209. scores = []
  210. contours, _ = cv2.findContours(
  211. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  212. )
  213. for contour in contours[: self.max_candidates]:
  214. epsilon = 0.002 * cv2.arcLength(contour, True)
  215. approx = cv2.approxPolyDP(contour, epsilon, True)
  216. points = approx.reshape((-1, 2))
  217. if points.shape[0] < 4:
  218. continue
  219. score = self.box_score_fast(pred, points.reshape(-1, 2))
  220. if self.box_thresh > score:
  221. continue
  222. if points.shape[0] > 2:
  223. box = self.unclip(points, self.unclip_ratio)
  224. if len(box) > 1:
  225. continue
  226. else:
  227. continue
  228. box = box.reshape(-1, 2)
  229. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  230. if sside < self.min_size + 2:
  231. continue
  232. box = np.array(box)
  233. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  234. box[:, 1] = np.clip(
  235. np.round(box[:, 1] / height * dest_height), 0, dest_height
  236. )
  237. boxes.append(box.tolist())
  238. scores.append(score)
  239. return boxes, scores
  240. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  241. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  242. bitmap = _bitmap
  243. height, width = bitmap.shape
  244. outs = cv2.findContours(
  245. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  246. )
  247. if len(outs) == 3:
  248. img, contours, _ = outs[0], outs[1], outs[2]
  249. elif len(outs) == 2:
  250. contours, _ = outs[0], outs[1]
  251. num_contours = min(len(contours), self.max_candidates)
  252. boxes = []
  253. scores = []
  254. for index in range(num_contours):
  255. contour = contours[index]
  256. points, sside = self.get_mini_boxes(contour)
  257. if sside < self.min_size:
  258. continue
  259. points = np.array(points)
  260. if self.score_mode == "fast":
  261. score = self.box_score_fast(pred, points.reshape(-1, 2))
  262. else:
  263. score = self.box_score_slow(pred, contour)
  264. if self.box_thresh > score:
  265. continue
  266. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  267. box, sside = self.get_mini_boxes(box)
  268. if sside < self.min_size + 2:
  269. continue
  270. box = np.array(box)
  271. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  272. box[:, 1] = np.clip(
  273. np.round(box[:, 1] / height * dest_height), 0, dest_height
  274. )
  275. boxes.append(box.astype(np.int16))
  276. scores.append(score)
  277. return np.array(boxes, dtype=np.int16), scores
  278. def unclip(self, box, unclip_ratio):
  279. """unclip"""
  280. poly = Polygon(box)
  281. distance = poly.area * unclip_ratio / poly.length
  282. offset = pyclipper.PyclipperOffset()
  283. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  284. expanded = np.array(offset.Execute(distance))
  285. return expanded
  286. def get_mini_boxes(self, contour):
  287. """get mini boxes"""
  288. bounding_box = cv2.minAreaRect(contour)
  289. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  290. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  291. if points[1][1] > points[0][1]:
  292. index_1 = 0
  293. index_4 = 1
  294. else:
  295. index_1 = 1
  296. index_4 = 0
  297. if points[3][1] > points[2][1]:
  298. index_2 = 2
  299. index_3 = 3
  300. else:
  301. index_2 = 3
  302. index_3 = 2
  303. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  304. return box, min(bounding_box[1])
  305. def box_score_fast(self, bitmap, _box):
  306. """box_score_fast: use bbox mean score as the mean score"""
  307. h, w = bitmap.shape[:2]
  308. box = _box.copy()
  309. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  310. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  311. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  312. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  313. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  314. box[:, 0] = box[:, 0] - xmin
  315. box[:, 1] = box[:, 1] - ymin
  316. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  317. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  318. def box_score_slow(self, bitmap, contour):
  319. """box_score_slow: use polyon mean score as the mean score"""
  320. h, w = bitmap.shape[:2]
  321. contour = contour.copy()
  322. contour = np.reshape(contour, (-1, 2))
  323. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  324. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  325. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  326. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  327. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  328. contour[:, 0] = contour[:, 0] - xmin
  329. contour[:, 1] = contour[:, 1] - ymin
  330. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  331. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  332. def apply(self, pred, img_shape, img_path):
  333. """apply"""
  334. pred = pred[0, :, :]
  335. segmentation = pred > self.thresh
  336. src_h, src_w, ratio_h, ratio_w = img_shape
  337. if self.dilation_kernel is not None:
  338. mask = cv2.dilate(
  339. np.array(segmentation).astype(np.uint8),
  340. self.dilation_kernel,
  341. )
  342. else:
  343. mask = segmentation
  344. if self.box_type == "poly":
  345. boxes, scores = self.polygons_from_bitmap(pred, mask, src_w, src_h)
  346. elif self.box_type == "quad":
  347. boxes, scores = self.boxes_from_bitmap(pred, mask, src_w, src_h)
  348. else:
  349. raise ValueError("box_type can only be one of ['quad', 'poly']")
  350. text_det_res = TextDetResult(
  351. {"img_path": img_path, "dt_polys": boxes, "dt_scores": scores}
  352. )
  353. return {"text_det_res": text_det_res}
  354. class CropByPolys(BaseComponent):
  355. """Crop Image by Polys"""
  356. INPUT_KEYS = ["img_path", "dt_polys"]
  357. OUTPUT_KEYS = ["img"]
  358. DEAULT_INPUTS = {"img_path": "img_path", "dt_polys": "dt_polys"}
  359. DEAULT_OUTPUTS = {"img": "img"}
  360. def __init__(self, det_box_type="quad"):
  361. super().__init__()
  362. self.det_box_type = det_box_type
  363. self._reader = ImageReader(backend="opencv")
  364. def apply(self, img_path, dt_polys):
  365. """apply"""
  366. img = self._reader.read(img_path)
  367. dt_boxes = np.array(dt_polys)
  368. # TODO
  369. # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
  370. output_list = []
  371. for bno in range(len(dt_boxes)):
  372. tmp_box = copy.deepcopy(dt_boxes[bno])
  373. if self.det_box_type == "quad":
  374. img_crop = self.get_rotate_crop_image(img, tmp_box)
  375. else:
  376. img_crop = self.get_minarea_rect_crop(img, tmp_box)
  377. output_list.append(
  378. {"img": img_crop, "img_size": [img_crop.shape[1], img_crop.shape[0]]}
  379. )
  380. return output_list
  381. def sorted_boxes(self, dt_boxes):
  382. """
  383. Sort text boxes in order from top to bottom, left to right
  384. args:
  385. dt_boxes(array):detected text boxes with shape [4, 2]
  386. return:
  387. sorted boxes(array) with shape [4, 2]
  388. """
  389. dt_boxes = np.array(dt_boxes)
  390. num_boxes = dt_boxes.shape[0]
  391. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  392. _boxes = list(sorted_boxes)
  393. for i in range(num_boxes - 1):
  394. for j in range(i, -1, -1):
  395. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  396. _boxes[j + 1][0][0] < _boxes[j][0][0]
  397. ):
  398. tmp = _boxes[j]
  399. _boxes[j] = _boxes[j + 1]
  400. _boxes[j + 1] = tmp
  401. else:
  402. break
  403. return _boxes
  404. def get_minarea_rect_crop(self, img, points):
  405. """get_minarea_rect_crop"""
  406. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  407. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  408. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  409. if points[1][1] > points[0][1]:
  410. index_a = 0
  411. index_d = 1
  412. else:
  413. index_a = 1
  414. index_d = 0
  415. if points[3][1] > points[2][1]:
  416. index_b = 2
  417. index_c = 3
  418. else:
  419. index_b = 3
  420. index_c = 2
  421. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  422. crop_img = self.get_rotate_crop_image(img, np.array(box))
  423. return crop_img
  424. def get_rotate_crop_image(self, img, points):
  425. """
  426. img_height, img_width = img.shape[0:2]
  427. left = int(np.min(points[:, 0]))
  428. right = int(np.max(points[:, 0]))
  429. top = int(np.min(points[:, 1]))
  430. bottom = int(np.max(points[:, 1]))
  431. img_crop = img[top:bottom, left:right, :].copy()
  432. points[:, 0] = points[:, 0] - left
  433. points[:, 1] = points[:, 1] - top
  434. """
  435. assert len(points) == 4, "shape of points must be 4*2"
  436. img_crop_width = int(
  437. max(
  438. np.linalg.norm(points[0] - points[1]),
  439. np.linalg.norm(points[2] - points[3]),
  440. )
  441. )
  442. img_crop_height = int(
  443. max(
  444. np.linalg.norm(points[0] - points[3]),
  445. np.linalg.norm(points[1] - points[2]),
  446. )
  447. )
  448. pts_std = np.float32(
  449. [
  450. [0, 0],
  451. [img_crop_width, 0],
  452. [img_crop_width, img_crop_height],
  453. [0, img_crop_height],
  454. ]
  455. )
  456. M = cv2.getPerspectiveTransform(points, pts_std)
  457. dst_img = cv2.warpPerspective(
  458. img,
  459. M,
  460. (img_crop_width, img_crop_height),
  461. borderMode=cv2.BORDER_REPLICATE,
  462. flags=cv2.INTER_CUBIC,
  463. )
  464. dst_img_height, dst_img_width = dst_img.shape[0:2]
  465. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  466. dst_img = np.rot90(dst_img)
  467. return dst_img