utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import re
  16. import copy
  17. __all__ = [
  18. "TableMatch",
  19. "convert_4point2rect",
  20. "get_ori_coordinate_for_table",
  21. "is_inside",
  22. ]
  23. def deal_eb_token(master_token):
  24. """
  25. post process with <eb></eb>, <eb1></eb1>, ...
  26. emptyBboxTokenDict = {
  27. "[]": '<eb></eb>',
  28. "[' ']": '<eb1></eb1>',
  29. "['<b>', ' ', '</b>']": '<eb2></eb2>',
  30. "['\\u2028', '\\u2028']": '<eb3></eb3>',
  31. "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
  32. "['<b>', '</b>']": '<eb5></eb5>',
  33. "['<i>', ' ', '</i>']": '<eb6></eb6>',
  34. "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
  35. "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
  36. "['<i>', '</i>']": '<eb9></eb9>',
  37. "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
  38. }
  39. :param master_token:
  40. :return:
  41. """
  42. master_token = master_token.replace("<eb></eb>", "<td></td>")
  43. master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
  44. master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
  45. master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
  46. master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
  47. master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
  48. master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
  49. master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
  50. master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
  51. master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
  52. master_token = master_token.replace(
  53. "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
  54. )
  55. return master_token
  56. def deal_bb(result_token):
  57. """
  58. In our opinion, <b></b> always occurs in <thead></thead> text's context.
  59. This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
  60. :param result_token:
  61. :return:
  62. """
  63. # find out <thead></thead> parts.
  64. thead_pattern = "<thead>(.*?)</thead>"
  65. if re.search(thead_pattern, result_token) is None:
  66. return result_token
  67. thead_part = re.search(thead_pattern, result_token).group()
  68. origin_thead_part = copy.deepcopy(thead_part)
  69. # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
  70. span_pattern = (
  71. '<td rowspan="(\d)+" colspan="(\d)+">|<td colspan="(\d)+" rowspan="(\d)+">|<td rowspan'
  72. '="(\d)+">|<td colspan="(\d)+">'
  73. )
  74. span_iter = re.finditer(span_pattern, thead_part)
  75. span_list = [s.group() for s in span_iter]
  76. has_span_in_head = True if len(span_list) > 0 else False
  77. if not has_span_in_head:
  78. # <thead></thead> not include "rowspan" or "colspan" branch 1.
  79. # 1. replace <td> to <td><b>, and </td> to </b></td>
  80. # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
  81. # so we replace <b><b> to <b>, and </b></b> to </b>
  82. thead_part = (
  83. thead_part.replace("<td>", "<td><b>")
  84. .replace("</td>", "</b></td>")
  85. .replace("<b><b>", "<b>")
  86. .replace("</b></b>", "</b>")
  87. )
  88. else:
  89. # <thead></thead> include "rowspan" or "colspan" branch 2.
  90. # Firstly, we deal rowspan or colspan cases.
  91. # 1. replace > to ><b>
  92. # 2. replace </td> to </b></td>
  93. # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
  94. # so we replace <b><b> to <b>, and </b><b> to </b>
  95. # Secondly, deal ordinary cases like branch 1
  96. # replace ">" to "<b>"
  97. replaced_span_list = []
  98. for sp in span_list:
  99. replaced_span_list.append(sp.replace(">", "><b>"))
  100. for sp, rsp in zip(span_list, replaced_span_list):
  101. thead_part = thead_part.replace(sp, rsp)
  102. # replace "</td>" to "</b></td>"
  103. thead_part = thead_part.replace("</td>", "</b></td>")
  104. # remove duplicated <b> by re.sub
  105. mb_pattern = "(<b>)+"
  106. single_b_string = "<b>"
  107. thead_part = re.sub(mb_pattern, single_b_string, thead_part)
  108. mgb_pattern = "(</b>)+"
  109. single_gb_string = "</b>"
  110. thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
  111. # ordinary cases like branch 1
  112. thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
  113. # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
  114. # but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
  115. thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
  116. # deal with duplicated <b></b>
  117. thead_part = deal_duplicate_bb(thead_part)
  118. # deal with isolate span tokens, which causes by wrong predict by structure prediction.
  119. # eg.PMC5994107_011_00.png
  120. thead_part = deal_isolate_span(thead_part)
  121. # replace original result with new thead part.
  122. result_token = result_token.replace(origin_thead_part, thead_part)
  123. return result_token
  124. def deal_isolate_span(thead_part):
  125. """
  126. Deal with isolate span cases in this function.
  127. It causes by wrong prediction in structure recognition model.
  128. eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
  129. :param thead_part:
  130. :return:
  131. """
  132. # 1. find out isolate span tokens.
  133. isolate_pattern = (
  134. '<td></td> rowspan="(\d)+" colspan="(\d)+"></b></td>|'
  135. '<td></td> colspan="(\d)+" rowspan="(\d)+"></b></td>|'
  136. '<td></td> rowspan="(\d)+"></b></td>|'
  137. '<td></td> colspan="(\d)+"></b></td>'
  138. )
  139. isolate_iter = re.finditer(isolate_pattern, thead_part)
  140. isolate_list = [i.group() for i in isolate_iter]
  141. # 2. find out span number, by step 1 results.
  142. span_pattern = (
  143. ' rowspan="(\d)+" colspan="(\d)+"|'
  144. ' colspan="(\d)+" rowspan="(\d)+"|'
  145. ' rowspan="(\d)+"|'
  146. ' colspan="(\d)+"'
  147. )
  148. corrected_list = []
  149. for isolate_item in isolate_list:
  150. span_part = re.search(span_pattern, isolate_item)
  151. spanStr_in_isolateItem = span_part.group()
  152. # 3. merge the span number into the span token format string.
  153. if spanStr_in_isolateItem is not None:
  154. corrected_item = "<td{}></td>".format(spanStr_in_isolateItem)
  155. corrected_list.append(corrected_item)
  156. else:
  157. corrected_list.append(None)
  158. # 4. replace original isolated token.
  159. for corrected_item, isolate_item in zip(corrected_list, isolate_list):
  160. if corrected_item is not None:
  161. thead_part = thead_part.replace(isolate_item, corrected_item)
  162. else:
  163. pass
  164. return thead_part
  165. def deal_duplicate_bb(thead_part):
  166. """
  167. Deal duplicate <b> or </b> after replace.
  168. Keep one <b></b> in a <td></td> token.
  169. :param thead_part:
  170. :return:
  171. """
  172. # 1. find out <td></td> in <thead></thead>.
  173. td_pattern = (
  174. '<td rowspan="(\d)+" colspan="(\d)+">(.+?)</td>|'
  175. '<td colspan="(\d)+" rowspan="(\d)+">(.+?)</td>|'
  176. '<td rowspan="(\d)+">(.+?)</td>|'
  177. '<td colspan="(\d)+">(.+?)</td>|'
  178. "<td>(.*?)</td>"
  179. )
  180. td_iter = re.finditer(td_pattern, thead_part)
  181. td_list = [t.group() for t in td_iter]
  182. # 2. is multiply <b></b> in <td></td> or not?
  183. new_td_list = []
  184. for td_item in td_list:
  185. if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
  186. # multiply <b></b> in <td></td> case.
  187. # 1. remove all <b></b>
  188. td_item = td_item.replace("<b>", "").replace("</b>", "")
  189. # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
  190. td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
  191. new_td_list.append(td_item)
  192. else:
  193. new_td_list.append(td_item)
  194. # 3. replace original thead part.
  195. for td_item, new_td_item in zip(td_list, new_td_list):
  196. thead_part = thead_part.replace(td_item, new_td_item)
  197. return thead_part
  198. def distance(box_1, box_2):
  199. """
  200. compute the distance between two boxes
  201. Args:
  202. box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
  203. box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
  204. Returns:
  205. int: the distance between two boxes
  206. """
  207. x1, y1, x2, y2 = box_1
  208. x3, y3, x4, y4 = box_2
  209. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  210. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  211. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  212. return dis + min(dis_2, dis_3)
  213. def compute_iou(rec1, rec2):
  214. """
  215. computing IoU
  216. Args:
  217. rec1 (list): (x1, y1, x2, y2)
  218. rec2 (list): (x1, y1, x2, y2)
  219. Returns:
  220. float: Intersection over Union
  221. """
  222. # computing area of each rectangles
  223. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  224. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  225. # computing the sum_area
  226. sum_area = S_rec1 + S_rec2
  227. # find the each edge of intersect rectangle
  228. left_line = max(rec1[0], rec2[0])
  229. right_line = min(rec1[2], rec2[2])
  230. top_line = max(rec1[1], rec2[1])
  231. bottom_line = min(rec1[3], rec2[3])
  232. # judge if there is an intersect
  233. if left_line >= right_line or top_line >= bottom_line:
  234. return 0.0
  235. else:
  236. intersect = (right_line - left_line) * (bottom_line - top_line)
  237. return (intersect / (sum_area - intersect)) * 1.0
  238. def convert_4point2rect(bbox):
  239. """
  240. Convert 4 point coordinate to rectangle coordinate
  241. Args:
  242. bbox (list): list of 4 points, eg. [x1, y1, x2, y2,...] or [[x1,y1],[x2,y2],...]
  243. """
  244. if isinstance(bbox, list):
  245. bbox = np.array(bbox)
  246. if bbox.shape[0] == 8:
  247. bbox = np.reshape(bbox, (4, 2))
  248. x1 = min(bbox[:, 0])
  249. y1 = min(bbox[:, 1])
  250. x2 = max(bbox[:, 0])
  251. y2 = max(bbox[:, 1])
  252. return [x1, y1, x2, y2]
  253. def get_ori_coordinate_for_table(x, y, table_bbox):
  254. """
  255. get the original coordinate from Cropped image to Original image.
  256. Args:
  257. x (int): x coordinate of cropped image
  258. y (int): y coordinate of cropped image
  259. table_bbox (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  260. Returns:
  261. list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
  262. """
  263. if not table_bbox:
  264. return table_bbox
  265. offset = np.array([x, y] * 4)
  266. table_bbox = np.array(table_bbox)
  267. if table_bbox.shape[-1] == 2:
  268. offset = offset.reshape(4, 2)
  269. return offset + table_bbox
  270. def is_inside(target_box, text_box):
  271. """
  272. check if text box is inside target box
  273. Args:
  274. target_box (list): target box where we want to detect, eg. [x1, y1, x2, y2]
  275. text_box (list): text box, eg. [x1, y1, x2, y2]
  276. Returns:
  277. bool: True if text box is inside target box
  278. """
  279. x1_1, y1_1, x2_1, y2_1 = target_box
  280. x1_2, y1_2, x2_2, y2_2 = text_box
  281. inter_x1 = max(x1_1, x1_2)
  282. inter_y1 = max(y1_1, y1_2)
  283. inter_x2 = min(x2_1, x2_2)
  284. inter_y2 = min(y2_1, y2_2)
  285. if inter_x1 < inter_x2 and inter_y1 < inter_y2:
  286. inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
  287. else:
  288. inter_area = 0
  289. area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
  290. area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
  291. union_area = area1 + area2 - inter_area
  292. iou = inter_area / union_area if union_area != 0 else 0
  293. return iou > 0
  294. class TableMatch(object):
  295. """
  296. match table html and ocr res
  297. """
  298. def __init__(self, filter_ocr_result=False):
  299. self.filter_ocr_result = filter_ocr_result
  300. def __call__(self, table_pred, ocr_pred):
  301. structures = table_pred["structure"]
  302. table_boxes = table_pred["bbox"]
  303. ocr_dt_ploys = ocr_pred["dt_polys"]
  304. ocr_text_res = ocr_pred["rec_text"]
  305. if self.filter_ocr_result:
  306. ocr_dt_ploys, ocr_text_res = self._filter_ocr_result(
  307. table_boxes, ocr_dt_ploys, ocr_text_res
  308. )
  309. matched_index = self.metch_table_and_ocr(table_boxes, ocr_dt_ploys)
  310. pred_html = self.get_html_result(matched_index, ocr_text_res, structures)
  311. return pred_html
  312. def metch_table_and_ocr(self, table_boxes, ocr_boxes):
  313. """
  314. match table bo
  315. Args:
  316. table_boxes (list): bbox for table, 4 points, [x1,y1,x2,y2,x3,y3,x4,y4]
  317. ocr_boxes (list): bbox for ocr, 4 points, [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
  318. Returns:
  319. dict: matched dict, key is table index, value is ocr index
  320. """
  321. matched = {}
  322. for i, ocr_box in enumerate(np.array(ocr_boxes)):
  323. ocr_box = convert_4point2rect(ocr_box)
  324. distances = []
  325. for j, table_box in enumerate(table_boxes):
  326. table_box = convert_4point2rect(table_box)
  327. distances.append(
  328. (
  329. distance(table_box, ocr_box),
  330. 1.0 - compute_iou(table_box, ocr_box),
  331. )
  332. ) # compute iou and l1 distance
  333. sorted_distances = distances.copy()
  334. # select det box by iou and l1 distance
  335. sorted_distances = sorted(
  336. sorted_distances, key=lambda item: (item[1], item[0])
  337. )
  338. if distances.index(sorted_distances[0]) not in matched.keys():
  339. matched[distances.index(sorted_distances[0])] = [i]
  340. else:
  341. matched[distances.index(sorted_distances[0])].append(i)
  342. return matched
  343. def get_html_result(self, matched_index, ocr_contents, pred_structures):
  344. pred_html = []
  345. td_index = 0
  346. head_structure = pred_structures[0:3]
  347. html = "".join(head_structure)
  348. table_structure = pred_structures[3:-3]
  349. for tag in table_structure:
  350. if "</td>" in tag:
  351. if "<td></td>" == tag:
  352. pred_html.extend("<td>")
  353. if td_index in matched_index.keys():
  354. b_with = False
  355. if (
  356. "<b>" in ocr_contents[matched_index[td_index][0]]
  357. and len(matched_index[td_index]) > 1
  358. ):
  359. b_with = True
  360. pred_html.extend("<b>")
  361. for i, td_index_index in enumerate(matched_index[td_index]):
  362. content = ocr_contents[td_index_index]
  363. if len(matched_index[td_index]) > 1:
  364. if len(content) == 0:
  365. continue
  366. if content[0] == " ":
  367. content = content[1:]
  368. if "<b>" in content:
  369. content = content[3:]
  370. if "</b>" in content:
  371. content = content[:-4]
  372. if len(content) == 0:
  373. continue
  374. if (
  375. i != len(matched_index[td_index]) - 1
  376. and " " != content[-1]
  377. ):
  378. content += " "
  379. pred_html.extend(content)
  380. if b_with:
  381. pred_html.extend("</b>")
  382. if "<td></td>" == tag:
  383. pred_html.append("</td>")
  384. else:
  385. pred_html.append(tag)
  386. td_index += 1
  387. else:
  388. pred_html.append(tag)
  389. html += "".join(pred_html)
  390. end_structure = pred_structures[-3:]
  391. html += "".join(end_structure)
  392. return html
  393. def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
  394. y1 = pred_bboxes[:, 1::2].min()
  395. new_dt_boxes = []
  396. new_rec_res = []
  397. for box, rec in zip(dt_boxes, rec_res):
  398. if np.max(box[1::2]) < y1:
  399. continue
  400. new_dt_boxes.append(box)
  401. new_rec_res.append(rec)
  402. return new_dt_boxes, new_rec_res