xycut.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. from typing import List
  2. import cv2
  3. import numpy as np
  4. def projection_by_bboxes(boxes: np.array, axis: int) -> np.ndarray:
  5. """
  6. 通过一组 bbox 获得投影直方图,最后以 per-pixel 形式输出
  7. Args:
  8. boxes: [N, 4]
  9. axis: 0-x坐标向水平方向投影, 1-y坐标向垂直方向投影
  10. Returns:
  11. 1D 投影直方图,长度为投影方向坐标的最大值(我们不需要图片的实际边长,因为只是要找文本框的间隔)
  12. """
  13. assert axis in [0, 1]
  14. length = np.max(boxes[:, axis::2])
  15. res = np.zeros(length, dtype=int)
  16. # TODO: how to remove for loop?
  17. for start, end in boxes[:, axis::2]:
  18. res[start:end] += 1
  19. return res
  20. # from: https://dothinking.github.io/2021-06-19-%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%E7%AE%97%E6%B3%95/#:~:text=%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%EF%BC%88Recursive%20XY,%EF%BC%8C%E5%8F%AF%E4%BB%A5%E5%88%92%E5%88%86%E6%AE%B5%E8%90%BD%E3%80%81%E8%A1%8C%E3%80%82
  21. def split_projection_profile(arr_values: np.array, min_value: float, min_gap: float):
  22. """Split projection profile:
  23. ```
  24. ┌──┐
  25. arr_values │ │ ┌─┐───
  26. ┌──┐ │ │ │ │ |
  27. │ │ │ │ ┌───┐ │ │min_value
  28. │ │<- min_gap ->│ │ │ │ │ │ |
  29. ────┴──┴─────────────┴──┴─┴───┴─┴─┴─┴───
  30. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
  31. ```
  32. Args:
  33. arr_values (np.array): 1-d array representing the projection profile.
  34. min_value (float): Ignore the profile if `arr_value` is less than `min_value`.
  35. min_gap (float): Ignore the gap if less than this value.
  36. Returns:
  37. tuple: Start indexes and end indexes of split groups.
  38. """
  39. # all indexes with projection height exceeding the threshold
  40. arr_index = np.where(arr_values > min_value)[0]
  41. if not len(arr_index):
  42. return
  43. # find zero intervals between adjacent projections
  44. # | | ||
  45. # ||||<- zero-interval -> |||||
  46. arr_diff = arr_index[1:] - arr_index[0:-1]
  47. arr_diff_index = np.where(arr_diff > min_gap)[0]
  48. arr_zero_intvl_start = arr_index[arr_diff_index]
  49. arr_zero_intvl_end = arr_index[arr_diff_index + 1]
  50. # convert to index of projection range:
  51. # the start index of zero interval is the end index of projection
  52. arr_start = np.insert(arr_zero_intvl_end, 0, arr_index[0])
  53. arr_end = np.append(arr_zero_intvl_start, arr_index[-1])
  54. arr_end += 1 # end index will be excluded as index slice
  55. return arr_start, arr_end
  56. def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int]):
  57. """
  58. Args:
  59. boxes: (N, 4)
  60. indices: 递归过程中始终表示 box 在原始数据中的索引
  61. res: 保存输出结果
  62. """
  63. # 向 y 轴投影
  64. assert len(boxes) == len(indices)
  65. _indices = boxes[:, 1].argsort()
  66. y_sorted_boxes = boxes[_indices]
  67. y_sorted_indices = indices[_indices]
  68. # debug_vis(y_sorted_boxes, y_sorted_indices)
  69. y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
  70. pos_y = split_projection_profile(y_projection, 0, 1)
  71. if not pos_y:
  72. return
  73. arr_y0, arr_y1 = pos_y
  74. for r0, r1 in zip(arr_y0, arr_y1):
  75. # [r0, r1] 表示按照水平切分,有 bbox 的区域,对这些区域会再进行垂直切分
  76. _indices = (r0 <= y_sorted_boxes[:, 1]) & (y_sorted_boxes[:, 1] < r1)
  77. y_sorted_boxes_chunk = y_sorted_boxes[_indices]
  78. y_sorted_indices_chunk = y_sorted_indices[_indices]
  79. _indices = y_sorted_boxes_chunk[:, 0].argsort()
  80. x_sorted_boxes_chunk = y_sorted_boxes_chunk[_indices]
  81. x_sorted_indices_chunk = y_sorted_indices_chunk[_indices]
  82. # 往 x 方向投影
  83. x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
  84. pos_x = split_projection_profile(x_projection, 0, 1)
  85. if not pos_x:
  86. continue
  87. arr_x0, arr_x1 = pos_x
  88. if len(arr_x0) == 1:
  89. # x 方向无法切分
  90. res.extend(x_sorted_indices_chunk)
  91. continue
  92. # x 方向上能分开,继续递归调用
  93. for c0, c1 in zip(arr_x0, arr_x1):
  94. _indices = (c0 <= x_sorted_boxes_chunk[:, 0]) & (
  95. x_sorted_boxes_chunk[:, 0] < c1
  96. )
  97. recursive_xy_cut(
  98. x_sorted_boxes_chunk[_indices], x_sorted_indices_chunk[_indices], res
  99. )
  100. def points_to_bbox(points):
  101. assert len(points) == 8
  102. # [x1,y1,x2,y2,x3,y3,x4,y4]
  103. left = min(points[::2])
  104. right = max(points[::2])
  105. top = min(points[1::2])
  106. bottom = max(points[1::2])
  107. left = max(left, 0)
  108. top = max(top, 0)
  109. right = max(right, 0)
  110. bottom = max(bottom, 0)
  111. return [left, top, right, bottom]
  112. def bbox2points(bbox):
  113. left, top, right, bottom = bbox
  114. return [left, top, right, top, right, bottom, left, bottom]
  115. def vis_polygon(img, points, thickness=2, color=None):
  116. br2bl_color = color
  117. tl2tr_color = color
  118. tr2br_color = color
  119. bl2tl_color = color
  120. cv2.line(
  121. img,
  122. (points[0][0], points[0][1]),
  123. (points[1][0], points[1][1]),
  124. color=tl2tr_color,
  125. thickness=thickness,
  126. )
  127. cv2.line(
  128. img,
  129. (points[1][0], points[1][1]),
  130. (points[2][0], points[2][1]),
  131. color=tr2br_color,
  132. thickness=thickness,
  133. )
  134. cv2.line(
  135. img,
  136. (points[2][0], points[2][1]),
  137. (points[3][0], points[3][1]),
  138. color=br2bl_color,
  139. thickness=thickness,
  140. )
  141. cv2.line(
  142. img,
  143. (points[3][0], points[3][1]),
  144. (points[0][0], points[0][1]),
  145. color=bl2tl_color,
  146. thickness=thickness,
  147. )
  148. return img
  149. def vis_points(
  150. img: np.ndarray, points, texts: List[str] = None, color=(0, 200, 0)
  151. ) -> np.ndarray:
  152. """
  153. Args:
  154. img:
  155. points: [N, 8] 8: x1,y1,x2,y2,x3,y3,x3,y4
  156. texts:
  157. color:
  158. Returns:
  159. """
  160. points = np.array(points)
  161. if texts is not None:
  162. assert len(texts) == points.shape[0]
  163. for i, _points in enumerate(points):
  164. vis_polygon(img, _points.reshape(-1, 2), thickness=2, color=color)
  165. bbox = points_to_bbox(_points)
  166. left, top, right, bottom = bbox
  167. cx = (left + right) // 2
  168. cy = (top + bottom) // 2
  169. txt = texts[i]
  170. font = cv2.FONT_HERSHEY_SIMPLEX
  171. cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
  172. img = cv2.rectangle(
  173. img,
  174. (cx - 5 * len(txt), cy - cat_size[1] - 5),
  175. (cx - 5 * len(txt) + cat_size[0], cy - 5),
  176. color,
  177. -1,
  178. )
  179. img = cv2.putText(
  180. img,
  181. txt,
  182. (cx - 5 * len(txt), cy - 5),
  183. font,
  184. 0.5,
  185. (255, 255, 255),
  186. thickness=1,
  187. lineType=cv2.LINE_AA,
  188. )
  189. return img
  190. def vis_polygons_with_index(image, points):
  191. texts = [str(i) for i in range(len(points))]
  192. res_img = vis_points(image.copy(), points, texts)
  193. return res_img