transforms.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import copy
  18. import math
  19. import pyclipper
  20. import numpy as np
  21. from numpy.linalg import norm
  22. from PIL import Image
  23. from shapely.geometry import Polygon
  24. from ....utils import logging
  25. from ...base.predictor.io.writers import ImageWriter
  26. from ...base.predictor.io.readers import ImageReader
  27. from ...base.predictor import BaseTransform
  28. from .keys import TextDetKeys as K
  29. from .utils import AutoRectifier
  30. __all__ = [
  31. "DetResizeForTest",
  32. "NormalizeImage",
  33. "DBPostProcess",
  34. "SaveTextDetResults",
  35. "PrintResult",
  36. ]
  37. class DetResizeForTest(BaseTransform):
  38. """DetResizeForTest"""
  39. def __init__(self, **kwargs):
  40. super(DetResizeForTest, self).__init__()
  41. self.resize_type = 0
  42. self.keep_ratio = False
  43. if "image_shape" in kwargs:
  44. self.image_shape = kwargs["image_shape"]
  45. self.resize_type = 1
  46. if "keep_ratio" in kwargs:
  47. self.keep_ratio = kwargs["keep_ratio"]
  48. elif "limit_side_len" in kwargs:
  49. self.limit_side_len = kwargs["limit_side_len"]
  50. self.limit_type = kwargs.get("limit_type", "min")
  51. elif "resize_long" in kwargs:
  52. self.resize_type = 2
  53. self.resize_long = kwargs.get("resize_long", 960)
  54. else:
  55. self.limit_side_len = 736
  56. self.limit_type = "min"
  57. def apply(self, data):
  58. """apply"""
  59. img = data[K.IMAGE]
  60. src_h, src_w, _ = img.shape
  61. if sum([src_h, src_w]) < 64:
  62. img = self.image_padding(img)
  63. if self.resize_type == 0:
  64. # img, shape = self.resize_image_type0(img)
  65. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  66. elif self.resize_type == 2:
  67. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  68. else:
  69. # img, shape = self.resize_image_type1(img)
  70. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  71. data[K.IMAGE] = img
  72. data[K.SHAPE] = np.array([src_h, src_w, ratio_h, ratio_w])
  73. return data
  74. @classmethod
  75. def get_input_keys(cls):
  76. """get input keys"""
  77. return [K.IMAGE]
  78. @classmethod
  79. def get_output_keys(cls):
  80. """get output keys"""
  81. return [K.IMAGE, K.SHAPE]
  82. def image_padding(self, im, value=0):
  83. """padding image"""
  84. h, w, c = im.shape
  85. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  86. im_pad[:h, :w, :] = im
  87. return im_pad
  88. def resize_image_type1(self, img):
  89. """resize the image"""
  90. resize_h, resize_w = self.image_shape
  91. ori_h, ori_w = img.shape[:2] # (h, w, c)
  92. if self.keep_ratio is True:
  93. resize_w = ori_w * resize_h / ori_h
  94. N = math.ceil(resize_w / 32)
  95. resize_w = N * 32
  96. ratio_h = float(resize_h) / ori_h
  97. ratio_w = float(resize_w) / ori_w
  98. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  99. # return img, np.array([ori_h, ori_w])
  100. return img, [ratio_h, ratio_w]
  101. def resize_image_type0(self, img):
  102. """
  103. resize image to a size multiple of 32 which is required by the network
  104. args:
  105. img(array): array with shape [h, w, c]
  106. return(tuple):
  107. img, (ratio_h, ratio_w)
  108. """
  109. limit_side_len = self.limit_side_len
  110. h, w, c = img.shape
  111. # limit the max side
  112. if self.limit_type == "max":
  113. if max(h, w) > limit_side_len:
  114. if h > w:
  115. ratio = float(limit_side_len) / h
  116. else:
  117. ratio = float(limit_side_len) / w
  118. else:
  119. ratio = 1.0
  120. elif self.limit_type == "min":
  121. if min(h, w) < limit_side_len:
  122. if h < w:
  123. ratio = float(limit_side_len) / h
  124. else:
  125. ratio = float(limit_side_len) / w
  126. else:
  127. ratio = 1.0
  128. elif self.limit_type == "resize_long":
  129. ratio = float(limit_side_len) / max(h, w)
  130. else:
  131. raise Exception("not support limit type, image ")
  132. resize_h = int(h * ratio)
  133. resize_w = int(w * ratio)
  134. resize_h = max(int(round(resize_h / 32) * 32), 32)
  135. resize_w = max(int(round(resize_w / 32) * 32), 32)
  136. try:
  137. if int(resize_w) <= 0 or int(resize_h) <= 0:
  138. return None, (None, None)
  139. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  140. except:
  141. logging.info(img.shape, resize_w, resize_h)
  142. sys.exit(0)
  143. ratio_h = resize_h / float(h)
  144. ratio_w = resize_w / float(w)
  145. return img, [ratio_h, ratio_w]
  146. def resize_image_type2(self, img):
  147. """resize image size"""
  148. h, w, _ = img.shape
  149. resize_w = w
  150. resize_h = h
  151. if resize_h > resize_w:
  152. ratio = float(self.resize_long) / resize_h
  153. else:
  154. ratio = float(self.resize_long) / resize_w
  155. resize_h = int(resize_h * ratio)
  156. resize_w = int(resize_w * ratio)
  157. max_stride = 128
  158. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  159. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  160. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  161. ratio_h = resize_h / float(h)
  162. ratio_w = resize_w / float(w)
  163. return img, [ratio_h, ratio_w]
  164. class NormalizeImage(BaseTransform):
  165. """normalize image such as substract mean, divide std"""
  166. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  167. if isinstance(scale, str):
  168. scale = eval(scale)
  169. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  170. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  171. std = std if std is not None else [0.229, 0.224, 0.225]
  172. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  173. self.mean = np.array(mean).reshape(shape).astype("float32")
  174. self.std = np.array(std).reshape(shape).astype("float32")
  175. def apply(self, data):
  176. """apply"""
  177. img = data[K.IMAGE]
  178. from PIL import Image
  179. if isinstance(img, Image.Image):
  180. img = np.array(img)
  181. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  182. data[K.IMAGE] = (img.astype("float32") * self.scale - self.mean) / self.std
  183. return data
  184. @classmethod
  185. def get_input_keys(cls):
  186. """get input keys"""
  187. return [K.IMAGE]
  188. @classmethod
  189. def get_output_keys(cls):
  190. """get output keys"""
  191. return [K.IMAGE]
  192. class DBPostProcess(BaseTransform):
  193. """
  194. The post process for Differentiable Binarization (DB).
  195. """
  196. def __init__(
  197. self,
  198. thresh=0.3,
  199. box_thresh=0.7,
  200. max_candidates=1000,
  201. unclip_ratio=2.0,
  202. use_dilation=False,
  203. score_mode="fast",
  204. box_type="quad",
  205. **kwargs
  206. ):
  207. self.thresh = thresh
  208. self.box_thresh = box_thresh
  209. self.max_candidates = max_candidates
  210. self.unclip_ratio = unclip_ratio
  211. self.min_size = 3
  212. self.score_mode = score_mode
  213. self.box_type = box_type
  214. assert score_mode in [
  215. "slow",
  216. "fast",
  217. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  218. self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
  219. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  220. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  221. bitmap = _bitmap
  222. height, width = bitmap.shape
  223. boxes = []
  224. scores = []
  225. contours, _ = cv2.findContours(
  226. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  227. )
  228. for contour in contours[: self.max_candidates]:
  229. epsilon = 0.002 * cv2.arcLength(contour, True)
  230. approx = cv2.approxPolyDP(contour, epsilon, True)
  231. points = approx.reshape((-1, 2))
  232. if points.shape[0] < 4:
  233. continue
  234. score = self.box_score_fast(pred, points.reshape(-1, 2))
  235. if self.box_thresh > score:
  236. continue
  237. if points.shape[0] > 2:
  238. box = self.unclip(points, self.unclip_ratio)
  239. if len(box) > 1:
  240. continue
  241. else:
  242. continue
  243. box = box.reshape(-1, 2)
  244. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  245. if sside < self.min_size + 2:
  246. continue
  247. box = np.array(box)
  248. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  249. box[:, 1] = np.clip(
  250. np.round(box[:, 1] / height * dest_height), 0, dest_height
  251. )
  252. boxes.append(box.tolist())
  253. scores.append(score)
  254. return boxes, scores
  255. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  256. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  257. bitmap = _bitmap
  258. height, width = bitmap.shape
  259. outs = cv2.findContours(
  260. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  261. )
  262. if len(outs) == 3:
  263. img, contours, _ = outs[0], outs[1], outs[2]
  264. elif len(outs) == 2:
  265. contours, _ = outs[0], outs[1]
  266. num_contours = min(len(contours), self.max_candidates)
  267. boxes = []
  268. scores = []
  269. for index in range(num_contours):
  270. contour = contours[index]
  271. points, sside = self.get_mini_boxes(contour)
  272. if sside < self.min_size:
  273. continue
  274. points = np.array(points)
  275. if self.score_mode == "fast":
  276. score = self.box_score_fast(pred, points.reshape(-1, 2))
  277. else:
  278. score = self.box_score_slow(pred, contour)
  279. if self.box_thresh > score:
  280. continue
  281. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  282. box, sside = self.get_mini_boxes(box)
  283. if sside < self.min_size + 2:
  284. continue
  285. box = np.array(box)
  286. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  287. box[:, 1] = np.clip(
  288. np.round(box[:, 1] / height * dest_height), 0, dest_height
  289. )
  290. boxes.append(box.astype(np.int16))
  291. scores.append(score)
  292. return np.array(boxes, dtype=np.int16), scores
  293. def unclip(self, box, unclip_ratio):
  294. """unclip"""
  295. poly = Polygon(box)
  296. distance = poly.area * unclip_ratio / poly.length
  297. offset = pyclipper.PyclipperOffset()
  298. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  299. expanded = np.array(offset.Execute(distance))
  300. return expanded
  301. def get_mini_boxes(self, contour):
  302. """get mini boxes"""
  303. bounding_box = cv2.minAreaRect(contour)
  304. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  305. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  306. if points[1][1] > points[0][1]:
  307. index_1 = 0
  308. index_4 = 1
  309. else:
  310. index_1 = 1
  311. index_4 = 0
  312. if points[3][1] > points[2][1]:
  313. index_2 = 2
  314. index_3 = 3
  315. else:
  316. index_2 = 3
  317. index_3 = 2
  318. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  319. return box, min(bounding_box[1])
  320. def box_score_fast(self, bitmap, _box):
  321. """box_score_fast: use bbox mean score as the mean score"""
  322. h, w = bitmap.shape[:2]
  323. box = _box.copy()
  324. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  325. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  326. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  327. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  328. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  329. box[:, 0] = box[:, 0] - xmin
  330. box[:, 1] = box[:, 1] - ymin
  331. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  332. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  333. def box_score_slow(self, bitmap, contour):
  334. """box_score_slow: use polyon mean score as the mean score"""
  335. h, w = bitmap.shape[:2]
  336. contour = contour.copy()
  337. contour = np.reshape(contour, (-1, 2))
  338. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  339. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  340. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  341. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  342. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  343. contour[:, 0] = contour[:, 0] - xmin
  344. contour[:, 1] = contour[:, 1] - ymin
  345. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  346. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  347. def apply(self, data):
  348. """apply"""
  349. pred = data[K.PROB_MAP]
  350. shape_list = [data[K.SHAPE]]
  351. pred = pred[0][:, 0, :, :]
  352. segmentation = pred > self.thresh
  353. boxes_batch = []
  354. for batch_index in range(pred.shape[0]):
  355. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  356. if self.dilation_kernel is not None:
  357. mask = cv2.dilate(
  358. np.array(segmentation[batch_index]).astype(np.uint8),
  359. self.dilation_kernel,
  360. )
  361. else:
  362. mask = segmentation[batch_index]
  363. if self.box_type == "poly":
  364. boxes, scores = self.polygons_from_bitmap(
  365. pred[batch_index], mask, src_w, src_h
  366. )
  367. elif self.box_type == "quad":
  368. boxes, scores = self.boxes_from_bitmap(
  369. pred[batch_index], mask, src_w, src_h
  370. )
  371. else:
  372. raise ValueError("box_type can only be one of ['quad', 'poly']")
  373. data[K.DT_POLYS] = boxes
  374. data[K.DT_SCORES] = scores
  375. return data
  376. @classmethod
  377. def get_input_keys(cls):
  378. """get input keys"""
  379. return [K.PROB_MAP]
  380. @classmethod
  381. def get_output_keys(cls):
  382. """get output keys"""
  383. return [K.DT_POLYS, K.DT_SCORES]
  384. class CropByPolys(BaseTransform):
  385. """Crop Image by Polys"""
  386. def __init__(self, det_box_type="quad"):
  387. super().__init__()
  388. self.det_box_type = det_box_type
  389. def apply(self, data):
  390. """apply"""
  391. ori_im = data[K.ORI_IM]
  392. if self.det_box_type == "quad":
  393. dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
  394. dt_boxes = np.array(dt_boxes)
  395. img_crop_list = []
  396. for bno in range(len(dt_boxes)):
  397. tmp_box = copy.deepcopy(dt_boxes[bno])
  398. img_crop = self.get_minarea_rect_crop(ori_im, tmp_box)
  399. img_crop_list.append(img_crop)
  400. elif self.det_box_type == "poly":
  401. img_crop_list = []
  402. dt_boxes = data[K.DT_POLYS]
  403. for bno in range(len(dt_boxes)):
  404. tmp_box = copy.deepcopy(dt_boxes[bno])
  405. img_crop = self.get_poly_rect_crop(ori_im.copy(), tmp_box)
  406. img_crop_list.append(img_crop)
  407. else:
  408. raise NotImplementedError
  409. data[K.SUB_IMGS] = img_crop_list
  410. return data
  411. @classmethod
  412. def get_input_keys(cls):
  413. """get input keys"""
  414. return [K.IM_PATH, K.DT_POLYS]
  415. @classmethod
  416. def get_output_keys(cls):
  417. """get output keys"""
  418. return [K.SUB_IMGS]
  419. def sorted_boxes(self, dt_boxes):
  420. """
  421. Sort text boxes in order from top to bottom, left to right
  422. args:
  423. dt_boxes(array):detected text boxes with shape [4, 2]
  424. return:
  425. sorted boxes(array) with shape [4, 2]
  426. """
  427. dt_boxes = np.array(dt_boxes)
  428. num_boxes = dt_boxes.shape[0]
  429. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  430. _boxes = list(sorted_boxes)
  431. for i in range(num_boxes - 1):
  432. for j in range(i, -1, -1):
  433. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  434. _boxes[j + 1][0][0] < _boxes[j][0][0]
  435. ):
  436. tmp = _boxes[j]
  437. _boxes[j] = _boxes[j + 1]
  438. _boxes[j + 1] = tmp
  439. else:
  440. break
  441. return _boxes
  442. def get_minarea_rect_crop(self, img, points):
  443. """get_minarea_rect_crop"""
  444. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  445. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  446. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  447. if points[1][1] > points[0][1]:
  448. index_a = 0
  449. index_d = 1
  450. else:
  451. index_a = 1
  452. index_d = 0
  453. if points[3][1] > points[2][1]:
  454. index_b = 2
  455. index_c = 3
  456. else:
  457. index_b = 3
  458. index_c = 2
  459. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  460. crop_img = self.get_rotate_crop_image(img, np.array(box))
  461. return crop_img
  462. def get_rotate_crop_image(self, img, points):
  463. """
  464. img_height, img_width = img.shape[0:2]
  465. left = int(np.min(points[:, 0]))
  466. right = int(np.max(points[:, 0]))
  467. top = int(np.min(points[:, 1]))
  468. bottom = int(np.max(points[:, 1]))
  469. img_crop = img[top:bottom, left:right, :].copy()
  470. points[:, 0] = points[:, 0] - left
  471. points[:, 1] = points[:, 1] - top
  472. """
  473. assert len(points) == 4, "shape of points must be 4*2"
  474. img_crop_width = int(
  475. max(
  476. np.linalg.norm(points[0] - points[1]),
  477. np.linalg.norm(points[2] - points[3]),
  478. )
  479. )
  480. img_crop_height = int(
  481. max(
  482. np.linalg.norm(points[0] - points[3]),
  483. np.linalg.norm(points[1] - points[2]),
  484. )
  485. )
  486. pts_std = np.float32(
  487. [
  488. [0, 0],
  489. [img_crop_width, 0],
  490. [img_crop_width, img_crop_height],
  491. [0, img_crop_height],
  492. ]
  493. )
  494. M = cv2.getPerspectiveTransform(points, pts_std)
  495. dst_img = cv2.warpPerspective(
  496. img,
  497. M,
  498. (img_crop_width, img_crop_height),
  499. borderMode=cv2.BORDER_REPLICATE,
  500. flags=cv2.INTER_CUBIC,
  501. )
  502. dst_img_height, dst_img_width = dst_img.shape[0:2]
  503. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  504. dst_img = np.rot90(dst_img)
  505. return dst_img
  506. def reorder_poly_edge(self, points):
  507. """Get the respective points composing head edge, tail edge, top
  508. sideline and bottom sideline.
  509. Args:
  510. points (ndarray): The points composing a text polygon.
  511. Returns:
  512. head_edge (ndarray): The two points composing the head edge of text
  513. polygon.
  514. tail_edge (ndarray): The two points composing the tail edge of text
  515. polygon.
  516. top_sideline (ndarray): The points composing top curved sideline of
  517. text polygon.
  518. bot_sideline (ndarray): The points composing bottom curved sideline
  519. of text polygon.
  520. """
  521. assert points.ndim == 2
  522. assert points.shape[0] >= 4
  523. assert points.shape[1] == 2
  524. orientation_thr=2.0 # 一个经验超参数
  525. head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
  526. head_edge, tail_edge = points[head_inds], points[tail_inds]
  527. pad_points = np.vstack([points, points])
  528. if tail_inds[1] < 1:
  529. tail_inds[1] = len(points)
  530. sideline1 = pad_points[head_inds[1]:tail_inds[1]]
  531. sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
  532. return head_edge, tail_edge, sideline1, sideline2
  533. def vector_slope(self, vec):
  534. assert len(vec) == 2
  535. return abs(vec[1] / (vec[0] + 1e-8))
  536. def find_head_tail(self, points, orientation_thr):
  537. """Find the head edge and tail edge of a text polygon.
  538. Args:
  539. points (ndarray): The points composing a text polygon.
  540. orientation_thr (float): The threshold for distinguishing between
  541. head edge and tail edge among the horizontal and vertical edges
  542. of a quadrangle.
  543. Returns:
  544. head_inds (list): The indexes of two points composing head edge.
  545. tail_inds (list): The indexes of two points composing tail edge.
  546. """
  547. assert points.ndim == 2
  548. assert points.shape[0] >= 4
  549. assert points.shape[1] == 2
  550. assert isinstance(orientation_thr, float)
  551. if len(points) > 4:
  552. pad_points = np.vstack([points, points[0]])
  553. edge_vec = pad_points[1:] - pad_points[:-1]
  554. theta_sum = []
  555. adjacent_vec_theta = []
  556. for i, edge_vec1 in enumerate(edge_vec):
  557. adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
  558. adjacent_edge_vec = edge_vec[adjacent_ind]
  559. temp_theta_sum = np.sum(
  560. self.vector_angle(edge_vec1, adjacent_edge_vec))
  561. temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
  562. adjacent_edge_vec[1])
  563. theta_sum.append(temp_theta_sum)
  564. adjacent_vec_theta.append(temp_adjacent_theta)
  565. theta_sum_score = np.array(theta_sum) / np.pi
  566. adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
  567. poly_center = np.mean(points, axis=0)
  568. edge_dist = np.maximum(
  569. norm(
  570. pad_points[1:] - poly_center, axis=-1),
  571. norm(
  572. pad_points[:-1] - poly_center, axis=-1))
  573. dist_score = edge_dist / np.max(edge_dist)
  574. position_score = np.zeros(len(edge_vec))
  575. score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
  576. score += 0.35 * dist_score
  577. if len(points) % 2 == 0:
  578. position_score[(len(score) // 2 - 1)] += 1
  579. position_score[-1] += 1
  580. score += 0.1 * position_score
  581. pad_score = np.concatenate([score, score])
  582. score_matrix = np.zeros((len(score), len(score) - 3))
  583. x = np.arange(len(score) - 3) / float(len(score) - 4)
  584. gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
  585. (x - 0.5) / 0.5, 2.) / 2)
  586. gaussian = gaussian / np.max(gaussian)
  587. for i in range(len(score)):
  588. score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
  589. score) - 1)] * gaussian * 0.3
  590. head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
  591. score_matrix.shape)
  592. tail_start = (head_start + tail_increment + 2) % len(points)
  593. head_end = (head_start + 1) % len(points)
  594. tail_end = (tail_start + 1) % len(points)
  595. if head_end > tail_end:
  596. head_start, tail_start = tail_start, head_start
  597. head_end, tail_end = tail_end, head_end
  598. head_inds = [head_start, head_end]
  599. tail_inds = [tail_start, tail_end]
  600. else:
  601. if vector_slope(points[1] - points[0]) + vector_slope(points[
  602. 3] - points[2]) < vector_slope(points[2] - points[
  603. 1]) + vector_slope(points[0] - points[3]):
  604. horizontal_edge_inds = [[0, 1], [2, 3]]
  605. vertical_edge_inds = [[3, 0], [1, 2]]
  606. else:
  607. horizontal_edge_inds = [[3, 0], [1, 2]]
  608. vertical_edge_inds = [[0, 1], [2, 3]]
  609. vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
  610. vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
  611. 0]] - points[vertical_edge_inds[1][1]])
  612. horizontal_len_sum = norm(points[horizontal_edge_inds[0][
  613. 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
  614. horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
  615. [1]])
  616. if vertical_len_sum > horizontal_len_sum * orientation_thr:
  617. head_inds = horizontal_edge_inds[0]
  618. tail_inds = horizontal_edge_inds[1]
  619. else:
  620. head_inds = vertical_edge_inds[0]
  621. tail_inds = vertical_edge_inds[1]
  622. return head_inds, tail_inds
  623. def vector_angle(self, vec1, vec2):
  624. if vec1.ndim > 1:
  625. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
  626. else:
  627. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
  628. if vec2.ndim > 1:
  629. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
  630. else:
  631. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
  632. return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
  633. def get_minarea_rect(self, img, points):
  634. bounding_box = cv2.minAreaRect(points)
  635. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  636. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  637. if points[1][1] > points[0][1]:
  638. index_a = 0
  639. index_d = 1
  640. else:
  641. index_a = 1
  642. index_d = 0
  643. if points[3][1] > points[2][1]:
  644. index_b = 2
  645. index_c = 3
  646. else:
  647. index_b = 3
  648. index_c = 2
  649. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  650. crop_img = self.get_rotate_crop_image(img, np.array(box))
  651. return crop_img, box
  652. def sample_points_on_bbox_bp(self, line, n=50):
  653. """Resample n points on a line.
  654. Args:
  655. line (ndarray): The points composing a line.
  656. n (int): The resampled points number.
  657. Returns:
  658. resampled_line (ndarray): The points composing the resampled line.
  659. """
  660. from numpy.linalg import norm
  661. # 断言检查输入参数的有效性
  662. assert line.ndim == 2
  663. assert line.shape[0] >= 2
  664. assert line.shape[1] == 2
  665. assert isinstance(n, int)
  666. assert n > 0
  667. length_list = [
  668. norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
  669. ]
  670. total_length = sum(length_list)
  671. length_cumsum = np.cumsum([0.0] + length_list)
  672. delta_length = total_length / (float(n) + 1e-8)
  673. current_edge_ind = 0
  674. resampled_line = [line[0]]
  675. for i in range(1, n):
  676. current_line_len = i * delta_length
  677. while current_edge_ind + 1 < len(
  678. length_cumsum) and current_line_len >= length_cumsum[
  679. current_edge_ind + 1]:
  680. current_edge_ind += 1
  681. current_edge_end_shift = current_line_len - length_cumsum[
  682. current_edge_ind]
  683. if current_edge_ind >= len(length_list):
  684. break
  685. end_shift_ratio = current_edge_end_shift / length_list[
  686. current_edge_ind]
  687. current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
  688. - line[current_edge_ind]
  689. ) * end_shift_ratio
  690. resampled_line.append(current_point)
  691. resampled_line.append(line[-1])
  692. resampled_line = np.array(resampled_line)
  693. return resampled_line
  694. def sample_points_on_bbox(self, line, n=50):
  695. """Resample n points on a line.
  696. Args:
  697. line (ndarray): The points composing a line.
  698. n (int): The resampled points number.
  699. Returns:
  700. resampled_line (ndarray): The points composing the resampled line.
  701. """
  702. assert line.ndim == 2
  703. assert line.shape[0] >= 2
  704. assert line.shape[1] == 2
  705. assert isinstance(n, int)
  706. assert n > 0
  707. length_list = [
  708. norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
  709. ]
  710. total_length = sum(length_list)
  711. mean_length = total_length / (len(length_list) + 1e-8)
  712. group = [[0]]
  713. for i in range(len(length_list)):
  714. point_id = i+1
  715. if length_list[i] < 0.9 * mean_length:
  716. for g in group:
  717. if i in g:
  718. g.append(point_id)
  719. break
  720. else:
  721. g = [point_id]
  722. group.append(g)
  723. top_tail_len = norm(line[0] - line[-1])
  724. if top_tail_len < 0.9 * mean_length:
  725. group[0].extend(g)
  726. group.remove(g)
  727. mean_positions = []
  728. for indices in group:
  729. x_sum = 0
  730. y_sum = 0
  731. for index in indices:
  732. x, y = line[index]
  733. x_sum += x
  734. y_sum += y
  735. num_points = len(indices)
  736. mean_x = x_sum / num_points
  737. mean_y = y_sum / num_points
  738. mean_positions.append((mean_x, mean_y))
  739. resampled_line = np.array(mean_positions)
  740. return resampled_line
  741. def get_poly_rect_crop(self, img, points):
  742. '''
  743. 修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
  744. args: img: 图片 ndarrary格式
  745. points: polygon格式的多点坐标 N*2 shape, ndarray格式
  746. return: 矫正后的图片 ndarray格式
  747. '''
  748. points = np.array(points).astype(np.int32).reshape(-1, 2)
  749. temp_crop_img, temp_box = self.get_minarea_rect(img, points)
  750. # 计算最小外接矩形与polygon的IoU
  751. def get_union(pD, pG):
  752. return Polygon(pD).union(Polygon(pG)).area
  753. def get_intersection_over_union(pD, pG):
  754. return get_intersection(pD, pG) / (get_union(pD, pG)+ 1e-10)
  755. def get_intersection(pD, pG):
  756. return Polygon(pD).intersection(Polygon(pG)).area
  757. cal_IoU = get_intersection_over_union(points, temp_box)
  758. if cal_IoU >= 0.7:
  759. points = self.sample_points_on_bbox_bp(points, 31)
  760. return temp_crop_img
  761. points_sample = self.sample_points_on_bbox(points)
  762. points_sample = points_sample.astype(np.int32)
  763. head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
  764. resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
  765. resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
  766. sideline_mean_shift = np.mean(
  767. resample_top_line, axis=0) - np.mean(
  768. resample_bot_line, axis=0)
  769. if sideline_mean_shift[1] > 0:
  770. resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
  771. rectifier = AutoRectifier()
  772. new_points = np.concatenate([resample_top_line, resample_bot_line])
  773. new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
  774. if len(img.shape) == 2:
  775. img = np.stack((img,)*3, axis=-1)
  776. img_crop, image = rectifier.run(img, new_points_list, mode='homography')
  777. return img_crop[0]
  778. class SaveTextDetResults(BaseTransform):
  779. """Save Text Det Results"""
  780. def __init__(self, save_dir, task='quad'):
  781. super().__init__()
  782. self.save_dir = save_dir
  783. self.task = task
  784. # We use pillow backend to save both numpy arrays and PIL Image objects
  785. self._writer = ImageWriter(backend="opencv")
  786. def apply(self, data):
  787. """apply"""
  788. if self.save_dir is None:
  789. logging.warning(
  790. "The `save_dir` has been set to None, so the text detection result won't to be saved."
  791. )
  792. return data
  793. fn = os.path.basename(data["input_path"])
  794. save_path = os.path.join(self.save_dir, fn)
  795. bbox_res = data[K.DT_POLYS]
  796. if self.task == "quad":
  797. vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
  798. else:
  799. vis_img = self.draw_polyline(data[K.IM_PATH], bbox_res)
  800. self._writer.write(save_path, vis_img)
  801. return data
  802. @classmethod
  803. def get_input_keys(cls):
  804. """get input keys"""
  805. return [K.IM_PATH, K.DT_POLYS, K.DT_SCORES]
  806. @classmethod
  807. def get_output_keys(cls):
  808. """get output keys"""
  809. return []
  810. def draw_rectangle(self, img_path, boxes):
  811. """draw rectangle"""
  812. boxes = np.array(boxes)
  813. img = cv2.imread(img_path)
  814. img_show = img.copy()
  815. for box in boxes.astype(int):
  816. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  817. cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
  818. return img_show
  819. def draw_polyline(self, img_path, boxes):
  820. """draw polyline"""
  821. img = cv2.imread(img_path)
  822. img_show = img.copy()
  823. for box in boxes:
  824. box = np.array(box).astype(int)
  825. box = np.reshape(box, [-1, 1, 2]).astype(np.int64)
  826. cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
  827. return img_show
  828. class PrintResult(BaseTransform):
  829. """Print Result Transform"""
  830. def apply(self, data):
  831. """apply"""
  832. logging.info("The prediction result is:")
  833. logging.info(data[K.DT_POLYS])
  834. return data
  835. @classmethod
  836. def get_input_keys(cls):
  837. """get input keys"""
  838. return [K.DT_SCORES]
  839. @classmethod
  840. def get_output_keys(cls):
  841. """get output keys"""
  842. return []
  843. # DT_SCORES = 'dt_scores'
  844. # DT_POLYS = 'dt_polys'