table_recover_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. from typing import Any, Dict, List, Union, Tuple
  2. import numpy as np
  3. import shapely
  4. from shapely.geometry import MultiPoint, Polygon
  5. def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray:
  6. """
  7. Sort text boxes in order from top to bottom, left to right
  8. args:
  9. dt_boxes(array):detected text boxes with shape (N, 4, 2)
  10. return:
  11. sorted boxes(array) with shape (N, 4, 2)
  12. """
  13. num_boxes = dt_boxes.shape[0]
  14. dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  15. _boxes = list(dt_boxes)
  16. # 解决相邻框,后边比前面y轴小,则会被排到前面去的问题
  17. for i in range(num_boxes - 1):
  18. for j in range(i, -1, -1):
  19. if (
  20. abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
  21. and _boxes[j + 1][0][0] < _boxes[j][0][0]
  22. ):
  23. _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]
  24. else:
  25. break
  26. return np.array(_boxes)
  27. def calculate_iou(
  28. box1: Union[np.ndarray, List], box2: Union[np.ndarray, List]
  29. ) -> float:
  30. """
  31. :param box1: Iterable [xmin,ymin,xmax,ymax]
  32. :param box2: Iterable [xmin,ymin,xmax,ymax]
  33. :return: iou: float 0-1
  34. """
  35. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  36. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  37. # 不相交直接退出检测
  38. if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
  39. return 0.0
  40. # 计算交集
  41. inter_x1 = max(b1_x1, b2_x1)
  42. inter_y1 = max(b1_y1, b2_y1)
  43. inter_x2 = min(b1_x2, b2_x2)
  44. inter_y2 = min(b1_y2, b2_y2)
  45. i_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
  46. # 计算并集
  47. b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
  48. b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
  49. u_area = b1_area + b2_area - i_area
  50. # 避免除零错误,如果区域小到乘积为0,认为是错误识别,直接去掉
  51. if u_area == 0:
  52. return 1
  53. # 检查完全包含
  54. iou = i_area / u_area
  55. return iou
  56. def is_box_contained(
  57. box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2
  58. ) -> Union[int, None]:
  59. """
  60. :param box1: Iterable [xmin,ymin,xmax,ymax]
  61. :param box2: Iterable [xmin,ymin,xmax,ymax]
  62. :return: 1: box1 is contained 2: box2 is contained None: no contain these
  63. """
  64. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  65. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  66. # 不相交直接退出检测
  67. if b1_x2 < b2_x1 or b1_x1 > b2_x2 or b1_y2 < b2_y1 or b1_y1 > b2_y2:
  68. return None
  69. # 计算box2的总面积
  70. b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
  71. b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
  72. # 计算box1和box2的交集
  73. intersect_x1 = max(b1_x1, b2_x1)
  74. intersect_y1 = max(b1_y1, b2_y1)
  75. intersect_x2 = min(b1_x2, b2_x2)
  76. intersect_y2 = min(b1_y2, b2_y2)
  77. # 计算交集的面积
  78. intersect_area = max(0, intersect_x2 - intersect_x1) * max(
  79. 0, intersect_y2 - intersect_y1
  80. )
  81. # 计算外面的面积
  82. b1_outside_area = b1_area - intersect_area
  83. b2_outside_area = b2_area - intersect_area
  84. # 计算外面的面积占box2总面积的比例
  85. ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
  86. ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
  87. if ratio_b1 < threshold:
  88. return 1
  89. if ratio_b2 < threshold:
  90. return 2
  91. # 判断比例是否大于阈值
  92. return None
  93. def is_single_axis_contained(
  94. box1: Union[np.ndarray, List],
  95. box2: Union[np.ndarray, List],
  96. axis="x",
  97. threshold: float = 0.2,
  98. ) -> Union[int, None]:
  99. """
  100. :param box1: Iterable [xmin,ymin,xmax,ymax]
  101. :param box2: Iterable [xmin,ymin,xmax,ymax]
  102. :return: 1: box1 is contained 2: box2 is contained None: no contain these
  103. """
  104. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  105. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  106. # 计算轴重叠大小
  107. if axis == "x":
  108. b1_area = b1_x2 - b1_x1
  109. b2_area = b2_x2 - b2_x1
  110. i_area = min(b1_x2, b2_x2) - max(b1_x1, b2_x1)
  111. else:
  112. b1_area = b1_y2 - b1_y1
  113. b2_area = b2_y2 - b2_y1
  114. i_area = min(b1_y2, b2_y2) - max(b1_y1, b2_y1)
  115. # 计算外面的面积
  116. b1_outside_area = b1_area - i_area
  117. b2_outside_area = b2_area - i_area
  118. ratio_b1 = b1_outside_area / b1_area if b1_area > 0 else 0
  119. ratio_b2 = b2_outside_area / b2_area if b2_area > 0 else 0
  120. if ratio_b1 < threshold:
  121. return 1
  122. if ratio_b2 < threshold:
  123. return 2
  124. return None
  125. def sorted_ocr_boxes(
  126. dt_boxes: Union[np.ndarray, list], threshold: float = 0.2
  127. ) -> Tuple[Union[np.ndarray, list], List[int]]:
  128. """
  129. Sort text boxes in order from top to bottom, left to right
  130. args:
  131. dt_boxes(array):detected text boxes with (xmin, ymin, xmax, ymax)
  132. return:
  133. sorted boxes(array) with (xmin, ymin, xmax, ymax)
  134. """
  135. num_boxes = len(dt_boxes)
  136. if num_boxes <= 0:
  137. return dt_boxes, []
  138. indexed_boxes = [(box, idx) for idx, box in enumerate(dt_boxes)]
  139. sorted_boxes_with_idx = sorted(indexed_boxes, key=lambda x: (x[0][1], x[0][0]))
  140. _boxes, indices = zip(*sorted_boxes_with_idx)
  141. indices = list(indices)
  142. _boxes = [dt_boxes[i] for i in indices]
  143. # 避免输出和输入格式不对应,与函数功能不符合
  144. if isinstance(dt_boxes, np.ndarray):
  145. _boxes = np.array(_boxes)
  146. for i in range(num_boxes - 1):
  147. for j in range(i, -1, -1):
  148. c_idx = is_single_axis_contained(
  149. _boxes[j], _boxes[j + 1], axis="y", threshold=threshold
  150. )
  151. if (
  152. c_idx is not None
  153. and _boxes[j + 1][0] < _boxes[j][0]
  154. and abs(_boxes[j][1] - _boxes[j + 1][1]) < 20
  155. ):
  156. _boxes[j], _boxes[j + 1] = _boxes[j + 1].copy(), _boxes[j].copy()
  157. indices[j], indices[j + 1] = indices[j + 1], indices[j]
  158. else:
  159. break
  160. return _boxes, indices
  161. def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]:
  162. """
  163. 将poly_box转换为box_4_1
  164. :param poly_box:
  165. :return:
  166. """
  167. return [poly_box[0][0], poly_box[0][1], poly_box[2][0], poly_box[2][1]]
  168. def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.ndarray):
  169. """
  170. :param dt_rec_boxes: [[(4.2), text, score]]
  171. :param pred_bboxes: shap (4,2)
  172. :return:
  173. """
  174. matched = {}
  175. not_match_orc_boxes = []
  176. for i, gt_box in enumerate(dt_rec_boxes):
  177. for j, pred_box in enumerate(pred_bboxes):
  178. pred_box = [pred_box[0][0], pred_box[0][1], pred_box[2][0], pred_box[2][1]]
  179. ocr_boxes = gt_box[0]
  180. # xmin,ymin,xmax,ymax
  181. ocr_box = (
  182. ocr_boxes[0][0],
  183. ocr_boxes[0][1],
  184. ocr_boxes[2][0],
  185. ocr_boxes[2][1],
  186. )
  187. contained = is_box_contained(ocr_box, pred_box, 0.6)
  188. if contained == 1 or calculate_iou(ocr_box, pred_box) > 0.8:
  189. if j not in matched:
  190. matched[j] = [gt_box]
  191. else:
  192. matched[j].append(gt_box)
  193. else:
  194. not_match_orc_boxes.append(gt_box)
  195. return matched, not_match_orc_boxes
  196. def gather_ocr_list_by_row(ocr_list: List[Any], threshold: float = 0.2) -> List[Any]:
  197. """
  198. Groups OCR results by row based on the vertical (y-axis) overlap of their bounding boxes.
  199. Args:
  200. ocr_list (List[Any]): A list of OCR results, where each item is a list containing a bounding box
  201. in the format [xmin, ymin, xmax, ymax] and the recognized text.
  202. threshold (float, optional): The threshold for determining if two boxes are in the same row,
  203. based on their y-axis overlap. Default is 0.2.
  204. Returns:
  205. List[Any]: A new list of OCR results where texts in the same row are merged, and their bounding
  206. boxes are updated to encompass the merged text.
  207. """
  208. for i in range(len(ocr_list)):
  209. if not ocr_list[i]:
  210. continue
  211. for j in range(i + 1, len(ocr_list)):
  212. if not ocr_list[j]:
  213. continue
  214. cur = ocr_list[i]
  215. next = ocr_list[j]
  216. cur_box = cur[0]
  217. next_box = next[0]
  218. c_idx = is_single_axis_contained(
  219. cur[0], next[0], axis="y", threshold=threshold
  220. )
  221. if c_idx:
  222. dis = max(next_box[0] - cur_box[2], 0)
  223. blank_str = int(dis / 10) * " "
  224. cur[1] = cur[1] + blank_str + next[1]
  225. xmin = min(cur_box[0], next_box[0])
  226. xmax = max(cur_box[2], next_box[2])
  227. ymin = min(cur_box[1], next_box[1])
  228. ymax = max(cur_box[3], next_box[3])
  229. cur_box[0] = xmin
  230. cur_box[1] = ymin
  231. cur_box[2] = xmax
  232. cur_box[3] = ymax
  233. ocr_list[j] = None
  234. ocr_list = [x for x in ocr_list if x]
  235. return ocr_list
  236. def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float:
  237. """计算两个多边形的IOU
  238. Args:
  239. poly1 (np.ndarray): (4, 2)
  240. poly2 (np.ndarray): (4, 2)
  241. Returns:
  242. float: iou
  243. """
  244. poly1 = Polygon(a).convex_hull
  245. poly2 = Polygon(b).convex_hull
  246. union_poly = np.concatenate((a, b))
  247. if not poly1.intersects(poly2):
  248. return 0.0
  249. try:
  250. inter_area = poly1.intersection(poly2).area
  251. union_area = MultiPoint(union_poly).convex_hull.area
  252. except shapely.geos.TopologicalError:
  253. print("shapely.geos.TopologicalError occured, iou set to 0")
  254. return 0.0
  255. if union_area == 0:
  256. return 0.0
  257. return float(inter_area) / union_area
  258. def merge_adjacent_polys(polygons: np.ndarray) -> np.ndarray:
  259. """合并相邻iou大于阈值的框"""
  260. combine_iou_thresh = 0.1
  261. pair_polygons = list(zip(polygons, polygons[1:, ...]))
  262. pair_ious = np.array([compute_poly_iou(p1, p2) for p1, p2 in pair_polygons])
  263. idxs = np.argwhere(pair_ious >= combine_iou_thresh)
  264. if idxs.size <= 0:
  265. return polygons
  266. polygons = combine_two_poly(polygons, idxs)
  267. # 注意:递归调用
  268. polygons = merge_adjacent_polys(polygons)
  269. return polygons
  270. def combine_two_poly(polygons: np.ndarray, idxs: np.ndarray) -> np.ndarray:
  271. del_idxs, insert_boxes = [], []
  272. idxs = idxs.squeeze(-1)
  273. for idx in idxs:
  274. # idx 和 idx + 1 是重合度过高的
  275. # 合并,取两者各个点的最大值
  276. new_poly = []
  277. pre_poly, pos_poly = polygons[idx], polygons[idx + 1]
  278. # 四个点,每个点逐一比较
  279. new_poly.append(np.minimum(pre_poly[0], pos_poly[0]))
  280. x_2 = min(pre_poly[1][0], pos_poly[1][0])
  281. y_2 = max(pre_poly[1][1], pos_poly[1][1])
  282. new_poly.append([x_2, y_2])
  283. # 第3个点
  284. new_poly.append(np.maximum(pre_poly[2], pos_poly[2]))
  285. # 第4个点
  286. x_4 = max(pre_poly[3][0], pos_poly[3][0])
  287. y_4 = min(pre_poly[3][1], pos_poly[3][1])
  288. new_poly.append([x_4, y_4])
  289. new_poly = np.array(new_poly)
  290. # 删除已经合并的两个框,插入新的框
  291. del_idxs.extend([idx, idx + 1])
  292. insert_boxes.append(new_poly)
  293. # 整合合并后的框
  294. polygons = np.delete(polygons, del_idxs, axis=0)
  295. insert_boxes = np.array(insert_boxes)
  296. polygons = np.append(polygons, insert_boxes, axis=0)
  297. polygons = sorted_boxes(polygons)
  298. return polygons
  299. def plot_html_table(
  300. logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]]
  301. ) -> str:
  302. # 初始化最大行数和列数
  303. max_row = 0
  304. max_col = 0
  305. # 计算最大行数和列数
  306. for point in logi_points:
  307. max_row = max(max_row, point[1] + 1) # 加1是因为结束下标是包含在内的
  308. max_col = max(max_col, point[3] + 1) # 加1是因为结束下标是包含在内的
  309. # 创建一个二维数组来存储 sorted_logi_points 中的元素
  310. grid = [[None] * max_col for _ in range(max_row)]
  311. valid_start_row = (1 << 16) - 1
  312. valid_start_col = (1 << 16) - 1
  313. valid_end_col = 0
  314. # 将 sorted_logi_points 中的元素填充到 grid 中
  315. for i, logic_point in enumerate(logi_points):
  316. row_start, row_end, col_start, col_end = (
  317. logic_point[0],
  318. logic_point[1],
  319. logic_point[2],
  320. logic_point[3],
  321. )
  322. ocr_rec_text_list = cell_box_map.get(i)
  323. if ocr_rec_text_list and "".join(ocr_rec_text_list):
  324. valid_start_row = min(row_start, valid_start_row)
  325. valid_start_col = min(col_start, valid_start_col)
  326. valid_end_col = max(col_end, valid_end_col)
  327. for row in range(row_start, row_end + 1):
  328. for col in range(col_start, col_end + 1):
  329. grid[row][col] = (i, row_start, row_end, col_start, col_end)
  330. # 创建表格
  331. table_html = "<html><body><table>"
  332. # 遍历每行
  333. for row in range(max_row):
  334. if row < valid_start_row:
  335. continue
  336. temp = "<tr>"
  337. # 遍历每一列
  338. for col in range(max_col):
  339. if col < valid_start_col or col > valid_end_col:
  340. continue
  341. if not grid[row][col]:
  342. temp += "<td></td>"
  343. else:
  344. i, row_start, row_end, col_start, col_end = grid[row][col]
  345. if not cell_box_map.get(i):
  346. continue
  347. if row == row_start and col == col_start:
  348. ocr_rec_text = cell_box_map.get(i)
  349. text = "<br>".join(ocr_rec_text)
  350. # 如果是起始单元格
  351. row_span = row_end - row_start + 1
  352. col_span = col_end - col_start + 1
  353. cell_content = (
  354. f"<td rowspan={row_span} colspan={col_span}>{text}</td>"
  355. )
  356. temp += cell_content
  357. table_html = table_html + temp + "</tr>"
  358. table_html += "</table></body></html>"
  359. return table_html