table_structure_unet.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import copy
  2. import math
  3. from typing import Optional, Dict, Any, Tuple
  4. import cv2
  5. import numpy as np
  6. from skimage import measure
  7. from .wired_table_rec_utils import OrtInferSession, resize_img
  8. from .table_line_rec_utils import (
  9. get_table_line,
  10. final_adjust_lines,
  11. min_area_rect_box,
  12. draw_lines,
  13. adjust_lines,
  14. )
  15. from .table_recover_utils import (
  16. sorted_ocr_boxes,
  17. box_4_2_poly_to_box_4_1,
  18. )
  19. class TSRUnet:
  20. def __init__(self, config: Dict):
  21. self.K = 1000
  22. self.MK = 4000
  23. self.mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
  24. self.std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
  25. self.inp_height = 1024
  26. self.inp_width = 1024
  27. self.session = OrtInferSession(config)
  28. def __call__(
  29. self, img: np.ndarray, **kwargs
  30. ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
  31. img_info = self.preprocess(img)
  32. pred = self.infer(img_info)
  33. polygons, rotated_polygons = self.postprocess(img, pred, **kwargs)
  34. if polygons.size == 0:
  35. return None, None
  36. polygons = polygons.reshape(polygons.shape[0], 4, 2)
  37. polygons[:, 3, :], polygons[:, 1, :] = (
  38. polygons[:, 1, :].copy(),
  39. polygons[:, 3, :].copy(),
  40. )
  41. rotated_polygons = rotated_polygons.reshape(rotated_polygons.shape[0], 4, 2)
  42. rotated_polygons[:, 3, :], rotated_polygons[:, 1, :] = (
  43. rotated_polygons[:, 1, :].copy(),
  44. rotated_polygons[:, 3, :].copy(),
  45. )
  46. _, idx = sorted_ocr_boxes(
  47. [box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
  48. threshold=0.4,
  49. )
  50. polygons = polygons[idx]
  51. rotated_polygons = rotated_polygons[idx]
  52. return polygons, rotated_polygons
  53. def preprocess(self, img) -> Dict[str, Any]:
  54. scale = (self.inp_height, self.inp_width)
  55. img, _, _ = resize_img(img, scale, True)
  56. img = img.copy().astype(np.float32)
  57. assert img.dtype != np.uint8
  58. mean = np.float64(self.mean.reshape(1, -1))
  59. stdinv = 1 / np.float64(self.std.reshape(1, -1))
  60. cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
  61. cv2.subtract(img, mean, img) # inplace
  62. cv2.multiply(img, stdinv, img) # inplace
  63. img = img.transpose(2, 0, 1)
  64. images = img[None, :]
  65. return {"img": images}
  66. def infer(self, input):
  67. result = self.session(input["img"][None, ...])[0][0]
  68. result = result[0].astype(np.uint8)
  69. return result
  70. def postprocess(self, img, pred, **kwargs):
  71. row = kwargs.get("row", 50) if kwargs else 50
  72. col = kwargs.get("col", 30) if kwargs else 30
  73. h_lines_threshold = kwargs.get("h_lines_threshold", 100) if kwargs else 100
  74. v_lines_threshold = kwargs.get("v_lines_threshold", 15) if kwargs else 15
  75. angle = kwargs.get("angle", 50) if kwargs else 50
  76. enhance_box_line = kwargs.get("enhance_box_line", True) if kwargs else True
  77. morph_close = (
  78. kwargs.get("morph_close", enhance_box_line) if kwargs else enhance_box_line
  79. ) # 是否进行闭合运算以找到更多小的框
  80. more_h_lines = (
  81. kwargs.get("more_h_lines", enhance_box_line) if kwargs else enhance_box_line
  82. ) # 是否调整以找到更多的横线
  83. more_v_lines = (
  84. kwargs.get("more_v_lines", enhance_box_line) if kwargs else enhance_box_line
  85. ) # 是否调整以找到更多的横线
  86. extend_line = (
  87. kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
  88. ) # 是否进行线段延长使得端点连接
  89. ori_shape = img.shape
  90. pred = np.uint8(pred)
  91. hpred = copy.deepcopy(pred) # 横线
  92. vpred = copy.deepcopy(pred) # 竖线
  93. whereh = np.where(hpred == 1)
  94. wherev = np.where(vpred == 2)
  95. hpred[wherev] = 0
  96. vpred[whereh] = 0
  97. hpred = cv2.resize(hpred, (ori_shape[1], ori_shape[0]))
  98. vpred = cv2.resize(vpred, (ori_shape[1], ori_shape[0]))
  99. h, w = pred.shape
  100. hors_k = int(math.sqrt(w) * 1.2)
  101. vert_k = int(math.sqrt(h) * 1.2)
  102. hkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (hors_k, 1))
  103. vkernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, vert_k))
  104. vpred = cv2.morphologyEx(
  105. vpred, cv2.MORPH_CLOSE, vkernel, iterations=1
  106. ) # 先膨胀后腐蚀的过程
  107. if morph_close:
  108. hpred = cv2.morphologyEx(hpred, cv2.MORPH_CLOSE, hkernel, iterations=1)
  109. colboxes = get_table_line(vpred, axis=1, lineW=col) # 竖线
  110. rowboxes = get_table_line(hpred, axis=0, lineW=row) # 横线
  111. rboxes_row_, rboxes_col_ = [], []
  112. if more_h_lines:
  113. rboxes_row_ = adjust_lines(rowboxes, alph=h_lines_threshold, angle=angle)
  114. if more_v_lines:
  115. rboxes_col_ = adjust_lines(colboxes, alph=v_lines_threshold, angle=angle)
  116. rowboxes += rboxes_row_
  117. colboxes += rboxes_col_
  118. if extend_line:
  119. rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
  120. line_img = np.zeros(img.shape[:2], dtype="uint8")
  121. line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
  122. polygons = self.cal_region_boxes(line_img)
  123. rotated_polygons = polygons.copy()
  124. return polygons, rotated_polygons
  125. def cal_region_boxes(self, tmp):
  126. labels = measure.label(tmp < 255, connectivity=2) # 8连通区域标记
  127. regions = measure.regionprops(labels)
  128. ceilboxes = min_area_rect_box(
  129. regions,
  130. False,
  131. tmp.shape[1],
  132. tmp.shape[0],
  133. filtersmall=True,
  134. adjust_box=False,
  135. ) # 最后一个参数改为False
  136. return np.array(ceilboxes)