matcher_utils.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # -*- encoding: utf-8 -*-
  15. # @Author: SWHL
  16. # @Contact: liekkaskono@163.com
  17. import copy
  18. import re
  19. def deal_isolate_span(thead_part):
  20. """
  21. Deal with isolate span cases in this function.
  22. It causes by wrong prediction in structure recognition model.
  23. eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
  24. :param thead_part:
  25. :return:
  26. """
  27. # 1. find out isolate span tokens.
  28. isolate_pattern = (
  29. r"<td></td> rowspan='(\d)+' colspan='(\d)+'></b></td>|"
  30. r"<td></td> colspan='(\d)+' rowspan='(\d)+'></b></td>|"
  31. r"<td></td> rowspan='(\d)+'></b></td>|"
  32. r"<td></td> colspan='(\d)+'></b></td>"
  33. )
  34. isolate_iter = re.finditer(isolate_pattern, thead_part)
  35. isolate_list = [i.group() for i in isolate_iter]
  36. # 2. find out span number, by step 1 result.
  37. span_pattern = (
  38. r" rowspan='(\d)+' colspan='(\d)+'|"
  39. r" colspan='(\d)+' rowspan='(\d)+'|"
  40. r" rowspan='(\d)+'|"
  41. r" colspan='(\d)+'"
  42. )
  43. corrected_list = []
  44. for isolate_item in isolate_list:
  45. span_part = re.search(span_pattern, isolate_item)
  46. spanStr_in_isolateItem = span_part.group()
  47. # 3. merge the span number into the span token format string.
  48. if spanStr_in_isolateItem is not None:
  49. corrected_item = f"<td{spanStr_in_isolateItem}></td>"
  50. corrected_list.append(corrected_item)
  51. else:
  52. corrected_list.append(None)
  53. # 4. replace original isolated token.
  54. for corrected_item, isolate_item in zip(corrected_list, isolate_list):
  55. if corrected_item is not None:
  56. thead_part = thead_part.replace(isolate_item, corrected_item)
  57. else:
  58. pass
  59. return thead_part
  60. def deal_duplicate_bb(thead_part):
  61. """
  62. Deal duplicate <b> or </b> after replace.
  63. Keep one <b></b> in a <td></td> token.
  64. :param thead_part:
  65. :return:
  66. """
  67. # 1. find out <td></td> in <thead></thead>.
  68. td_pattern = (
  69. r"<td rowspan='(\d)+' colspan='(\d)+'>(.+?)</td>|"
  70. r"<td colspan='(\d)+' rowspan='(\d)+'>(.+?)</td>|"
  71. r"<td rowspan='(\d)+'>(.+?)</td>|"
  72. r"<td colspan='(\d)+'>(.+?)</td>|"
  73. r"<td>(.*?)</td>"
  74. )
  75. td_iter = re.finditer(td_pattern, thead_part)
  76. td_list = [t.group() for t in td_iter]
  77. # 2. is multiply <b></b> in <td></td> or not?
  78. new_td_list = []
  79. for td_item in td_list:
  80. if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
  81. # multiply <b></b> in <td></td> case.
  82. # 1. remove all <b></b>
  83. td_item = td_item.replace("<b>", "").replace("</b>", "")
  84. # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
  85. td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
  86. new_td_list.append(td_item)
  87. else:
  88. new_td_list.append(td_item)
  89. # 3. replace original thead part.
  90. for td_item, new_td_item in zip(td_list, new_td_list):
  91. thead_part = thead_part.replace(td_item, new_td_item)
  92. return thead_part
  93. def deal_bb(result_token):
  94. """
  95. In our opinion, <b></b> always occurs in <thead></thead> text's context.
  96. This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
  97. :param result_token:
  98. :return:
  99. """
  100. # find out <thead></thead> parts.
  101. thead_pattern = "<thead>(.*?)</thead>"
  102. if re.search(thead_pattern, result_token) is None:
  103. return result_token
  104. thead_part = re.search(thead_pattern, result_token).group()
  105. origin_thead_part = copy.deepcopy(thead_part)
  106. # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
  107. span_pattern = r"<td rowspan='(\d)+' colspan='(\d)+'>|<td colspan='(\d)+' rowspan='(\d)+'>|<td rowspan='(\d)+'>|<td colspan='(\d)+'>"
  108. span_iter = re.finditer(span_pattern, thead_part)
  109. span_list = [s.group() for s in span_iter]
  110. has_span_in_head = True if len(span_list) > 0 else False
  111. if not has_span_in_head:
  112. # <thead></thead> not include "rowspan" or "colspan" branch 1.
  113. # 1. replace <td> to <td><b>, and </td> to </b></td>
  114. # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
  115. # so we replace <b><b> to <b>, and </b></b> to </b>
  116. thead_part = (
  117. thead_part.replace("<td>", "<td><b>")
  118. .replace("</td>", "</b></td>")
  119. .replace("<b><b>", "<b>")
  120. .replace("</b></b>", "</b>")
  121. )
  122. else:
  123. # <thead></thead> include "rowspan" or "colspan" branch 2.
  124. # Firstly, we deal rowspan or colspan cases.
  125. # 1. replace > to ><b>
  126. # 2. replace </td> to </b></td>
  127. # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
  128. # so we replace <b><b> to <b>, and </b><b> to </b>
  129. # Secondly, deal ordinary cases like branch 1
  130. # replace ">" to "<b>"
  131. replaced_span_list = []
  132. for sp in span_list:
  133. replaced_span_list.append(sp.replace(">", "><b>"))
  134. for sp, rsp in zip(span_list, replaced_span_list):
  135. thead_part = thead_part.replace(sp, rsp)
  136. # replace "</td>" to "</b></td>"
  137. thead_part = thead_part.replace("</td>", "</b></td>")
  138. # remove duplicated <b> by re.sub
  139. mb_pattern = "(<b>)+"
  140. single_b_string = "<b>"
  141. thead_part = re.sub(mb_pattern, single_b_string, thead_part)
  142. mgb_pattern = "(</b>)+"
  143. single_gb_string = "</b>"
  144. thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
  145. # ordinary cases like branch 1
  146. thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
  147. # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
  148. # but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
  149. thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
  150. # deal with duplicated <b></b>
  151. thead_part = deal_duplicate_bb(thead_part)
  152. # deal with isolate span tokens, which causes by wrong predict by structure prediction.
  153. # eg.PMC5994107_011_00.png
  154. thead_part = deal_isolate_span(thead_part)
  155. # replace original result with new thead part.
  156. result_token = result_token.replace(origin_thead_part, thead_part)
  157. return result_token
  158. def deal_eb_token(master_token):
  159. """
  160. post process with <eb></eb>, <eb1></eb1>, ...
  161. emptyBboxTokenDict = {
  162. "[]": '<eb></eb>',
  163. "[' ']": '<eb1></eb1>',
  164. "['<b>', ' ', '</b>']": '<eb2></eb2>',
  165. "['\\u2028', '\\u2028']": '<eb3></eb3>',
  166. "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
  167. "['<b>', '</b>']": '<eb5></eb5>',
  168. "['<i>', ' ', '</i>']": '<eb6></eb6>',
  169. "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
  170. "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
  171. "['<i>', '</i>']": '<eb9></eb9>',
  172. "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
  173. }
  174. :param master_token:
  175. :return:
  176. """
  177. master_token = master_token.replace("<eb></eb>", "<td></td>")
  178. master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
  179. master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
  180. master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
  181. master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
  182. master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
  183. master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
  184. master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
  185. master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
  186. master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
  187. master_token = master_token.replace(
  188. "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
  189. )
  190. return master_token
  191. def distance(box_1, box_2):
  192. x1, y1, x2, y2 = box_1
  193. x3, y3, x4, y4 = box_2
  194. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  195. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  196. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  197. return dis + min(dis_2, dis_3)
  198. def compute_iou(rec1, rec2):
  199. """
  200. computing IoU
  201. :param rec1: (y0, x0, y1, x1), which reflects
  202. (top, left, bottom, right)
  203. :param rec2: (y0, x0, y1, x1)
  204. :return: scala value of IoU
  205. """
  206. # computing area of each rectangles
  207. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  208. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  209. # computing the sum_area
  210. sum_area = S_rec1 + S_rec2
  211. # find the each edge of intersect rectangle
  212. left_line = max(rec1[1], rec2[1])
  213. right_line = min(rec1[3], rec2[3])
  214. top_line = max(rec1[0], rec2[0])
  215. bottom_line = min(rec1[2], rec2[2])
  216. # judge if there is an intersect
  217. if left_line >= right_line or top_line >= bottom_line:
  218. return 0.0
  219. intersect = (right_line - left_line) * (bottom_line - top_line)
  220. return (intersect / (sum_area - intersect)) * 1.0