matcher.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from .matcher_utils import compute_iou, distance
  16. class TableMatch:
  17. def __init__(self, filter_ocr_result=True, use_master=False):
  18. self.filter_ocr_result = filter_ocr_result
  19. self.use_master = use_master
  20. def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
  21. if self.filter_ocr_result:
  22. dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
  23. matched_index = self.match_result(dt_boxes, cell_bboxes)
  24. pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
  25. return pred_html
  26. def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
  27. matched = {}
  28. for i, gt_box in enumerate(dt_boxes):
  29. distances = []
  30. for j, pred_box in enumerate(cell_bboxes):
  31. if len(pred_box) == 8:
  32. pred_box = [
  33. np.min(pred_box[0::2]),
  34. np.min(pred_box[1::2]),
  35. np.max(pred_box[0::2]),
  36. np.max(pred_box[1::2]),
  37. ]
  38. distances.append(
  39. (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
  40. ) # compute iou and l1 distance
  41. sorted_distances = distances.copy()
  42. # select det box by iou and l1 distance
  43. sorted_distances = sorted(
  44. sorted_distances, key=lambda item: (item[1], item[0])
  45. )
  46. # must > min_iou
  47. if sorted_distances[0][1] >= 1 - min_iou:
  48. continue
  49. if distances.index(sorted_distances[0]) not in matched:
  50. matched[distances.index(sorted_distances[0])] = [i]
  51. else:
  52. matched[distances.index(sorted_distances[0])].append(i)
  53. return matched
  54. def get_pred_html(self, pred_structures, matched_index, ocr_contents):
  55. end_html = []
  56. td_index = 0
  57. for tag in pred_structures:
  58. if "</td>" not in tag:
  59. end_html.append(tag)
  60. continue
  61. if "<td></td>" == tag:
  62. end_html.extend("<td>")
  63. if td_index in matched_index.keys():
  64. b_with = False
  65. if (
  66. "<b>" in ocr_contents[matched_index[td_index][0]]
  67. and len(matched_index[td_index]) > 1
  68. ):
  69. b_with = True
  70. end_html.extend("<b>")
  71. for i, td_index_index in enumerate(matched_index[td_index]):
  72. content = ocr_contents[td_index_index][0]
  73. if len(matched_index[td_index]) > 1:
  74. if len(content) == 0:
  75. continue
  76. if content[0] == " ":
  77. content = content[1:]
  78. if "<b>" in content:
  79. content = content[3:]
  80. if "</b>" in content:
  81. content = content[:-4]
  82. if len(content) == 0:
  83. continue
  84. if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
  85. content += " "
  86. end_html.extend(content)
  87. if b_with:
  88. end_html.extend("</b>")
  89. if "<td></td>" == tag:
  90. end_html.append("</td>")
  91. else:
  92. end_html.append(tag)
  93. td_index += 1
  94. # Filter <thead></thead><tbody></tbody> elements
  95. filter_elements = ["<thead>", "</thead>", "<tbody>", "</tbody>"]
  96. end_html = [v for v in end_html if v not in filter_elements]
  97. return "".join(end_html), end_html
  98. def decode_logic_points(self, pred_structures):
  99. logic_points = []
  100. current_row = 0
  101. current_col = 0
  102. max_rows = 0
  103. max_cols = 0
  104. occupied_cells = {} # 用于记录已经被占用的单元格
  105. def is_occupied(row, col):
  106. return (row, col) in occupied_cells
  107. def mark_occupied(row, col, rowspan, colspan):
  108. for r in range(row, row + rowspan):
  109. for c in range(col, col + colspan):
  110. occupied_cells[(r, c)] = True
  111. i = 0
  112. while i < len(pred_structures):
  113. token = pred_structures[i]
  114. if token == "<tr>":
  115. current_col = 0 # 每次遇到 <tr> 时,重置当前列号
  116. elif token == "</tr>":
  117. current_row += 1 # 行结束,行号增加
  118. elif token.startswith("<td"):
  119. colspan = 1
  120. rowspan = 1
  121. j = i
  122. if token != "<td></td>":
  123. j += 1
  124. # 提取 colspan 和 rowspan 属性
  125. while j < len(pred_structures) and not pred_structures[
  126. j
  127. ].startswith(">"):
  128. if "colspan=" in pred_structures[j]:
  129. colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
  130. elif "rowspan=" in pred_structures[j]:
  131. rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
  132. j += 1
  133. # 跳过已经处理过的属性 token
  134. i = j
  135. # 找到下一个未被占用的列
  136. while is_occupied(current_row, current_col):
  137. current_col += 1
  138. # 计算逻辑坐标
  139. r_start = current_row
  140. r_end = current_row + rowspan - 1
  141. col_start = current_col
  142. col_end = current_col + colspan - 1
  143. # 记录逻辑坐标
  144. logic_points.append([r_start, r_end, col_start, col_end])
  145. # 标记占用的单元格
  146. mark_occupied(r_start, col_start, rowspan, colspan)
  147. # 更新当前列号
  148. current_col += colspan
  149. # 更新最大行数和列数
  150. max_rows = max(max_rows, r_end + 1)
  151. max_cols = max(max_cols, col_end + 1)
  152. i += 1
  153. return logic_points
  154. def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
  155. y1 = cell_bboxes[:, 1::2].min()
  156. new_dt_boxes = []
  157. new_rec_res = []
  158. for box, rec in zip(dt_boxes, rec_res):
  159. if np.max(box[1::2]) < y1:
  160. continue
  161. new_dt_boxes.append(box)
  162. new_rec_res.append(rec)
  163. return new_dt_boxes, new_rec_res