crop_image_regions.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple
  15. import copy
  16. import numpy as np
  17. import cv2
  18. from shapely.geometry import Polygon
  19. from numpy.linalg import norm
  20. from .base_operator import BaseOperator
  21. from ....utils.io import ImageReader
  22. from .seal_det_warp import AutoRectifier
  23. class CropByBoxes(BaseOperator):
  24. """Crop Image by Boxes"""
  25. entities = "CropByBoxes"
  26. def __init__(self) -> None:
  27. """Initializes the class."""
  28. super().__init__()
  29. def __call__(self, img: np.ndarray, boxes: list[dict]) -> list[dict]:
  30. """
  31. Process the input image and bounding boxes to produce a list of cropped images
  32. with their corresponding bounding box coordinates and labels.
  33. Args:
  34. img (np.ndarray): The input image as a NumPy array.
  35. boxes (list[dict]): A list of dictionaries, each containing bounding box
  36. information including 'cls_id' (class ID), 'coordinate' (bounding box
  37. coordinates as a list or tuple, left, top, right, bottom),
  38. and optionally 'label' (label text).
  39. Returns:
  40. list[dict]: A list of dictionaries, each containing a cropped image ('img'),
  41. the original bounding box coordinates ('box'), and the label ('label').
  42. """
  43. output_list = []
  44. for bbox_info in boxes:
  45. label_id = bbox_info["cls_id"]
  46. box = bbox_info["coordinate"]
  47. label = bbox_info.get("label", label_id)
  48. xmin, ymin, xmax, ymax = [int(i) for i in box]
  49. img_crop = img[ymin:ymax, xmin:xmax].copy()
  50. output_list.append({"img": img_crop, "box": box, "label": label})
  51. return output_list
  52. class CropByPolys(BaseOperator):
  53. """Crop Image by Polys"""
  54. entities = "CropByPolys"
  55. def __init__(self, det_box_type: str = "quad") -> None:
  56. """
  57. Initializes the operator with a default detection box type.
  58. Args:
  59. det_box_type (str, optional): The type of detection box, quad or poly. Defaults to "quad".
  60. """
  61. super().__init__()
  62. self.det_box_type = det_box_type
  63. def __call__(self, img: np.ndarray, dt_polys: list[list]) -> list[dict]:
  64. """
  65. Call method to crop images based on detection boxes.
  66. Args:
  67. img (nd.ndarray): The input image.
  68. dt_polys (list[list]): List of detection polygons.
  69. Returns:
  70. list[dict]: A list of dictionaries containing cropped images and their sizes.
  71. Raises:
  72. NotImplementedError: If det_box_type is not 'quad' or 'poly'.
  73. """
  74. if self.det_box_type == "quad":
  75. dt_boxes = np.array(dt_polys)
  76. output_list = []
  77. for bno in range(len(dt_boxes)):
  78. tmp_box = copy.deepcopy(dt_boxes[bno])
  79. img_crop = self.get_minarea_rect_crop(img, tmp_box)
  80. output_list.append(img_crop)
  81. elif self.det_box_type == "poly":
  82. output_list = []
  83. dt_boxes = dt_polys
  84. for bno in range(len(dt_boxes)):
  85. tmp_box = copy.deepcopy(dt_boxes[bno])
  86. img_crop = self.get_poly_rect_crop(img.copy(), tmp_box)
  87. output_list.append(img_crop)
  88. else:
  89. raise NotImplementedError
  90. return output_list
  91. def get_minarea_rect_crop(self, img: np.ndarray, points: np.ndarray) -> np.ndarray:
  92. """
  93. Get the minimum area rectangle crop from the given image and points.
  94. Args:
  95. img (np.ndarray): The input image.
  96. points (np.ndarray): A list of points defining the shape to be cropped.
  97. Returns:
  98. np.ndarray: The cropped image with the minimum area rectangle.
  99. """
  100. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  101. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  102. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  103. if points[1][1] > points[0][1]:
  104. index_a = 0
  105. index_d = 1
  106. else:
  107. index_a = 1
  108. index_d = 0
  109. if points[3][1] > points[2][1]:
  110. index_b = 2
  111. index_c = 3
  112. else:
  113. index_b = 3
  114. index_c = 2
  115. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  116. crop_img = self.get_rotate_crop_image(img, np.array(box))
  117. return crop_img
  118. def get_rotate_crop_image(self, img: np.ndarray, points: list) -> np.ndarray:
  119. """
  120. Crop and rotate the input image based on the given four points to form a perspective-transformed image.
  121. Args:
  122. img (np.ndarray): The input image array.
  123. points (list): A list of four 2D points defining the crop region in the image.
  124. Returns:
  125. np.ndarray: The transformed image array.
  126. """
  127. assert len(points) == 4, "shape of points must be 4*2"
  128. img_crop_width = int(
  129. max(
  130. np.linalg.norm(points[0] - points[1]),
  131. np.linalg.norm(points[2] - points[3]),
  132. )
  133. )
  134. img_crop_height = int(
  135. max(
  136. np.linalg.norm(points[0] - points[3]),
  137. np.linalg.norm(points[1] - points[2]),
  138. )
  139. )
  140. pts_std = np.float32(
  141. [
  142. [0, 0],
  143. [img_crop_width, 0],
  144. [img_crop_width, img_crop_height],
  145. [0, img_crop_height],
  146. ]
  147. )
  148. M = cv2.getPerspectiveTransform(points, pts_std)
  149. dst_img = cv2.warpPerspective(
  150. img,
  151. M,
  152. (img_crop_width, img_crop_height),
  153. borderMode=cv2.BORDER_REPLICATE,
  154. flags=cv2.INTER_CUBIC,
  155. )
  156. dst_img_height, dst_img_width = dst_img.shape[0:2]
  157. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  158. dst_img = np.rot90(dst_img)
  159. return dst_img
  160. def reorder_poly_edge(
  161. self, points: np.ndarray
  162. ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
  163. """Get the respective points composing head edge, tail edge, top
  164. sideline and bottom sideline.
  165. Args:
  166. points (ndarray): The points composing a text polygon.
  167. Returns:
  168. head_edge (ndarray): The two points composing the head edge of text
  169. polygon.
  170. tail_edge (ndarray): The two points composing the tail edge of text
  171. polygon.
  172. top_sideline (ndarray): The points composing top curved sideline of
  173. text polygon.
  174. bot_sideline (ndarray): The points composing bottom curved sideline
  175. of text polygon.
  176. """
  177. assert points.ndim == 2
  178. assert points.shape[0] >= 4
  179. assert points.shape[1] == 2
  180. orientation_thr = 2.0 # 一个经验超参数
  181. head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
  182. head_edge, tail_edge = points[head_inds], points[tail_inds]
  183. pad_points = np.vstack([points, points])
  184. if tail_inds[1] < 1:
  185. tail_inds[1] = len(points)
  186. sideline1 = pad_points[head_inds[1] : tail_inds[1]]
  187. sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
  188. return head_edge, tail_edge, sideline1, sideline2
  189. def vector_slope(self, vec: list) -> float:
  190. """
  191. Calculate the slope of a vector in 2D space.
  192. Args:
  193. vec (list): A list of two elements representing the coordinates of the vector.
  194. Returns:
  195. float: The slope of the vector.
  196. Raises:
  197. AssertionError: If the length of the vector is not equal to 2.
  198. """
  199. assert len(vec) == 2
  200. return abs(vec[1] / (vec[0] + 1e-8))
  201. def find_head_tail(
  202. self, points: np.ndarray, orientation_thr: float
  203. ) -> tuple[list, list]:
  204. """Find the head edge and tail edge of a text polygon.
  205. Args:
  206. points (ndarray): The points composing a text polygon.
  207. orientation_thr (float): The threshold for distinguishing between
  208. head edge and tail edge among the horizontal and vertical edges
  209. of a quadrangle.
  210. Returns:
  211. head_inds (list): The indexes of two points composing head edge.
  212. tail_inds (list): The indexes of two points composing tail edge.
  213. """
  214. assert points.ndim == 2
  215. assert points.shape[0] >= 4
  216. assert points.shape[1] == 2
  217. assert isinstance(orientation_thr, float)
  218. if len(points) > 4:
  219. pad_points = np.vstack([points, points[0]])
  220. edge_vec = pad_points[1:] - pad_points[:-1]
  221. theta_sum = []
  222. adjacent_vec_theta = []
  223. for i, edge_vec1 in enumerate(edge_vec):
  224. adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
  225. adjacent_edge_vec = edge_vec[adjacent_ind]
  226. temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
  227. temp_adjacent_theta = self.vector_angle(
  228. adjacent_edge_vec[0], adjacent_edge_vec[1]
  229. )
  230. theta_sum.append(temp_theta_sum)
  231. adjacent_vec_theta.append(temp_adjacent_theta)
  232. theta_sum_score = np.array(theta_sum) / np.pi
  233. adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
  234. poly_center = np.mean(points, axis=0)
  235. edge_dist = np.maximum(
  236. norm(pad_points[1:] - poly_center, axis=-1),
  237. norm(pad_points[:-1] - poly_center, axis=-1),
  238. )
  239. dist_score = edge_dist / np.max(edge_dist)
  240. position_score = np.zeros(len(edge_vec))
  241. score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
  242. score += 0.35 * dist_score
  243. if len(points) % 2 == 0:
  244. position_score[(len(score) // 2 - 1)] += 1
  245. position_score[-1] += 1
  246. score += 0.1 * position_score
  247. pad_score = np.concatenate([score, score])
  248. score_matrix = np.zeros((len(score), len(score) - 3))
  249. x = np.arange(len(score) - 3) / float(len(score) - 4)
  250. gaussian = (
  251. 1.0
  252. / (np.sqrt(2.0 * np.pi) * 0.5)
  253. * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
  254. )
  255. gaussian = gaussian / np.max(gaussian)
  256. for i in range(len(score)):
  257. score_matrix[i, :] = (
  258. score[i]
  259. + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
  260. )
  261. head_start, tail_increment = np.unravel_index(
  262. score_matrix.argmax(), score_matrix.shape
  263. )
  264. tail_start = (head_start + tail_increment + 2) % len(points)
  265. head_end = (head_start + 1) % len(points)
  266. tail_end = (tail_start + 1) % len(points)
  267. if head_end > tail_end:
  268. head_start, tail_start = tail_start, head_start
  269. head_end, tail_end = tail_end, head_end
  270. head_inds = [head_start, head_end]
  271. tail_inds = [tail_start, tail_end]
  272. else:
  273. if self.vector_slope(points[1] - points[0]) + self.vector_slope(
  274. points[3] - points[2]
  275. ) < self.vector_slope(points[2] - points[1]) + self.vector_slope(
  276. points[0] - points[3]
  277. ):
  278. horizontal_edge_inds = [[0, 1], [2, 3]]
  279. vertical_edge_inds = [[3, 0], [1, 2]]
  280. else:
  281. horizontal_edge_inds = [[3, 0], [1, 2]]
  282. vertical_edge_inds = [[0, 1], [2, 3]]
  283. vertical_len_sum = norm(
  284. points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
  285. ) + norm(
  286. points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
  287. )
  288. horizontal_len_sum = norm(
  289. points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
  290. ) + norm(
  291. points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
  292. )
  293. if vertical_len_sum > horizontal_len_sum * orientation_thr:
  294. head_inds = horizontal_edge_inds[0]
  295. tail_inds = horizontal_edge_inds[1]
  296. else:
  297. head_inds = vertical_edge_inds[0]
  298. tail_inds = vertical_edge_inds[1]
  299. return head_inds, tail_inds
  300. def vector_angle(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
  301. """
  302. Calculate the angle between two vectors.
  303. Args:
  304. vec1 (ndarray): The first vector.
  305. vec2 (ndarray): The second vector.
  306. Returns:
  307. float: The angle between the two vectors in radians.
  308. """
  309. if vec1.ndim > 1:
  310. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
  311. else:
  312. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
  313. if vec2.ndim > 1:
  314. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
  315. else:
  316. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
  317. return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
  318. def get_minarea_rect(
  319. self, img: np.ndarray, points: np.ndarray
  320. ) -> tuple[np.ndarray, list]:
  321. """
  322. Get the minimum area rectangle for the given points and crop the image accordingly.
  323. Args:
  324. img (np.ndarray): The input image.
  325. points (np.ndarray): The points to compute the minimum area rectangle for.
  326. Returns:
  327. tuple[np.ndarray, list]: The cropped image,
  328. and the list of points in the order of the bounding box.
  329. """
  330. bounding_box = cv2.minAreaRect(points)
  331. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  332. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  333. if points[1][1] > points[0][1]:
  334. index_a = 0
  335. index_d = 1
  336. else:
  337. index_a = 1
  338. index_d = 0
  339. if points[3][1] > points[2][1]:
  340. index_b = 2
  341. index_c = 3
  342. else:
  343. index_b = 3
  344. index_c = 2
  345. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  346. crop_img = self.get_rotate_crop_image(img, np.array(box))
  347. return crop_img, box
  348. def sample_points_on_bbox_bp(self, line, n=50):
  349. """Resample n points on a line.
  350. Args:
  351. line (ndarray): The points composing a line.
  352. n (int): The resampled points number.
  353. Returns:
  354. resampled_line (ndarray): The points composing the resampled line.
  355. """
  356. from numpy.linalg import norm
  357. # 断言检查输入参数的有效性
  358. assert line.ndim == 2
  359. assert line.shape[0] >= 2
  360. assert line.shape[1] == 2
  361. assert isinstance(n, int)
  362. assert n > 0
  363. length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
  364. total_length = sum(length_list)
  365. length_cumsum = np.cumsum([0.0] + length_list)
  366. delta_length = total_length / (float(n) + 1e-8)
  367. current_edge_ind = 0
  368. resampled_line = [line[0]]
  369. for i in range(1, n):
  370. current_line_len = i * delta_length
  371. while (
  372. current_edge_ind + 1 < len(length_cumsum)
  373. and current_line_len >= length_cumsum[current_edge_ind + 1]
  374. ):
  375. current_edge_ind += 1
  376. current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
  377. if current_edge_ind >= len(length_list):
  378. break
  379. end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
  380. current_point = (
  381. line[current_edge_ind]
  382. + (line[current_edge_ind + 1] - line[current_edge_ind])
  383. * end_shift_ratio
  384. )
  385. resampled_line.append(current_point)
  386. resampled_line.append(line[-1])
  387. resampled_line = np.array(resampled_line)
  388. return resampled_line
  389. def sample_points_on_bbox(self, line, n=50):
  390. """Resample n points on a line.
  391. Args:
  392. line (ndarray): The points composing a line.
  393. n (int): The resampled points number.
  394. Returns:
  395. resampled_line (ndarray): The points composing the resampled line.
  396. """
  397. assert line.ndim == 2
  398. assert line.shape[0] >= 2
  399. assert line.shape[1] == 2
  400. assert isinstance(n, int)
  401. assert n > 0
  402. length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
  403. total_length = sum(length_list)
  404. mean_length = total_length / (len(length_list) + 1e-8)
  405. group = [[0]]
  406. for i in range(len(length_list)):
  407. point_id = i + 1
  408. if length_list[i] < 0.9 * mean_length:
  409. for g in group:
  410. if i in g:
  411. g.append(point_id)
  412. break
  413. else:
  414. g = [point_id]
  415. group.append(g)
  416. top_tail_len = norm(line[0] - line[-1])
  417. if top_tail_len < 0.9 * mean_length:
  418. group[0].extend(g)
  419. group.remove(g)
  420. mean_positions = []
  421. for indices in group:
  422. x_sum = 0
  423. y_sum = 0
  424. for index in indices:
  425. x, y = line[index]
  426. x_sum += x
  427. y_sum += y
  428. num_points = len(indices)
  429. mean_x = x_sum / num_points
  430. mean_y = y_sum / num_points
  431. mean_positions.append((mean_x, mean_y))
  432. resampled_line = np.array(mean_positions)
  433. return resampled_line
  434. def get_poly_rect_crop(self, img, points):
  435. """
  436. 修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
  437. args: img: 图片 ndarrary格式
  438. points: polygon格式的多点坐标 N*2 shape, ndarray格式
  439. return: 矫正后的图片 ndarray格式
  440. """
  441. points = np.array(points).astype(np.int32).reshape(-1, 2)
  442. temp_crop_img, temp_box = self.get_minarea_rect(img, points)
  443. # 计算最小外接矩形与polygon的IoU
  444. def get_union(pD, pG):
  445. return Polygon(pD).union(Polygon(pG)).area
  446. def get_intersection_over_union(pD, pG):
  447. return get_intersection(pD, pG) / (get_union(pD, pG) + 1e-10)
  448. def get_intersection(pD, pG):
  449. return Polygon(pD).intersection(Polygon(pG)).area
  450. if not Polygon(points).is_valid:
  451. return temp_crop_img
  452. cal_IoU = get_intersection_over_union(points, temp_box)
  453. if cal_IoU >= 0.7:
  454. points = self.sample_points_on_bbox_bp(points, 31)
  455. return temp_crop_img
  456. points_sample = self.sample_points_on_bbox(points)
  457. points_sample = points_sample.astype(np.int32)
  458. head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
  459. resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
  460. resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
  461. sideline_mean_shift = np.mean(resample_top_line, axis=0) - np.mean(
  462. resample_bot_line, axis=0
  463. )
  464. if sideline_mean_shift[1] > 0:
  465. resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
  466. rectifier = AutoRectifier()
  467. new_points = np.concatenate([resample_top_line, resample_bot_line])
  468. new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
  469. if len(img.shape) == 2:
  470. img = np.stack((img,) * 3, axis=-1)
  471. img_crop, image = rectifier.run(img, new_points_list, mode="homography")
  472. return np.array(img_crop[0], dtype=np.uint8)