crop_image_regions.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. from typing import List, Tuple
  16. import numpy as np
  17. from numpy.linalg import norm
  18. from .....utils.deps import class_requires_deps, is_dep_available
  19. from .base_operator import BaseOperator
  20. from .seal_det_warp import AutoRectifier
  21. if is_dep_available("opencv-contrib-python"):
  22. import cv2
  23. if is_dep_available("shapely"):
  24. from shapely.geometry import Polygon
  25. class CropByBoxes(BaseOperator):
  26. """Crop Image by Boxes"""
  27. entities = "CropByBoxes"
  28. def __init__(self) -> None:
  29. """Initializes the class."""
  30. super().__init__()
  31. def __call__(self, img: np.ndarray, boxes: List[dict]) -> List[dict]:
  32. """
  33. Process the input image and bounding boxes to produce a list of cropped images
  34. with their corresponding bounding box coordinates and labels.
  35. Args:
  36. img (np.ndarray): The input image as a NumPy array.
  37. boxes (list[dict]): A list of dictionaries, each containing bounding box
  38. information including 'cls_id' (class ID), 'coordinate' (bounding box
  39. coordinates as a list or tuple, left, top, right, bottom),
  40. and optionally 'label' (label text).
  41. Returns:
  42. list[dict]: A list of dictionaries, each containing a cropped image ('img'),
  43. the original bounding box coordinates ('box'), and the label ('label').
  44. """
  45. output_list = []
  46. for bbox_info in boxes:
  47. label_id = bbox_info["cls_id"]
  48. box = bbox_info["coordinate"]
  49. label = bbox_info.get("label", label_id)
  50. xmin, ymin, xmax, ymax = [int(i) for i in box]
  51. img_crop = img[ymin:ymax, xmin:xmax].copy()
  52. output_list.append({"img": img_crop, "box": box, "label": label})
  53. return output_list
  54. @class_requires_deps("opencv-contrib-python", "shapely")
  55. class CropByPolys(BaseOperator):
  56. """Crop Image by Polys"""
  57. entities = "CropByPolys"
  58. def __init__(self, det_box_type: str = "quad") -> None:
  59. """
  60. Initializes the operator with a default detection box type.
  61. Args:
  62. det_box_type (str, optional): The type of detection box, quad or poly. Defaults to "quad".
  63. """
  64. super().__init__()
  65. self.det_box_type = det_box_type
  66. def __call__(self, img: np.ndarray, dt_polys: List[list]) -> List[dict]:
  67. """
  68. Call method to crop images based on detection boxes.
  69. Args:
  70. img (nd.ndarray): The input image.
  71. dt_polys (list[list]): List of detection polygons.
  72. Returns:
  73. list[dict]: A list of dictionaries containing cropped images and their sizes.
  74. Raises:
  75. NotImplementedError: If det_box_type is not 'quad' or 'poly'.
  76. """
  77. if self.det_box_type == "quad":
  78. dt_boxes = np.array(dt_polys)
  79. output_list = []
  80. for bno in range(len(dt_boxes)):
  81. tmp_box = copy.deepcopy(dt_boxes[bno])
  82. img_crop = self.get_minarea_rect_crop(img, tmp_box)
  83. output_list.append(img_crop)
  84. elif self.det_box_type == "poly":
  85. output_list = []
  86. dt_boxes = dt_polys
  87. for bno in range(len(dt_boxes)):
  88. tmp_box = copy.deepcopy(dt_boxes[bno])
  89. img_crop = self.get_poly_rect_crop(img.copy(), tmp_box)
  90. output_list.append(img_crop)
  91. else:
  92. raise NotImplementedError
  93. return output_list
  94. def get_minarea_rect_crop(self, img: np.ndarray, points: np.ndarray) -> np.ndarray:
  95. """
  96. Get the minimum area rectangle crop from the given image and points.
  97. Args:
  98. img (np.ndarray): The input image.
  99. points (np.ndarray): A list of points defining the shape to be cropped.
  100. Returns:
  101. np.ndarray: The cropped image with the minimum area rectangle.
  102. """
  103. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  104. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  105. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  106. if points[1][1] > points[0][1]:
  107. index_a = 0
  108. index_d = 1
  109. else:
  110. index_a = 1
  111. index_d = 0
  112. if points[3][1] > points[2][1]:
  113. index_b = 2
  114. index_c = 3
  115. else:
  116. index_b = 3
  117. index_c = 2
  118. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  119. crop_img = self.get_rotate_crop_image(img, np.array(box))
  120. return crop_img
  121. def get_rotate_crop_image(self, img: np.ndarray, points: list) -> np.ndarray:
  122. """
  123. Crop and rotate the input image based on the given four points to form a perspective-transformed image.
  124. Args:
  125. img (np.ndarray): The input image array.
  126. points (list): A list of four 2D points defining the crop region in the image.
  127. Returns:
  128. np.ndarray: The transformed image array.
  129. """
  130. assert len(points) == 4, "shape of points must be 4*2"
  131. img_crop_width = int(
  132. max(
  133. np.linalg.norm(points[0] - points[1]),
  134. np.linalg.norm(points[2] - points[3]),
  135. )
  136. )
  137. img_crop_height = int(
  138. max(
  139. np.linalg.norm(points[0] - points[3]),
  140. np.linalg.norm(points[1] - points[2]),
  141. )
  142. )
  143. pts_std = np.float32(
  144. [
  145. [0, 0],
  146. [img_crop_width, 0],
  147. [img_crop_width, img_crop_height],
  148. [0, img_crop_height],
  149. ]
  150. )
  151. M = cv2.getPerspectiveTransform(points, pts_std)
  152. dst_img = cv2.warpPerspective(
  153. img,
  154. M,
  155. (img_crop_width, img_crop_height),
  156. borderMode=cv2.BORDER_REPLICATE,
  157. flags=cv2.INTER_CUBIC,
  158. )
  159. dst_img_height, dst_img_width = dst_img.shape[0:2]
  160. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  161. dst_img = np.rot90(dst_img)
  162. return dst_img
  163. def reorder_poly_edge(
  164. self, points: np.ndarray
  165. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
  166. """Get the respective points composing head edge, tail edge, top
  167. sideline and bottom sideline.
  168. Args:
  169. points (ndarray): The points composing a text polygon.
  170. Returns:
  171. head_edge (ndarray): The two points composing the head edge of text
  172. polygon.
  173. tail_edge (ndarray): The two points composing the tail edge of text
  174. polygon.
  175. top_sideline (ndarray): The points composing top curved sideline of
  176. text polygon.
  177. bot_sideline (ndarray): The points composing bottom curved sideline
  178. of text polygon.
  179. """
  180. assert points.ndim == 2
  181. assert points.shape[0] >= 4
  182. assert points.shape[1] == 2
  183. orientation_thr = 2.0 # 一个经验超参数
  184. head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
  185. head_edge, tail_edge = points[head_inds], points[tail_inds]
  186. pad_points = np.vstack([points, points])
  187. if tail_inds[1] < 1:
  188. tail_inds[1] = len(points)
  189. sideline1 = pad_points[head_inds[1] : tail_inds[1]]
  190. sideline2 = pad_points[tail_inds[1] : (head_inds[1] + len(points))]
  191. return head_edge, tail_edge, sideline1, sideline2
  192. def vector_slope(self, vec: list) -> float:
  193. """
  194. Calculate the slope of a vector in 2D space.
  195. Args:
  196. vec (list): A list of two elements representing the coordinates of the vector.
  197. Returns:
  198. float: The slope of the vector.
  199. Raises:
  200. AssertionError: If the length of the vector is not equal to 2.
  201. """
  202. assert len(vec) == 2
  203. return abs(vec[1] / (vec[0] + 1e-8))
  204. def find_head_tail(
  205. self, points: np.ndarray, orientation_thr: float
  206. ) -> Tuple[list, list]:
  207. """Find the head edge and tail edge of a text polygon.
  208. Args:
  209. points (ndarray): The points composing a text polygon.
  210. orientation_thr (float): The threshold for distinguishing between
  211. head edge and tail edge among the horizontal and vertical edges
  212. of a quadrangle.
  213. Returns:
  214. head_inds (list): The indexes of two points composing head edge.
  215. tail_inds (list): The indexes of two points composing tail edge.
  216. """
  217. assert points.ndim == 2
  218. assert points.shape[0] >= 4
  219. assert points.shape[1] == 2
  220. assert isinstance(orientation_thr, float)
  221. if len(points) > 4:
  222. pad_points = np.vstack([points, points[0]])
  223. edge_vec = pad_points[1:] - pad_points[:-1]
  224. theta_sum = []
  225. adjacent_vec_theta = []
  226. for i, edge_vec1 in enumerate(edge_vec):
  227. adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
  228. adjacent_edge_vec = edge_vec[adjacent_ind]
  229. temp_theta_sum = np.sum(self.vector_angle(edge_vec1, adjacent_edge_vec))
  230. temp_adjacent_theta = self.vector_angle(
  231. adjacent_edge_vec[0], adjacent_edge_vec[1]
  232. )
  233. theta_sum.append(temp_theta_sum)
  234. adjacent_vec_theta.append(temp_adjacent_theta)
  235. theta_sum_score = np.array(theta_sum) / np.pi
  236. adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
  237. poly_center = np.mean(points, axis=0)
  238. edge_dist = np.maximum(
  239. norm(pad_points[1:] - poly_center, axis=-1),
  240. norm(pad_points[:-1] - poly_center, axis=-1),
  241. )
  242. dist_score = edge_dist / np.max(edge_dist)
  243. position_score = np.zeros(len(edge_vec))
  244. score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
  245. score += 0.35 * dist_score
  246. if len(points) % 2 == 0:
  247. position_score[(len(score) // 2 - 1)] += 1
  248. position_score[-1] += 1
  249. score += 0.1 * position_score
  250. pad_score = np.concatenate([score, score])
  251. score_matrix = np.zeros((len(score), len(score) - 3))
  252. x = np.arange(len(score) - 3) / float(len(score) - 4)
  253. gaussian = (
  254. 1.0
  255. / (np.sqrt(2.0 * np.pi) * 0.5)
  256. * np.exp(-np.power((x - 0.5) / 0.5, 2.0) / 2)
  257. )
  258. gaussian = gaussian / np.max(gaussian)
  259. for i in range(len(score)):
  260. score_matrix[i, :] = (
  261. score[i]
  262. + pad_score[(i + 2) : (i + len(score) - 1)] * gaussian * 0.3
  263. )
  264. head_start, tail_increment = np.unravel_index(
  265. score_matrix.argmax(), score_matrix.shape
  266. )
  267. tail_start = (head_start + tail_increment + 2) % len(points)
  268. head_end = (head_start + 1) % len(points)
  269. tail_end = (tail_start + 1) % len(points)
  270. if head_end > tail_end:
  271. head_start, tail_start = tail_start, head_start
  272. head_end, tail_end = tail_end, head_end
  273. head_inds = [head_start, head_end]
  274. tail_inds = [tail_start, tail_end]
  275. else:
  276. if self.vector_slope(points[1] - points[0]) + self.vector_slope(
  277. points[3] - points[2]
  278. ) < self.vector_slope(points[2] - points[1]) + self.vector_slope(
  279. points[0] - points[3]
  280. ):
  281. horizontal_edge_inds = [[0, 1], [2, 3]]
  282. vertical_edge_inds = [[3, 0], [1, 2]]
  283. else:
  284. horizontal_edge_inds = [[3, 0], [1, 2]]
  285. vertical_edge_inds = [[0, 1], [2, 3]]
  286. vertical_len_sum = norm(
  287. points[vertical_edge_inds[0][0]] - points[vertical_edge_inds[0][1]]
  288. ) + norm(
  289. points[vertical_edge_inds[1][0]] - points[vertical_edge_inds[1][1]]
  290. )
  291. horizontal_len_sum = norm(
  292. points[horizontal_edge_inds[0][0]] - points[horizontal_edge_inds[0][1]]
  293. ) + norm(
  294. points[horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1][1]]
  295. )
  296. if vertical_len_sum > horizontal_len_sum * orientation_thr:
  297. head_inds = horizontal_edge_inds[0]
  298. tail_inds = horizontal_edge_inds[1]
  299. else:
  300. head_inds = vertical_edge_inds[0]
  301. tail_inds = vertical_edge_inds[1]
  302. return head_inds, tail_inds
  303. def vector_angle(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
  304. """
  305. Calculate the angle between two vectors.
  306. Args:
  307. vec1 (ndarray): The first vector.
  308. vec2 (ndarray): The second vector.
  309. Returns:
  310. float: The angle between the two vectors in radians.
  311. """
  312. if vec1.ndim > 1:
  313. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
  314. else:
  315. unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
  316. if vec2.ndim > 1:
  317. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
  318. else:
  319. unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
  320. return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
  321. def get_minarea_rect(
  322. self, img: np.ndarray, points: np.ndarray
  323. ) -> Tuple[np.ndarray, list]:
  324. """
  325. Get the minimum area rectangle for the given points and crop the image accordingly.
  326. Args:
  327. img (np.ndarray): The input image.
  328. points (np.ndarray): The points to compute the minimum area rectangle for.
  329. Returns:
  330. tuple[np.ndarray, list]: The cropped image,
  331. and the list of points in the order of the bounding box.
  332. """
  333. bounding_box = cv2.minAreaRect(points)
  334. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  335. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  336. if points[1][1] > points[0][1]:
  337. index_a = 0
  338. index_d = 1
  339. else:
  340. index_a = 1
  341. index_d = 0
  342. if points[3][1] > points[2][1]:
  343. index_b = 2
  344. index_c = 3
  345. else:
  346. index_b = 3
  347. index_c = 2
  348. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  349. crop_img = self.get_rotate_crop_image(img, np.array(box))
  350. return crop_img, box
  351. def sample_points_on_bbox_bp(self, line, n=50):
  352. """Resample n points on a line.
  353. Args:
  354. line (ndarray): The points composing a line.
  355. n (int): The resampled points number.
  356. Returns:
  357. resampled_line (ndarray): The points composing the resampled line.
  358. """
  359. from numpy.linalg import norm
  360. # 断言检查输入参数的有效性
  361. assert line.ndim == 2
  362. assert line.shape[0] >= 2
  363. assert line.shape[1] == 2
  364. assert isinstance(n, int)
  365. assert n > 0
  366. length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
  367. total_length = sum(length_list)
  368. length_cumsum = np.cumsum([0.0] + length_list)
  369. delta_length = total_length / (float(n) + 1e-8)
  370. current_edge_ind = 0
  371. resampled_line = [line[0]]
  372. for i in range(1, n):
  373. current_line_len = i * delta_length
  374. while (
  375. current_edge_ind + 1 < len(length_cumsum)
  376. and current_line_len >= length_cumsum[current_edge_ind + 1]
  377. ):
  378. current_edge_ind += 1
  379. current_edge_end_shift = current_line_len - length_cumsum[current_edge_ind]
  380. if current_edge_ind >= len(length_list):
  381. break
  382. end_shift_ratio = current_edge_end_shift / length_list[current_edge_ind]
  383. current_point = (
  384. line[current_edge_ind]
  385. + (line[current_edge_ind + 1] - line[current_edge_ind])
  386. * end_shift_ratio
  387. )
  388. resampled_line.append(current_point)
  389. resampled_line.append(line[-1])
  390. resampled_line = np.array(resampled_line)
  391. return resampled_line
  392. def sample_points_on_bbox(self, line, n=50):
  393. """Resample n points on a line.
  394. Args:
  395. line (ndarray): The points composing a line.
  396. n (int): The resampled points number.
  397. Returns:
  398. resampled_line (ndarray): The points composing the resampled line.
  399. """
  400. assert line.ndim == 2
  401. assert line.shape[0] >= 2
  402. assert line.shape[1] == 2
  403. assert isinstance(n, int)
  404. assert n > 0
  405. length_list = [norm(line[i + 1] - line[i]) for i in range(len(line) - 1)]
  406. total_length = sum(length_list)
  407. mean_length = total_length / (len(length_list) + 1e-8)
  408. group = [[0]]
  409. for i in range(len(length_list)):
  410. point_id = i + 1
  411. if length_list[i] < 0.9 * mean_length:
  412. for g in group:
  413. if i in g:
  414. g.append(point_id)
  415. break
  416. else:
  417. g = [point_id]
  418. group.append(g)
  419. top_tail_len = norm(line[0] - line[-1])
  420. if top_tail_len < 0.9 * mean_length:
  421. group[0].extend(g)
  422. group.remove(g)
  423. mean_positions = []
  424. for indices in group:
  425. x_sum = 0
  426. y_sum = 0
  427. for index in indices:
  428. x, y = line[index]
  429. x_sum += x
  430. y_sum += y
  431. num_points = len(indices)
  432. mean_x = x_sum / num_points
  433. mean_y = y_sum / num_points
  434. mean_positions.append((mean_x, mean_y))
  435. resampled_line = np.array(mean_positions)
  436. return resampled_line
  437. def get_poly_rect_crop(self, img, points):
  438. """
  439. 修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
  440. args: img: 图片 ndarrary格式
  441. points: polygon格式的多点坐标 N*2 shape, ndarray格式
  442. return: 矫正后的图片 ndarray格式
  443. """
  444. points = np.array(points).astype(np.int32).reshape(-1, 2)
  445. temp_crop_img, temp_box = self.get_minarea_rect(img, points)
  446. # 计算最小外接矩形与polygon的IoU
  447. def get_union(pD, pG):
  448. return Polygon(pD).union(Polygon(pG)).area
  449. def get_intersection_over_union(pD, pG):
  450. return get_intersection(pD, pG) / (get_union(pD, pG) + 1e-10)
  451. def get_intersection(pD, pG):
  452. return Polygon(pD).intersection(Polygon(pG)).area
  453. if not Polygon(points).is_valid:
  454. return temp_crop_img
  455. cal_IoU = get_intersection_over_union(points, temp_box)
  456. if cal_IoU >= 0.7:
  457. points = self.sample_points_on_bbox_bp(points, 31)
  458. return temp_crop_img
  459. points_sample = self.sample_points_on_bbox(points)
  460. points_sample = points_sample.astype(np.int32)
  461. head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
  462. resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
  463. resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
  464. sideline_mean_shift = np.mean(resample_top_line, axis=0) - np.mean(
  465. resample_bot_line, axis=0
  466. )
  467. if sideline_mean_shift[1] > 0:
  468. resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
  469. rectifier = AutoRectifier()
  470. new_points = np.concatenate([resample_top_line, resample_bot_line])
  471. new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
  472. if len(img.shape) == 2:
  473. img = np.stack((img,) * 3, axis=-1)
  474. img_crop, image = rectifier.run(img, new_points_list, mode="homography")
  475. return np.array(img_crop[0], dtype=np.uint8)