utils_table_recover.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from typing import Any, Dict, List, Union, Tuple
  2. import numpy as np
  3. def calculate_iou(
  4. box1: Union[np.ndarray, List], box2: Union[np.ndarray, List]
  5. ) -> float:
  6. """
  7. :param box1: Iterable [xmin,ymin,xmax,ymax]
  8. :param box2: Iterable [xmin,ymin,xmax,ymax]
  9. :return: iou: float 0-1
  10. """
  11. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  12. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  13. # 不相交直接退出检测
  14. if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
  15. return 0.0
  16. # 计算交集
  17. inter_x1 = max(b1_x1, b2_x1)
  18. inter_y1 = max(b1_y1, b2_y1)
  19. inter_x2 = min(b1_x2, b2_x2)
  20. inter_y2 = min(b1_y2, b2_y2)
  21. i_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
  22. # 计算并集
  23. b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
  24. b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
  25. u_area = b1_area + b2_area - i_area
  26. # 避免除零错误,如果区域小到乘积为0,认为是错误识别,直接去掉
  27. if u_area == 0:
  28. return 1
  29. # 检查完全包含
  30. iou = i_area / u_area
  31. return iou
  32. def is_box_contained(
  33. box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
  34. ) -> Union[int, None]:
  35. """
  36. :param box1: Iterable [xmin,ymin,xmax,ymax]
  37. :param box2: Iterable [xmin,ymin,xmax,ymax]
  38. :return: 1: box1 is contained 2: box2 is contained None: no contain these
  39. """
  40. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  41. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  42. # 不相交直接退出检测
  43. if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
  44. return None
  45. # 计算box2的总面积
  46. b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
  47. b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
  48. # 计算box1和box2的交集
  49. intersect_x1 = max(b1_x1, b2_x1)
  50. intersect_y1 = max(b1_y1, b2_y1)
  51. intersect_x2 = min(b1_x2, b2_x2)
  52. intersect_y2 = min(b1_y2, b2_y2)
  53. # 计算交集的面积
  54. intersect_area = max(0, intersect_x2 - intersect_x1) * max(
  55. 0, intersect_y2 - intersect_y1
  56. )
  57. # 计算外面的面积
  58. b1_outside_area = b1_area - intersect_area
  59. b2_outside_area = b2_area - intersect_area
  60. # 计算外面的面积占box2总面积的比例
  61. ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
  62. ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
  63. if ratio_b1 < threshold:
  64. return 1
  65. if ratio_b2 < threshold:
  66. return 2
  67. # 判断比例是否大于阈值
  68. return None
  69. def is_single_axis_contained(
  70. box1: Union[np.ndarray, List],
  71. box2: Union[np.ndarray, List],
  72. axis="x",
  73. threhold: float = 0.2,
  74. ) -> Union[int, None]:
  75. """
  76. :param box1: Iterable [xmin,ymin,xmax,ymax]
  77. :param box2: Iterable [xmin,ymin,xmax,ymax]
  78. :return: 1: box1 is contained 2: box2 is contained None: no contain these
  79. """
  80. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  81. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  82. # 计算轴重叠大小
  83. if axis == "x":
  84. b1_area = b1_x2 - b1_x1
  85. b2_area = b2_x2 - b2_x1
  86. i_area = min(b1_x2, b2_x2) - max(b1_x1, b2_x1)
  87. else:
  88. b1_area = b1_y2 - b1_y1
  89. b2_area = b2_y2 - b2_y1
  90. i_area = min(b1_y2, b2_y2) - max(b1_y1, b2_y1)
  91. # 计算外面的面积
  92. b1_outside_area = b1_area - i_area
  93. b2_outside_area = b2_area - i_area
  94. ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
  95. ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
  96. if ratio_b1 < threhold:
  97. return 1
  98. if ratio_b2 < threhold:
  99. return 2
  100. return None
  101. def sorted_ocr_boxes(
  102. dt_boxes: Union[np.ndarray, list], threhold: float = 0.2
  103. ) -> Tuple[Union[np.ndarray, list], List[int]]:
  104. """
  105. Sort text boxes in order from top to bottom, left to right
  106. args:
  107. dt_boxes(array):detected text boxes with (xmin, ymin, xmax, ymax)
  108. return:
  109. sorted boxes(array) with (xmin, ymin, xmax, ymax)
  110. """
  111. num_boxes = len(dt_boxes)
  112. if num_boxes <= 0:
  113. return dt_boxes, []
  114. indexed_boxes = [(box, idx) for idx, box in enumerate(dt_boxes)]
  115. sorted_boxes_with_idx = sorted(indexed_boxes, key=lambda x: (x[0][1], x[0][0]))
  116. _boxes, indices = zip(*sorted_boxes_with_idx)
  117. indices = list(indices)
  118. _boxes = [dt_boxes[i] for i in indices]
  119. threahold = 20
  120. # 避免输出和输入格式不对应,与函数功能不符合
  121. if isinstance(dt_boxes, np.ndarray):
  122. _boxes = np.array(_boxes)
  123. for i in range(num_boxes - 1):
  124. for j in range(i, -1, -1):
  125. c_idx = is_single_axis_contained(
  126. _boxes[j], _boxes[j + 1], axis="y", threhold=threhold
  127. )
  128. if (
  129. c_idx is not None
  130. and _boxes[j + 1][0] < _boxes[j][0]
  131. and abs(_boxes[j][1] - _boxes[j + 1][1]) < threahold
  132. ):
  133. _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy()
  134. indices[j], indices[j + 1] = indices[j + 1], indices[j]
  135. else:
  136. break
  137. return _boxes, indices
  138. def box_4_1_poly_to_box_4_2(poly_box: Union[list, np.ndarray]) -> List[List[float]]:
  139. xmin, ymin, xmax, ymax = tuple(poly_box)
  140. return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
  141. def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]:
  142. """
  143. 将poly_box转换为box_4_1
  144. :param poly_box:
  145. :return:
  146. """
  147. return [poly_box[0][0], poly_box[0][1], poly_box[2][0], poly_box[2][1]]
  148. def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.ndarray):
  149. """
  150. :param dt_rec_boxes: [[(4.2), text, score]]
  151. :param pred_bboxes: shap (4,2)
  152. :return:
  153. """
  154. matched = {}
  155. not_match_orc_boxes = []
  156. for i, gt_box in enumerate(dt_rec_boxes):
  157. for j, pred_box in enumerate(pred_bboxes):
  158. pred_box = [pred_box[0][0], pred_box[0][1], pred_box[2][0], pred_box[2][1]]
  159. ocr_boxes = gt_box[0]
  160. # xmin,ymin,xmax,ymax
  161. ocr_box = (
  162. ocr_boxes[0][0],
  163. ocr_boxes[0][1],
  164. ocr_boxes[2][0],
  165. ocr_boxes[2][1],
  166. )
  167. contained = is_box_contained(ocr_box, pred_box, 0.6)
  168. if contained == 1 or calculate_iou(ocr_box, pred_box) > 0.8:
  169. if j not in matched:
  170. matched[j] = [gt_box]
  171. else:
  172. matched[j].append(gt_box)
  173. else:
  174. not_match_orc_boxes.append(gt_box)
  175. return matched, not_match_orc_boxes
  176. def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[Any]:
  177. """
  178. :param ocr_list: [[[xmin,ymin,xmax,ymax], text]]
  179. :return:
  180. """
  181. threshold = 10
  182. for i in range(len(ocr_list)):
  183. if not ocr_list[i]:
  184. continue
  185. for j in range(i + 1, len(ocr_list)):
  186. if not ocr_list[j]:
  187. continue
  188. cur = ocr_list[i]
  189. next = ocr_list[j]
  190. cur_box = cur[0]
  191. next_box = next[0]
  192. c_idx = is_single_axis_contained(
  193. cur[0], next[0], axis="y", threhold=threhold
  194. )
  195. if c_idx:
  196. dis = max(next_box[0] - cur_box[2], 0)
  197. blank_str = int(dis / threshold) * " "
  198. cur[1] = cur[1] + blank_str + next[1]
  199. xmin = min(cur_box[0], next_box[0])
  200. xmax = max(cur_box[2], next_box[2])
  201. ymin = min(cur_box[1], next_box[1])
  202. ymax = max(cur_box[3], next_box[3])
  203. cur_box[0] = xmin
  204. cur_box[1] = ymin
  205. cur_box[2] = xmax
  206. cur_box[3] = ymax
  207. ocr_list[j] = None
  208. ocr_list = [x for x in ocr_list if x]
  209. return ocr_list
  210. def plot_html_table(
  211. logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]]
  212. ) -> str:
  213. # 初始化最大行数和列数
  214. max_row = 0
  215. max_col = 0
  216. # 计算最大行数和列数
  217. for point in logi_points:
  218. max_row = max(max_row, point[1] + 1) # 加1是因为结束下标是包含在内的
  219. max_col = max(max_col, point[3] + 1) # 加1是因为结束下标是包含在内的
  220. # 创建一个二维数组来存储 sorted_logi_points 中的元素
  221. grid = [[None] * max_col for _ in range(max_row)]
  222. valid_start_row = (1 << 16) - 1
  223. valid_start_col = (1 << 16) - 1
  224. valid_end_col = 0
  225. # 将 sorted_logi_points 中的元素填充到 grid 中
  226. for i, logic_point in enumerate(logi_points):
  227. row_start, row_end, col_start, col_end = (
  228. logic_point[0],
  229. logic_point[1],
  230. logic_point[2],
  231. logic_point[3],
  232. )
  233. ocr_rec_text_list = cell_box_map.get(i)
  234. if ocr_rec_text_list and "".join(ocr_rec_text_list):
  235. valid_start_row = min(row_start, valid_start_row)
  236. valid_start_col = min(col_start, valid_start_col)
  237. valid_end_col = max(col_end, valid_end_col)
  238. for row in range(row_start, row_end + 1):
  239. for col in range(col_start, col_end + 1):
  240. grid[row][col] = (i, row_start, row_end, col_start, col_end)
  241. # 创建表格
  242. table_html = "<html><body><table>"
  243. # 遍历每行
  244. for row in range(max_row):
  245. if row < valid_start_row:
  246. continue
  247. temp = "<tr>"
  248. # 遍历每一列
  249. for col in range(max_col):
  250. if col < valid_start_col or col > valid_end_col:
  251. continue
  252. if not grid[row][col]:
  253. temp += "<td></td>"
  254. else:
  255. i, row_start, row_end, col_start, col_end = grid[row][col]
  256. if not cell_box_map.get(i):
  257. continue
  258. if row == row_start and col == col_start:
  259. ocr_rec_text = cell_box_map.get(i)
  260. # text = "<br>".join(ocr_rec_text)
  261. text = "".join(ocr_rec_text)
  262. # 如果是起始单元格
  263. row_span = row_end - row_start + 1
  264. col_span = col_end - col_start + 1
  265. cell_content = (
  266. f"<td rowspan={row_span} colspan={col_span}>{text}</td>"
  267. )
  268. temp += cell_content
  269. table_html = table_html + temp + "</tr>"
  270. table_html += "</table></body></html>"
  271. return table_html