text_det.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import copy
  18. import math
  19. import pyclipper
  20. import numpy as np
  21. from numpy.linalg import norm
  22. from PIL import Image
  23. from shapely.geometry import Polygon
  24. from ...utils.io import ImageReader
  25. from ....utils import logging
  26. from ..base import BaseComponent
  27. from .seal_det_warp import AutoRectifier
  28. __all__ = ["DetResizeForTest", "NormalizeImage", "DBPostProcess", "CropByPolys"]
  29. class DetResizeForTest(BaseComponent):
  30. """DetResizeForTest"""
  31. INPUT_KEYS = ["img"]
  32. OUTPUT_KEYS = ["img", "img_shape"]
  33. DEAULT_INPUTS = {"img": "img"}
  34. DEAULT_OUTPUTS = {"img": "img", "img_shape": "img_shape"}
  35. def __init__(self, **kwargs):
  36. super().__init__()
  37. self.resize_type = 0
  38. self.keep_ratio = False
  39. if "image_shape" in kwargs:
  40. self.image_shape = kwargs["image_shape"]
  41. self.resize_type = 1
  42. if "keep_ratio" in kwargs:
  43. self.keep_ratio = kwargs["keep_ratio"]
  44. elif "limit_side_len" in kwargs:
  45. self.limit_side_len = kwargs["limit_side_len"]
  46. self.limit_type = kwargs.get("limit_type", "min")
  47. elif "resize_long" in kwargs:
  48. self.resize_type = 2
  49. self.resize_long = kwargs.get("resize_long", 960)
  50. else:
  51. self.limit_side_len = 736
  52. self.limit_type = "min"
  53. def apply(self, img):
  54. """apply"""
  55. src_h, src_w, _ = img.shape
  56. if sum([src_h, src_w]) < 64:
  57. img = self.image_padding(img)
  58. if self.resize_type == 0:
  59. # img, shape = self.resize_image_type0(img)
  60. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  61. elif self.resize_type == 2:
  62. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  63. else:
  64. # img, shape = self.resize_image_type1(img)
  65. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  66. return {"img": img, "img_shape": np.array([src_h, src_w, ratio_h, ratio_w])}
  67. def image_padding(self, im, value=0):
  68. """padding image"""
  69. h, w, c = im.shape
  70. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  71. im_pad[:h, :w, :] = im
  72. return im_pad
  73. def resize_image_type1(self, img):
  74. """resize the image"""
  75. resize_h, resize_w = self.image_shape
  76. ori_h, ori_w = img.shape[:2] # (h, w, c)
  77. if self.keep_ratio is True:
  78. resize_w = ori_w * resize_h / ori_h
  79. N = math.ceil(resize_w / 32)
  80. resize_w = N * 32
  81. ratio_h = float(resize_h) / ori_h
  82. ratio_w = float(resize_w) / ori_w
  83. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  84. # return img, np.array([ori_h, ori_w])
  85. return img, [ratio_h, ratio_w]
  86. def resize_image_type0(self, img):
  87. """
  88. resize image to a size multiple of 32 which is required by the network
  89. args:
  90. img(array): array with shape [h, w, c]
  91. return(tuple):
  92. img, (ratio_h, ratio_w)
  93. """
  94. limit_side_len = self.limit_side_len
  95. h, w, c = img.shape
  96. # limit the max side
  97. if self.limit_type == "max":
  98. if max(h, w) > limit_side_len:
  99. if h > w:
  100. ratio = float(limit_side_len) / h
  101. else:
  102. ratio = float(limit_side_len) / w
  103. else:
  104. ratio = 1.0
  105. elif self.limit_type == "min":
  106. if min(h, w) < limit_side_len:
  107. if h < w:
  108. ratio = float(limit_side_len) / h
  109. else:
  110. ratio = float(limit_side_len) / w
  111. else:
  112. ratio = 1.0
  113. elif self.limit_type == "resize_long":
  114. ratio = float(limit_side_len) / max(h, w)
  115. else:
  116. raise Exception("not support limit type, image ")
  117. resize_h = int(h * ratio)
  118. resize_w = int(w * ratio)
  119. resize_h = max(int(round(resize_h / 32) * 32), 32)
  120. resize_w = max(int(round(resize_w / 32) * 32), 32)
  121. try:
  122. if int(resize_w) <= 0 or int(resize_h) <= 0:
  123. return None, (None, None)
  124. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  125. except:
  126. logging.info(img.shape, resize_w, resize_h)
  127. sys.exit(0)
  128. ratio_h = resize_h / float(h)
  129. ratio_w = resize_w / float(w)
  130. return img, [ratio_h, ratio_w]
  131. def resize_image_type2(self, img):
  132. """resize image size"""
  133. h, w, _ = img.shape
  134. resize_w = w
  135. resize_h = h
  136. if resize_h > resize_w:
  137. ratio = float(self.resize_long) / resize_h
  138. else:
  139. ratio = float(self.resize_long) / resize_w
  140. resize_h = int(resize_h * ratio)
  141. resize_w = int(resize_w * ratio)
  142. max_stride = 128
  143. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  144. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  145. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  146. ratio_h = resize_h / float(h)
  147. ratio_w = resize_w / float(w)
  148. return img, [ratio_h, ratio_w]
  149. class NormalizeImage(BaseComponent):
  150. """normalize image such as substract mean, divide std"""
  151. INPUT_KEYS = ["img"]
  152. OUTPUT_KEYS = ["img"]
  153. DEAULT_INPUTS = {"img": "img"}
  154. DEAULT_OUTPUTS = {"img": "img"}
  155. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  156. super().__init__()
  157. if isinstance(scale, str):
  158. scale = eval(scale)
  159. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  160. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  161. std = std if std is not None else [0.229, 0.224, 0.225]
  162. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  163. self.mean = np.array(mean).reshape(shape).astype("float32")
  164. self.std = np.array(std).reshape(shape).astype("float32")
  165. def apply(self, img):
  166. """apply"""
  167. from PIL import Image
  168. if isinstance(img, Image.Image):
  169. img = np.array(img)
  170. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  171. img = (img.astype("float32") * self.scale - self.mean) / self.std
  172. return {"img": img}
  173. class DBPostProcess(BaseComponent):
  174. """
  175. The post process for Differentiable Binarization (DB).
  176. """
  177. INPUT_KEYS = ["pred", "img_shape"]
  178. OUTPUT_KEYS = ["dt_polys", "dt_scores"]
  179. DEAULT_INPUTS = {"pred": "pred", "img_shape": "img_shape"}
  180. DEAULT_OUTPUTS = {"dt_polys": "dt_polys", "dt_scores": "dt_scores"}
  181. def __init__(
  182. self,
  183. thresh=0.3,
  184. box_thresh=0.7,
  185. max_candidates=1000,
  186. unclip_ratio=2.0,
  187. use_dilation=False,
  188. score_mode="fast",
  189. box_type="quad",
  190. **kwargs
  191. ):
  192. super().__init__()
  193. self.thresh = thresh
  194. self.box_thresh = box_thresh
  195. self.max_candidates = max_candidates
  196. self.unclip_ratio = unclip_ratio
  197. self.min_size = 3
  198. self.score_mode = score_mode
  199. self.box_type = box_type
  200. assert score_mode in [
  201. "slow",
  202. "fast",
  203. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  204. self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
  205. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  206. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  207. bitmap = _bitmap
  208. height, width = bitmap.shape
  209. boxes = []
  210. scores = []
  211. contours, _ = cv2.findContours(
  212. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  213. )
  214. for contour in contours[: self.max_candidates]:
  215. epsilon = 0.002 * cv2.arcLength(contour, True)
  216. approx = cv2.approxPolyDP(contour, epsilon, True)
  217. points = approx.reshape((-1, 2))
  218. if points.shape[0] < 4:
  219. continue
  220. score = self.box_score_fast(pred, points.reshape(-1, 2))
  221. if self.box_thresh > score:
  222. continue
  223. if points.shape[0] > 2:
  224. box = self.unclip(points, self.unclip_ratio)
  225. if len(box) > 1:
  226. continue
  227. else:
  228. continue
  229. box = box.reshape(-1, 2)
  230. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  231. if sside < self.min_size + 2:
  232. continue
  233. box = np.array(box)
  234. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  235. box[:, 1] = np.clip(
  236. np.round(box[:, 1] / height * dest_height), 0, dest_height
  237. )
  238. boxes.append(box.tolist())
  239. scores.append(score)
  240. return boxes, scores
  241. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  242. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  243. bitmap = _bitmap
  244. height, width = bitmap.shape
  245. outs = cv2.findContours(
  246. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  247. )
  248. if len(outs) == 3:
  249. img, contours, _ = outs[0], outs[1], outs[2]
  250. elif len(outs) == 2:
  251. contours, _ = outs[0], outs[1]
  252. num_contours = min(len(contours), self.max_candidates)
  253. boxes = []
  254. scores = []
  255. for index in range(num_contours):
  256. contour = contours[index]
  257. points, sside = self.get_mini_boxes(contour)
  258. if sside < self.min_size:
  259. continue
  260. points = np.array(points)
  261. if self.score_mode == "fast":
  262. score = self.box_score_fast(pred, points.reshape(-1, 2))
  263. else:
  264. score = self.box_score_slow(pred, contour)
  265. if self.box_thresh > score:
  266. continue
  267. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  268. box, sside = self.get_mini_boxes(box)
  269. if sside < self.min_size + 2:
  270. continue
  271. box = np.array(box)
  272. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  273. box[:, 1] = np.clip(
  274. np.round(box[:, 1] / height * dest_height), 0, dest_height
  275. )
  276. boxes.append(box.astype(np.int16))
  277. scores.append(score)
  278. return np.array(boxes, dtype=np.int16), scores
  279. def unclip(self, box, unclip_ratio):
  280. """unclip"""
  281. poly = Polygon(box)
  282. distance = poly.area * unclip_ratio / poly.length
  283. offset = pyclipper.PyclipperOffset()
  284. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  285. expanded = np.array(offset.Execute(distance))
  286. return expanded
  287. def get_mini_boxes(self, contour):
  288. """get mini boxes"""
  289. bounding_box = cv2.minAreaRect(contour)
  290. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  291. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  292. if points[1][1] > points[0][1]:
  293. index_1 = 0
  294. index_4 = 1
  295. else:
  296. index_1 = 1
  297. index_4 = 0
  298. if points[3][1] > points[2][1]:
  299. index_2 = 2
  300. index_3 = 3
  301. else:
  302. index_2 = 3
  303. index_3 = 2
  304. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  305. return box, min(bounding_box[1])
  306. def box_score_fast(self, bitmap, _box):
  307. """box_score_fast: use bbox mean score as the mean score"""
  308. h, w = bitmap.shape[:2]
  309. box = _box.copy()
  310. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  311. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  312. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  313. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  314. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  315. box[:, 0] = box[:, 0] - xmin
  316. box[:, 1] = box[:, 1] - ymin
  317. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  318. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  319. def box_score_slow(self, bitmap, contour):
  320. """box_score_slow: use polyon mean score as the mean score"""
  321. h, w = bitmap.shape[:2]
  322. contour = contour.copy()
  323. contour = np.reshape(contour, (-1, 2))
  324. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  325. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  326. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  327. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  328. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  329. contour[:, 0] = contour[:, 0] - xmin
  330. contour[:, 1] = contour[:, 1] - ymin
  331. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  332. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  333. def apply(self, pred, img_shape):
  334. """apply"""
  335. pred = pred[0][0, :, :]
  336. segmentation = pred > self.thresh
  337. src_h, src_w, ratio_h, ratio_w = img_shape
  338. if self.dilation_kernel is not None:
  339. mask = cv2.dilate(
  340. np.array(segmentation).astype(np.uint8),
  341. self.dilation_kernel,
  342. )
  343. else:
  344. mask = segmentation
  345. if self.box_type == "poly":
  346. boxes, scores = self.polygons_from_bitmap(pred, mask, src_w, src_h)
  347. elif self.box_type == "quad":
  348. boxes, scores = self.boxes_from_bitmap(pred, mask, src_w, src_h)
  349. else:
  350. raise ValueError("box_type can only be one of ['quad', 'poly']")
  351. return {"dt_polys": boxes, "dt_scores": scores}
  352. class CropByPolys(BaseComponent):
  353. """Crop Image by Polys"""
  354. INPUT_KEYS = ["input_path", "dt_polys"]
  355. OUTPUT_KEYS = ["img"]
  356. DEAULT_INPUTS = {"input_path": "input_path", "dt_polys": "dt_polys"}
  357. DEAULT_OUTPUTS = {"img": "img"}
  358. def __init__(self, det_box_type="quad"):
  359. super().__init__()
  360. self.det_box_type = det_box_type
  361. self._reader = ImageReader(backend="opencv")
  362. def apply(self, input_path, dt_polys):
  363. """apply"""
  364. img = self._reader.read(input_path)
  365. if self.det_box_type == "quad":
  366. dt_boxes = np.array(dt_polys)
  367. output_list = []
  368. for bno in range(len(dt_boxes)):
  369. tmp_box = copy.deepcopy(dt_boxes[bno])
  370. img_crop = self.get_minarea_rect_crop(img, tmp_box)
  371. output_list.append(
  372. {
  373. "img": img_crop,
  374. "img_size": [img_crop.shape[1], img_crop.shape[0]],
  375. }
  376. )
  377. elif self.det_box_type == "poly":
  378. output_list = []
  379. dt_boxes = dt_polys
  380. for bno in range(len(dt_boxes)):
  381. tmp_box = copy.deepcopy(dt_boxes[bno])
  382. img_crop = self.get_poly_rect_crop(img.copy(), tmp_box)
  383. output_list.append(
  384. {
  385. "img": img_crop,
  386. "img_size": [img_crop.shape[1], img_crop.shape[0]],
  387. }
  388. )
  389. else:
  390. raise NotImplementedError
  391. return output_list
  392. def get_minarea_rect_crop(self, img, points):
  393. """get_minarea_rect_crop"""
  394. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  395. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  396. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  397. if points[1][1] > points[0][1]:
  398. index_a = 0
  399. index_d = 1
  400. else:
  401. index_a = 1
  402. index_d = 0
  403. if points[3][1] > points[2][1]:
  404. index_b = 2
  405. index_c = 3
  406. else:
  407. index_b = 3
  408. index_c = 2
  409. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  410. crop_img = self.get_rotate_crop_image(img, np.array(box))
  411. return crop_img
  412. def get_rotate_crop_image(self, img, points):
  413. """
  414. img_height, img_width = img.shape[0:2]
  415. left = int(np.min(points[:, 0]))
  416. right = int(np.max(points[:, 0]))
  417. top = int(np.min(points[:, 1]))
  418. bottom = int(np.max(points[:, 1]))
  419. img_crop = img[top:bottom, left:right, :].copy()
  420. points[:, 0] = points[:, 0] - left
  421. points[:, 1] = points[:, 1] - top
  422. """
  423. assert len(points) == 4, "shape of points must be 4*2"
  424. img_crop_width = int(
  425. max(
  426. np.linalg.norm(points[0] - points[1]),
  427. np.linalg.norm(points[2] - points[3]),
  428. )
  429. )
  430. img_crop_height = int(
  431. max(
  432. np.linalg.norm(points[0] - points[3]),
  433. np.linalg.norm(points[1] - points[2]),
  434. )
  435. )
  436. pts_std = np.float32(
  437. [
  438. [0, 0],
  439. [img_crop_width, 0],
  440. [img_crop_width, img_crop_height],
  441. [0, img_crop_height],
  442. ]
  443. )
  444. M = cv2.getPerspectiveTransform(points, pts_std)
  445. dst_img = cv2.warpPerspective(
  446. img,
  447. M,
  448. (img_crop_width, img_crop_height),
  449. borderMode=cv2.BORDER_REPLICATE,
  450. flags=cv2.INTER_CUBIC,
  451. )
  452. dst_img_height, dst_img_width = dst_img.shape[0:2]
  453. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  454. dst_img = np.rot90(dst_img)
  455. return dst_img
  456. def reorder_poly_edge(self, points):
  457. """Get the respective points composing head edge, tail edge, top
  458. sideline and bottom sideline.
  459. Args:
  460. points (ndarray): The points composing a text polygon.
  461. Returns:
  462. head_edge (ndarray): The two points composing the head edge of text
  463. polygon.
  464. tail_edge (ndarray): The two points composing the tail edge of text
  465. polygon.
  466. top_sideline (ndarray): The points composing top curved sideline of
  467. text polygon.
  468. bot_sideline (ndarray): The points composing bottom curved sideline
  469. of text polygon.
  470. """
  471. assert points.ndim == 2
  472. assert points.shape[0] >= 4
  473. assert points.shape[1] == 2
  474. orientation_thr = 2.0 # 一个经验超参数
  475. head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
  476. head_edge, tail_edge = points[head_inds], points[tail_inds]
  477. pad_points = np.vstack([points, points])
  478. if tail_inds[1] < 1:
  479. tail_inds[1] = len(points)
  480. sideline1 = pad_points[head_inds[1] : tail_inds[1]]
  481. sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
  482. return head_edge, tail_edge, sideline1, sideline2
  483. def vector_slope(self, vec):
  484. assert len(vec) == 2
  485. return abs(vec[1] / (vec[0] + 1e-8))
  486. def find_head_tail(self, points, orientation_thr):
  487. """Find the head edge and tail edge of a text polygon.
  488. Args:
  489. points (ndarray): The points composing a text polygon.
  490. orientation_thr (float): The threshold for distinguishing between
  491. head edge and tail edge among the horizontal and vertical edges
  492. of a quadrangle.
  493. Returns:
  494. head_inds (list): The indexes of two points composing head edge.
  495. tail_inds (list): The indexes of two points composing tail edge.
  496. """
  497. assert points.ndim == 2
  498. assert points.shape[0] >= 4
  499. assert points.shape[1] == 2
  500. assert isinstance(orientation_thr, float)
  501. if len(points) > 4:
  502. pad_points = np.vstack([points, points[0]])
  503. edge_vec = pad_points[1:] - pad_points[:-1]
  504. theta_sum = []
  505. adjacent_vec_theta = []
  506. for i, edge_vec1 in enumerate(edge_vec):
  507. adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
  508. adjacent_edge_vec = edge_vec[adjacent_ind]
  509. temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
  510. temp_adjacent_theta = self.vector_angle(
  511. adjacent_edge_vec[0], adjacent_edge_vec[1]
  512. )
  513. theta_sum.append(temp_theta_sum)
  514. adjacent_vec_theta.append(temp_adjacent_theta)
  515. theta_sum_score = np.array(theta_sum) / np.pi
  516. adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
  517. poly_center = np.mean(points, axis=0)
  518. edge_dist = np.maximum(
  519. norm(pad_points[1:] - poly_center, axis=-1),
  520. norm(pad_points[:-1] - poly_center, axis=-1),
  521. )
  522. dist_score = edge_dist / np.max(edge_dist)
  523. position_score = np.zeros(len(edge_vec))
  524. score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
  525. score += 0.35 * dist_score
  526. if len(points) % 2 == 0:
  527. position_score[(len(score) // 2 - 1)] += 1
  528. position_score[-1] += 1
  529. score += 0.1 * position_score
  530. pad_score = np.concatenate([score, score])
  531. score_matrix = np.zeros((len(score), len(score) - 3))
  532. x = np.arange(len(score) - 3) / float(len(score) - 4)
  533. gaussian = (
  534. 1.0
  535. / (np.sqrt(2.0 * np.pi) * 0.5)
  536. * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
  537. )
  538. gaussian = gaussian / np.max(gaussian)
  539. for i in range(len(score)):
  540. score_matrix[i, :] = (
  541. score[i]
  542. + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
  543. )
  544. head_start, tail_increment = np.unravel_index(
  545. score_matrix.argmax(), score_matrix.shape
  546. )
  547. tail_start = (head_start + tail_increment + 2) % len(points)
  548. head_end = (head_start + 1) % len(points)
  549. tail_end = (tail_start + 1) % len(points)
  550. if head_end > tail_end:
  551. head_start, tail_start = tail_start, head_start
  552. head_end, tail_end = tail_end, head_end
  553. head_inds = [head_start, head_end]
  554. tail_inds = [tail_start, tail_end]
  555. else:
  556. if self.vector_slope(points[1] - points[0]) + self.vector_slope(
  557. points[3] - points[2]
  558. ) < self.vector_slope(points[2] - points[1]) + self.vector_slope(
  559. points[0] - points[3]
  560. ):
  561. horizontal_edge_inds = [[0, 1], [2, 3]]
  562. vertical_edge_inds = [[3, 0], [1, 2]]
  563. else:
  564. horizontal_edge_inds = [[3, 0], [1, 2]]
  565. vertical_edge_inds = [[0, 1], [2, 3]]
  566. vertical_len_sum = norm(
  567. points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
  568. ) + norm(
  569. points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
  570. )
  571. horizontal_len_sum = norm(
  572. points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
  573. ) + norm(
  574. points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
  575. )
  576. if vertical_len_sum > horizontal_len_sum * orientation_thr:
  577. head_inds = horizontal_edge_inds[0]
  578. tail_inds = horizontal_edge_inds[1]
  579. else:
  580. head_inds = vertical_edge_inds[0]
  581. tail_inds = vertical_edge_inds[1]
  582. return head_inds, tail_inds
  583. def vector_angle(self, vec1, vec2):
  584. if vec1.ndim > 1:
  585. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
  586. else:
  587. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
  588. if vec2.ndim > 1:
  589. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
  590. else:
  591. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
  592. return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
  593. def get_minarea_rect(self, img, points):
  594. bounding_box = cv2.minAreaRect(points)
  595. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  596. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  597. if points[1][1] > points[0][1]:
  598. index_a = 0
  599. index_d = 1
  600. else:
  601. index_a = 1
  602. index_d = 0
  603. if points[3][1] > points[2][1]:
  604. index_b = 2
  605. index_c = 3
  606. else:
  607. index_b = 3
  608. index_c = 2
  609. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  610. crop_img = self.get_rotate_crop_image(img, np.array(box))
  611. return crop_img, box
  612. def sample_points_on_bbox_bp(self, line, n=50):
  613. """Resample n points on a line.
  614. Args:
  615. line (ndarray): The points composing a line.
  616. n (int): The resampled points number.
  617. Returns:
  618. resampled_line (ndarray): The points composing the resampled line.
  619. """
  620. from numpy.linalg import norm
  621. # 断言检查输入参数的有效性
  622. assert line.ndim == 2
  623. assert line.shape[0] >= 2
  624. assert line.shape[1] == 2
  625. assert isinstance(n, int)
  626. assert n > 0
  627. length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
  628. total_length = sum(length_list)
  629. length_cumsum = np.cumsum([0.0] + length_list)
  630. delta_length = total_length / (float(n) + 1e-8)
  631. current_edge_ind = 0
  632. resampled_line = [line[0]]
  633. for i in range(1, n):
  634. current_line_len = i * delta_length
  635. while (
  636. current_edge_ind + 1 < len(length_cumsum)
  637. and current_line_len >= length_cumsum[current_edge_ind + 1]
  638. ):
  639. current_edge_ind += 1
  640. current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
  641. if current_edge_ind >= len(length_list):
  642. break
  643. end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
  644. current_point = (
  645. line[current_edge_ind]
  646. + (line[current_edge_ind + 1] - line[current_edge_ind])
  647. * end_shift_ratio
  648. )
  649. resampled_line.append(current_point)
  650. resampled_line.append(line[-1])
  651. resampled_line = np.array(resampled_line)
  652. return resampled_line
  653. def sample_points_on_bbox(self, line, n=50):
  654. """Resample n points on a line.
  655. Args:
  656. line (ndarray): The points composing a line.
  657. n (int): The resampled points number.
  658. Returns:
  659. resampled_line (ndarray): The points composing the resampled line.
  660. """
  661. assert line.ndim == 2
  662. assert line.shape[0] >= 2
  663. assert line.shape[1] == 2
  664. assert isinstance(n, int)
  665. assert n > 0
  666. length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
  667. total_length = sum(length_list)
  668. mean_length = total_length / (len(length_list) + 1e-8)
  669. group = [[0]]
  670. for i in range(len(length_list)):
  671. point_id = i + 1
  672. if length_list[i] < 0.9 * mean_length:
  673. for g in group:
  674. if i in g:
  675. g.append(point_id)
  676. break
  677. else:
  678. g = [point_id]
  679. group.append(g)
  680. top_tail_len = norm(line[0] - line[-1])
  681. if top_tail_len < 0.9 * mean_length:
  682. group[0].extend(g)
  683. group.remove(g)
  684. mean_positions = []
  685. for indices in group:
  686. x_sum = 0
  687. y_sum = 0
  688. for index in indices:
  689. x, y = line[index]
  690. x_sum += x
  691. y_sum += y
  692. num_points = len(indices)
  693. mean_x = x_sum / num_points
  694. mean_y = y_sum / num_points
  695. mean_positions.append((mean_x, mean_y))
  696. resampled_line = np.array(mean_positions)
  697. return resampled_line
  698. def get_poly_rect_crop(self, img, points):
  699. """
  700. 修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
  701. args: img: 图片 ndarrary格式
  702. points: polygon格式的多点坐标 N*2 shape, ndarray格式
  703. return: 矫正后的图片 ndarray格式
  704. """
  705. points = np.array(points).astype(np.int32).reshape(-1, 2)
  706. temp_crop_img, temp_box = self.get_minarea_rect(img, points)
  707. # 计算最小外接矩形与polygon的IoU
  708. def get_union(pD, pG):
  709. return Polygon(pD).union(Polygon(pG)).area
  710. def get_intersection_over_union(pD, pG):
  711. return get_intersection(pD, pG) / (get_union(pD, pG) + 1e-10)
  712. def get_intersection(pD, pG):
  713. return Polygon(pD).intersection(Polygon(pG)).area
  714. cal_IoU = get_intersection_over_union(points, temp_box)
  715. if cal_IoU >= 0.7:
  716. points = self.sample_points_on_bbox_bp(points, 31)
  717. return temp_crop_img
  718. points_sample = self.sample_points_on_bbox(points)
  719. points_sample = points_sample.astype(np.int32)
  720. head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
  721. resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
  722. resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
  723. sideline_mean_shift = np.mean(resample_top_line, axis=0) - np.mean(
  724. resample_bot_line, axis=0
  725. )
  726. if sideline_mean_shift[1] > 0:
  727. resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
  728. rectifier = AutoRectifier()
  729. new_points = np.concatenate([resample_top_line, resample_bot_line])
  730. new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
  731. if len(img.shape) == 2:
  732. img = np.stack((img,) * 3, axis=-1)
  733. img_crop, image = rectifier.run(img, new_points_list, mode="homography")
  734. return img_crop[0]
  735. class SortBoxes(BaseComponent):
  736. YIELD_BATCH = False
  737. INPUT_KEYS = ["dt_polys"]
  738. OUTPUT_KEYS = ["dt_polys"]
  739. DEAULT_INPUTS = {"dt_polys": "dt_polys"}
  740. DEAULT_OUTPUTS = {"dt_polys": "dt_polys"}
  741. def apply(self, dt_polys):
  742. """
  743. Sort text boxes in order from top to bottom, left to right
  744. args:
  745. dt_boxes(array):detected text boxes with shape [4, 2]
  746. return:
  747. sorted boxes(array) with shape [4, 2]
  748. """
  749. dt_boxes = np.array(dt_polys)
  750. num_boxes = dt_boxes.shape[0]
  751. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  752. _boxes = list(sorted_boxes)
  753. for i in range(num_boxes - 1):
  754. for j in range(i, -1, -1):
  755. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  756. _boxes[j + 1][0][0] < _boxes[j][0][0]
  757. ):
  758. tmp = _boxes[j]
  759. _boxes[j] = _boxes[j + 1]
  760. _boxes[j + 1] = tmp
  761. else:
  762. break
  763. return {"dt_polys": [box.tolist() for box in _boxes]}