table_structure_unet.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import copy
  2. import math
  3. from typing import Optional, Dict, Any, Tuple
  4. import cv2
  5. import numpy as np
  6. from skimage import measure
  7. from .utils import OrtInferSession, resize_img
  8. from .utils_table_line_rec import (
  9. get_table_line,
  10. final_adjust_lines,
  11. min_area_rect_box,
  12. draw_lines,
  13. adjust_lines,
  14. )
  15. from.utils_table_recover import (
  16. sorted_ocr_boxes,
  17. box_4_2_poly_to_box_4_1,
  18. )
  19. class TSRUnet:
  20. def __init__(self, config: Dict):
  21. self.K = 1000
  22. self.MK = 4000
  23. self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
  24. self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
  25. self.inp_height = 1024
  26. self.inp_width = 1024
  27. self.session = OrtInferSession(config)
  28. def __call__(
  29. self, img: np.ndarray, **kwargs
  30. ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
  31. img_info = self.preprocess(img)
  32. pred = self.infer(img_info)
  33. polygons, rotated_polygons = self.postprocess(img, pred, **kwargs)
  34. if polygons.size == 0:
  35. return None, None
  36. polygons = polygons.reshape(polygons.shape[0], 4, 2)
  37. polygons[:, 3, :], polygons[:, 1, :] = (
  38. polygons[:, 1, :].copy(),
  39. polygons[:, 3, :].copy(),
  40. )
  41. rotated_polygons = rotated_polygons.reshape(rotated_polygons.shape[0], 4, 2)
  42. rotated_polygons[:, 3, :], rotated_polygons[:, 1, :] = (
  43. rotated_polygons[:, 1, :].copy(),
  44. rotated_polygons[:, 3, :].copy(),
  45. )
  46. _, idx = sorted_ocr_boxes(
  47. [box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
  48. threhold=0.4,
  49. )
  50. polygons = polygons[idx]
  51. rotated_polygons = rotated_polygons[idx]
  52. return polygons, rotated_polygons
  53. def preprocess(self, img) -> Dict[str, Any]:
  54. scale = (self.inp_height, self.inp_width)
  55. img, _, _ = resize_img(img, scale, True)
  56. img = img.copy().astype(np.float32)
  57. assert img.dtype != np.uint8
  58. mean = np.float64(self.mean.reshape(1, -1))
  59. stdinv = 1 / np.float64(self.std.reshape(1, -1))
  60. cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
  61. cv2.subtract(img, mean, img) # inplace
  62. cv2.multiply(img, stdinv, img) # inplace
  63. img = img.transpose(2, 0, 1)
  64. images = img[None, :]
  65. return {"img": images}
  66. def infer(self, input):
  67. result = self.session(input["img"][None, ...])[0][0]
  68. result = result[0].astype(np.uint8)
  69. return result
  70. def postprocess(self, img, pred, **kwargs):
  71. row = kwargs.get("row", 50) if kwargs else 50
  72. col = kwargs.get("col", 30) if kwargs else 30
  73. h_lines_threshold = kwargs.get("h_lines_threshold", 100) if kwargs else 100
  74. v_lines_threshold = kwargs.get("v_lines_threshold", 15) if kwargs else 15
  75. angle = kwargs.get("angle", 50) if kwargs else 50
  76. enhance_box_line = kwargs.get("enhance_box_line", True) if kwargs else True
  77. morph_close = (
  78. kwargs.get("morph_close", enhance_box_line) if kwargs else enhance_box_line
  79. ) # 是否进行闭合运算以找到更多小的框
  80. more_h_lines = (
  81. kwargs.get("more_h_lines", enhance_box_line) if kwargs else enhance_box_line
  82. ) # 是否调整以找到更多的横线
  83. more_v_lines = (
  84. kwargs.get("more_v_lines", enhance_box_line) if kwargs else enhance_box_line
  85. ) # 是否调整以找到更多的横线
  86. extend_line = (
  87. kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
  88. ) # 是否进行线段延长使得端点连接
  89. # 是否进行旋转修正
  90. rotated_fix = kwargs.get("rotated_fix") if kwargs else True
  91. ori_shape = img.shape
  92. pred = np.uint8(pred)
  93. hpred = copy.deepcopy(pred) # 横线
  94. vpred = copy.deepcopy(pred) # 竖线
  95. whereh = np.where(hpred == 1)
  96. wherev = np.where(vpred == 2)
  97. hpred[wherev] = 0
  98. vpred[whereh] = 0
  99. hpred = cv2.resize(hpred, (ori_shape[1], ori_shape[0]))
  100. vpred = cv2.resize(vpred, (ori_shape[1], ori_shape[0]))
  101. h, w = pred.shape
  102. hors_k = int(math.sqrt(w) * 1.2)
  103. vert_k = int(math.sqrt(h) * 1.2)
  104. hkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (hors_k, 1))
  105. vkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_k))
  106. vpred = cv2.morphologyEx(
  107. vpred, cv2.MORPH_CLOSE, vkernel, iterations=1
  108. ) # 先膨胀后腐蚀的过程
  109. if morph_close:
  110. hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1)
  111. colboxes = get_table_line(vpred, axis=1, lineW=col) # 竖线
  112. rowboxes = get_table_line(hpred, axis=0, lineW=row) # 横线
  113. rboxes_row_, rboxes_col_ = [], []
  114. if more_h_lines:
  115. rboxes_row_ = adjust_lines(rowboxes, alph=h_lines_threshold, angle=angle)
  116. if more_v_lines:
  117. rboxes_col_ = adjust_lines(colboxes, alph=v_lines_threshold, angle=angle)
  118. rowboxes += rboxes_row_
  119. colboxes += rboxes_col_
  120. if extend_line:
  121. rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
  122. line_img = np.zeros(img.shape[:2], dtype="uint8")
  123. line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
  124. rotated_angle = self.cal_rotate_angle(line_img)
  125. if rotated_fix and abs(rotated_angle) > 0.3:
  126. rotated_line_img = self.rotate_image(line_img, rotated_angle)
  127. rotated_polygons = self.cal_region_boxes(rotated_line_img)
  128. polygons = self.unrotate_polygons(
  129. rotated_polygons, rotated_angle, line_img.shape
  130. )
  131. else:
  132. polygons = self.cal_region_boxes(line_img)
  133. rotated_polygons = polygons.copy()
  134. return polygons, rotated_polygons
  135. def cal_region_boxes(self, tmp):
  136. labels = measure.label(tmp < 255, connectivity=2) # 8连通区域标记
  137. regions = measure.regionprops(labels)
  138. ceilboxes = min_area_rect_box(
  139. regions,
  140. False,
  141. tmp.shape[1],
  142. tmp.shape[0],
  143. filtersmall=True,
  144. adjust_box=False,
  145. ) # 最后一个参数改为False
  146. return np.array(ceilboxes)
  147. def cal_rotate_angle(self, tmp):
  148. # 计算最外侧的旋转框
  149. contours, _ = cv2.findContours(tmp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  150. if not contours:
  151. return 0
  152. largest_contour = max(contours, key=cv2.contourArea)
  153. rect = cv2.minAreaRect(largest_contour)
  154. # 计算旋转角度
  155. angle = rect[2]
  156. if angle < -45:
  157. angle += 90
  158. elif angle > 45:
  159. angle -= 90
  160. return angle
  161. def rotate_image(self, image, angle):
  162. # 获取图像的中心点
  163. (h, w) = image.shape[:2]
  164. center = (w // 2, h // 2)
  165. # 计算旋转矩阵
  166. M = cv2.getRotationMatrix2D(center, angle, 1.0)
  167. # 进行旋转
  168. rotated_image = cv2.warpAffine(
  169. image, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE
  170. )
  171. return rotated_image
  172. def unrotate_polygons(
  173. self, polygons: np.ndarray, angle: float, img_shape: tuple
  174. ) -> np.ndarray:
  175. # 将多边形旋转回原始位置
  176. (h, w) = img_shape
  177. center = (w // 2, h // 2)
  178. M_inv = cv2.getRotationMatrix2D(center, -angle, 1.0)
  179. # 将 (N, 8) 转换为 (N, 4, 2)
  180. polygons_reshaped = polygons.reshape(-1, 4, 2)
  181. # 批量逆旋转
  182. unrotated_polygons = cv2.transform(polygons_reshaped, M_inv)
  183. # 将 (N, 4, 2) 转换回 (N, 8)
  184. unrotated_polygons = unrotated_polygons.reshape(-1, 8)
  185. return unrotated_polygons