utils_table_line_rec.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import math
  2. import cv2
  3. import numpy as np
  4. from scipy.spatial import distance as dist
  5. from skimage import measure
  6. def transform_preds(coords, center, scale, output_size, rot=0):
  7. target_coords = np.zeros(coords.shape)
  8. trans = get_affine_transform(center, scale, rot, output_size, inv=1)
  9. for p in range(coords.shape[0]):
  10. target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
  11. return target_coords
  12. def get_affine_transform(
  13. center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0
  14. ):
  15. if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
  16. scale = np.array([scale, scale], dtype=np.float32)
  17. scale_tmp = scale
  18. src_w = scale_tmp[0]
  19. dst_w = output_size[0]
  20. dst_h = output_size[1]
  21. rot_rad = np.pi * rot / 180
  22. src_dir = get_dir([0, src_w * -0.5], rot_rad)
  23. dst_dir = np.array([0, dst_w * -0.5], np.float32)
  24. src = np.zeros((3, 2), dtype=np.float32)
  25. dst = np.zeros((3, 2), dtype=np.float32)
  26. src[0, :] = center + scale_tmp * shift
  27. src[1, :] = center + src_dir + scale_tmp * shift
  28. dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
  29. dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
  30. src[2:, :] = get_3rd_point(src[0, :], src[1, :])
  31. dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
  32. if inv:
  33. trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
  34. else:
  35. trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
  36. return trans
  37. def affine_transform(pt, t):
  38. new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32).T
  39. new_pt = np.dot(t, new_pt)
  40. return new_pt[:2]
  41. def get_dir(src_point, rot_rad):
  42. sn, cs = np.sin(rot_rad), np.cos(rot_rad)
  43. src_result = [0, 0]
  44. src_result[0] = src_point[0] * cs - src_point[1] * sn
  45. src_result[1] = src_point[0] * sn + src_point[1] * cs
  46. return src_result
  47. def get_3rd_point(a, b):
  48. direct = a - b
  49. return b + np.array([-direct[1], direct[0]], dtype=np.float32)
  50. def get_table_line(binimg, axis=0, lineW=10):
  51. ##获取表格线
  52. ##axis=0 横线
  53. ##axis=1 竖线
  54. labels = measure.label(binimg > 0, connectivity=2) # 8连通区域标记
  55. regions = measure.regionprops(labels)
  56. if axis == 1:
  57. lineboxes = [
  58. min_area_rect(line.coords)
  59. for line in regions
  60. if line.bbox[2] - line.bbox[0] > lineW
  61. ]
  62. else:
  63. lineboxes = [
  64. min_area_rect(line.coords)
  65. for line in regions
  66. if line.bbox[3] - line.bbox[1] > lineW
  67. ]
  68. return lineboxes
  69. def min_area_rect(coords):
  70. """
  71. 多边形外接矩形
  72. """
  73. rect = cv2.minAreaRect(coords[:, ::-1])
  74. box = cv2.boxPoints(rect)
  75. box = box.reshape((8,)).tolist()
  76. box = image_location_sort_box(box)
  77. x1, y1, x2, y2, x3, y3, x4, y4 = box
  78. degree, w, h, cx, cy = calculate_center_rotate_angle(box)
  79. if w < h:
  80. xmin = (x1 + x2) / 2
  81. xmax = (x3 + x4) / 2
  82. ymin = (y1 + y2) / 2
  83. ymax = (y3 + y4) / 2
  84. else:
  85. xmin = (x1 + x4) / 2
  86. xmax = (x2 + x3) / 2
  87. ymin = (y1 + y4) / 2
  88. ymax = (y2 + y3) / 2
  89. # degree,w,h,cx,cy = solve(box)
  90. # x1,y1,x2,y2,x3,y3,x4,y4 = box
  91. # return {'degree':degree,'w':w,'h':h,'cx':cx,'cy':cy}
  92. return [xmin, ymin, xmax, ymax]
  93. def image_location_sort_box(box):
  94. x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
  95. pts = (x1, y1), (x2, y2), (x3, y3), (x4, y4)
  96. pts = np.array(pts, dtype="float32")
  97. (x1, y1), (x2, y2), (x3, y3), (x4, y4) = _order_points(pts)
  98. return [x1, y1, x2, y2, x3, y3, x4, y4]
  99. def calculate_center_rotate_angle(box):
  100. """
  101. 绕 cx,cy点 w,h 旋转 angle 的坐标,能一定程度缓解图片的内部倾斜,但是还是依赖模型稳妥
  102. x = cx-w/2
  103. y = cy-h/2
  104. x1-cx = -w/2*cos(angle) +h/2*sin(angle)
  105. y1 -cy= -w/2*sin(angle) -h/2*cos(angle)
  106. h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle)
  107. w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle)
  108. (hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy)
  109. """
  110. x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
  111. cx = (x1 + x3 + x2 + x4) / 4.0
  112. cy = (y1 + y3 + y4 + y2) / 4.0
  113. w = (
  114. np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
  115. + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)
  116. ) / 2
  117. h = (
  118. np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
  119. + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)
  120. ) / 2
  121. # x = cx-w/2
  122. # y = cy-h/2
  123. sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2
  124. angle = np.arcsin(sinA)
  125. return angle, w, h, cx, cy
  126. def _order_points(pts):
  127. # 根据x坐标对点进行排序
  128. """
  129. ---------------------
  130. 本项目中是为了排序后得到[(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)]
  131. 作者:Tong_T
  132. 来源:CSDN
  133. 原文:https://blog.csdn.net/Tong_T/article/details/81907132
  134. 版权声明:本文为博主原创文章,转载请附上博文链接!
  135. """
  136. x_sorted = pts[np.argsort(pts[:, 0]), :]
  137. left_most = x_sorted[:2, :]
  138. right_most = x_sorted[2:, :]
  139. left_most = left_most[np.argsort(left_most[:, 1]), :]
  140. (tl, bl) = left_most
  141. distance = dist.cdist(tl[np.newaxis], right_most, "euclidean")[0]
  142. (br, tr) = right_most[np.argsort(distance)[::-1], :]
  143. return np.array([tl, tr, br, bl], dtype="float32")
  144. def sqrt(p1, p2):
  145. return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
  146. def adjust_lines(lines, alph=50, angle=50):
  147. lines_n = len(lines)
  148. new_lines = []
  149. for i in range(lines_n):
  150. x1, y1, x2, y2 = lines[i]
  151. cx1, cy1 = (x1 + x2) / 2, (y1 + y2) / 2
  152. for j in range(lines_n):
  153. if i != j:
  154. x3, y3, x4, y4 = lines[j]
  155. cx2, cy2 = (x3 + x4) / 2, (y3 + y4) / 2
  156. if (x3 < cx1 < x4 or y3 < cy1 < y4) or (
  157. x1 < cx2 < x2 or y1 < cy2 < y2
  158. ): # 判断两个横线在y方向的投影重不重合
  159. continue
  160. else:
  161. r = sqrt((x1, y1), (x3, y3))
  162. k = abs((y3 - y1) / (x3 - x1 + 1e-10))
  163. a = math.atan(k) * 180 / math.pi
  164. if r < alph and a < angle:
  165. new_lines.append((x1, y1, x3, y3))
  166. r = sqrt((x1, y1), (x4, y4))
  167. k = abs((y4 - y1) / (x4 - x1 + 1e-10))
  168. a = math.atan(k) * 180 / math.pi
  169. if r < alph and a < angle:
  170. new_lines.append((x1, y1, x4, y4))
  171. r = sqrt((x2, y2), (x3, y3))
  172. k = abs((y3 - y2) / (x3 - x2 + 1e-10))
  173. a = math.atan(k) * 180 / math.pi
  174. if r < alph and a < angle:
  175. new_lines.append((x2, y2, x3, y3))
  176. r = sqrt((x2, y2), (x4, y4))
  177. k = abs((y4 - y2) / (x4 - x2 + 1e-10))
  178. a = math.atan(k) * 180 / math.pi
  179. if r < alph and a < angle:
  180. new_lines.append((x2, y2, x4, y4))
  181. return new_lines
  182. def final_adjust_lines(rowboxes, colboxes):
  183. nrow = len(rowboxes)
  184. ncol = len(colboxes)
  185. for i in range(nrow):
  186. for j in range(ncol):
  187. rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], alpha=20, angle=30)
  188. colboxes[j] = line_to_line(colboxes[j], rowboxes[i], alpha=20, angle=30)
  189. return rowboxes, colboxes
  190. def draw_lines(im, bboxes, color=(0, 0, 0), lineW=3):
  191. """
  192. boxes: bounding boxes
  193. """
  194. tmp = np.copy(im)
  195. c = color
  196. h, w = im.shape[:2]
  197. for box in bboxes:
  198. x1, y1, x2, y2 = box[:4]
  199. cv2.line(
  200. tmp, (int(x1), int(y1)), (int(x2), int(y2)), c, lineW, lineType=cv2.LINE_AA
  201. )
  202. return tmp
  203. def line_to_line(points1, points2, alpha=10, angle=30):
  204. """
  205. 线段之间的距离
  206. """
  207. x1, y1, x2, y2 = points1
  208. ox1, oy1, ox2, oy2 = points2
  209. xy = np.array([(x1, y1), (x2, y2)], dtype="float32")
  210. A1, B1, C1 = fit_line(xy)
  211. oxy = np.array([(ox1, oy1), (ox2, oy2)], dtype="float32")
  212. A2, B2, C2 = fit_line(oxy)
  213. flag1 = point_line_cor(np.array([x1, y1], dtype="float32"), A2, B2, C2)
  214. flag2 = point_line_cor(np.array([x2, y2], dtype="float32"), A2, B2, C2)
  215. if (flag1 > 0 and flag2 > 0) or (flag1 < 0 and flag2 < 0): # 横线或者竖线在竖线或者横线的同一侧
  216. if (A1 * B2 - A2 * B1) != 0:
  217. x = (B1 * C2 - B2 * C1) / (A1 * B2 - A2 * B1)
  218. y = (A2 * C1 - A1 * C2) / (A1 * B2 - A2 * B1)
  219. # x, y = round(x, 2), round(y, 2)
  220. p = (x, y) # 横线与竖线的交点
  221. r0 = sqrt(p, (x1, y1))
  222. r1 = sqrt(p, (x2, y2))
  223. if min(r0, r1) < alpha: # 若交点与线起点或者终点的距离小于alpha,则延长线到交点
  224. if r0 < r1:
  225. k = abs((y2 - p[1]) / (x2 - p[0] + 1e-10))
  226. a = math.atan(k) * 180 / math.pi
  227. if a < angle or abs(90 - a) < angle:
  228. points1 = np.array([p[0], p[1], x2, y2], dtype="float32")
  229. else:
  230. k = abs((y1 - p[1]) / (x1 - p[0] + 1e-10))
  231. a = math.atan(k) * 180 / math.pi
  232. if a < angle or abs(90 - a) < angle:
  233. points1 = np.array([x1, y1, p[0], p[1]], dtype="float32")
  234. return points1
  235. def min_area_rect_box(
  236. regions, flag=True, W=0, H=0, filtersmall=False, adjust_box=False
  237. ):
  238. """
  239. 多边形外接矩形
  240. """
  241. boxes = []
  242. for region in regions:
  243. if region.bbox_area > H * W * 3 / 4: # 过滤大的单元格
  244. continue
  245. rect = cv2.minAreaRect(region.coords[:, ::-1])
  246. box = cv2.boxPoints(rect)
  247. box = box.reshape((8,)).tolist()
  248. box = image_location_sort_box(box)
  249. x1, y1, x2, y2, x3, y3, x4, y4 = box
  250. angle, w, h, cx, cy = calculate_center_rotate_angle(box)
  251. # if adjustBox:
  252. # x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w + 5, h + 5, angle=0, degree=None)
  253. # x1, x4 = max(x1, 0), max(x4, 0)
  254. # y1, y2 = max(y1, 0), max(y2, 0)
  255. # if w > 32 and h > 32 and flag:
  256. # if abs(angle / np.pi * 180) < 20:
  257. # if filtersmall and (w < 10 or h < 10):
  258. # continue
  259. # boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
  260. # else:
  261. if w * h < 0.5 * W * H:
  262. if filtersmall and (
  263. w < 15 or h < 15
  264. ): # or w / h > 30 or h / w > 30): # 过滤小的单元格
  265. continue
  266. boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
  267. return boxes
  268. def point_line_cor(p, A, B, C):
  269. ##判断点与线之间的位置关系
  270. # 一般式直线方程(Ax+By+c)=0
  271. x, y = p
  272. r = A * x + B * y + C
  273. return r
  274. def fit_line(p):
  275. """A = Y2 - Y1
  276. B = X1 - X2
  277. C = X2*Y1 - X1*Y2
  278. AX+BY+C=0
  279. 直线一般方程
  280. """
  281. x1, y1 = p[0]
  282. x2, y2 = p[1]
  283. A = y2 - y1
  284. B = x1 - x2
  285. C = x2 * y1 - x1 * y2
  286. return A, B, C