# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import re
import copy
__all__ = [
"TableMatch",
"convert_4point2rect",
"get_ori_coordinate_for_table",
"is_inside",
]
def deal_eb_token(master_token):
"""
post process with , , ...
emptyBboxTokenDict = {
"[]": '',
"[' ']": '',
"['', ' ', '']": '',
"['\\u2028', '\\u2028']": '',
"['', ' ', '']": '',
"['', '']": '',
"['', ' ', '']": '',
"['', '', '', '']": '',
"['', '', ' ', '', '']": '',
"['', '']": '',
"['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '',
}
:param master_token:
:return:
"""
master_token = master_token.replace("", "
| ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace("", "\u2028\u2028 | ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace("", " | ")
master_token = master_token.replace(
"", " \u2028 \u2028 | "
)
return master_token
def deal_bb(result_token):
"""
In our opinion, always occurs in text's context.
This function will find out all tokens in and insert by manual.
:param result_token:
:return:
"""
# find out parts.
thead_pattern = "(.*?)"
if re.search(thead_pattern, result_token) is None:
return result_token
thead_part = re.search(thead_pattern, result_token).group()
origin_thead_part = copy.deepcopy(thead_part)
# check "rowspan" or "colspan" occur in parts or not .
span_pattern = (
'| | | | | | '
)
span_iter = re.finditer(span_pattern, thead_part)
span_list = [s.group() for s in span_iter]
has_span_in_head = True if len(span_list) > 0 else False
if not has_span_in_head:
# not include "rowspan" or "colspan" branch 1.
# 1. replace | to | , and | to
# 2. it is possible to predict text include or by Text-line recognition,
# so we replace to , and to
thead_part = (
thead_part.replace("", " | ")
.replace(" | ", "")
.replace("", "")
.replace("", "")
)
else:
# include "rowspan" or "colspan" branch 2.
# Firstly, we deal rowspan or colspan cases.
# 1. replace > to >
# 2. replace to
# 3. it is possible to predict text include or by Text-line recognition,
# so we replace to , and to
# Secondly, deal ordinary cases like branch 1
# replace ">" to ""
replaced_span_list = []
for sp in span_list:
replaced_span_list.append(sp.replace(">", ">"))
for sp, rsp in zip(span_list, replaced_span_list):
thead_part = thead_part.replace(sp, rsp)
# replace "" to ""
thead_part = thead_part.replace("", "")
# remove duplicated by re.sub
mb_pattern = "()+"
single_b_string = ""
thead_part = re.sub(mb_pattern, single_b_string, thead_part)
mgb_pattern = "()+"
single_gb_string = ""
thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
# ordinary cases like branch 1
thead_part = thead_part.replace("", " | ").replace("", "")
# convert back to , empty cell has no .
# but space cell( ) is suitable for | |
thead_part = thead_part.replace(" | ", " | ")
# deal with duplicated
thead_part = deal_duplicate_bb(thead_part)
# deal with isolate span tokens, which causes by wrong predict by structure prediction.
# eg.PMC5994107_011_00.png
thead_part = deal_isolate_span(thead_part)
# replace original result with new thead part.
result_token = result_token.replace(origin_thead_part, thead_part)
return result_token
def deal_isolate_span(thead_part):
"""
Deal with isolate span cases in this function.
It causes by wrong prediction in structure recognition model.
eg. predict | to | rowspan="2">.
:param thead_part:
:return:
"""
# 1. find out isolate span tokens.
isolate_pattern = (
' | rowspan="(\d)+" colspan="(\d)+">|'
' | colspan="(\d)+" rowspan="(\d)+">|'
' | rowspan="(\d)+">|'
' | colspan="(\d)+">'
)
isolate_iter = re.finditer(isolate_pattern, thead_part)
isolate_list = [i.group() for i in isolate_iter]
# 2. find out span number, by step 1 results.
span_pattern = (
' rowspan="(\d)+" colspan="(\d)+"|'
' colspan="(\d)+" rowspan="(\d)+"|'
' rowspan="(\d)+"|'
' colspan="(\d)+"'
)
corrected_list = []
for isolate_item in isolate_list:
span_part = re.search(span_pattern, isolate_item)
spanStr_in_isolateItem = span_part.group()
# 3. merge the span number into the span token format string.
if spanStr_in_isolateItem is not None:
corrected_item = " | ".format(spanStr_in_isolateItem)
corrected_list.append(corrected_item)
else:
corrected_list.append(None)
# 4. replace original isolated token.
for corrected_item, isolate_item in zip(corrected_list, isolate_list):
if corrected_item is not None:
thead_part = thead_part.replace(isolate_item, corrected_item)
else:
pass
return thead_part
def deal_duplicate_bb(thead_part):
"""
Deal duplicate or after replace.
Keep one in a | token.
:param thead_part:
:return:
"""
# 1. find out | in .
td_pattern = (
'(.+?) | |'
'(.+?) | |'
'(.+?) | |'
'(.+?) | |'
"(.*?) | "
)
td_iter = re.finditer(td_pattern, thead_part)
td_list = [t.group() for t in td_iter]
# 2. is multiply in | or not?
new_td_list = []
for td_item in td_list:
if td_item.count("") > 1 or td_item.count("") > 1:
# multiply in | case.
# 1. remove all
td_item = td_item.replace("", "").replace("", "")
# 2. replace -> , -> .
td_item = td_item.replace("", " | ").replace(" | ", "")
new_td_list.append(td_item)
else:
new_td_list.append(td_item)
# 3. replace original thead part.
for td_item, new_td_item in zip(td_list, new_td_list):
thead_part = thead_part.replace(td_item, new_td_item)
return thead_part
def distance(box_1, box_2):
"""
compute the distance between two boxes
Args:
box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
Returns:
int: the distance between two boxes
"""
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4 - x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
"""
computing IoU
Args:
rec1 (list): (x1, y1, x2, y2)
rec2 (list): (x1, y1, x2, y2)
Returns:
float: Intersection over Union
"""
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
# computing the sum_area
sum_area = S_rec1 + S_rec2
# find the each edge of intersect rectangle
left_line = max(rec1[0], rec2[0])
right_line = min(rec1[2], rec2[2])
top_line = max(rec1[1], rec2[1])
bottom_line = min(rec1[3], rec2[3])
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect)) * 1.0
def convert_4point2rect(bbox):
"""
Convert 4 point coordinate to rectangle coordinate
Args:
bbox (list): list of 4 points, eg. [x1, y1, x2, y2,...] or [[x1,y1],[x2,y2],...]
"""
if isinstance(bbox, list):
bbox = np.array(bbox)
if bbox.shape[0] == 8:
bbox = np.reshape(bbox, (4, 2))
x1 = min(bbox[:, 0])
y1 = min(bbox[:, 1])
x2 = max(bbox[:, 0])
y2 = max(bbox[:, 1])
return [x1, y1, x2, y2]
def get_ori_coordinate_for_table(x, y, table_bbox):
"""
get the original coordinate from Cropped image to Original image.
Args:
x (int): x coordinate of cropped image
y (int): y coordinate of cropped image
table_bbox (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
Returns:
list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
"""
if not table_bbox:
return table_bbox
offset = np.array([x, y] * 4)
table_bbox = np.array(table_bbox)
if table_bbox.shape[-1] == 2:
offset = offset.reshape(4, 2)
return offset + table_bbox
def is_inside(target_box, text_box):
"""
check if text box is inside target box
Args:
target_box (list): target box where we want to detect, eg. [x1, y1, x2, y2]
text_box (list): text box, eg. [x1, y1, x2, y2]
Returns:
bool: True if text box is inside target box
"""
x1_1, y1_1, x2_1, y2_1 = target_box
x1_2, y1_2, x2_2, y2_2 = text_box
inter_x1 = max(x1_1, x1_2)
inter_y1 = max(y1_1, y1_2)
inter_x2 = min(x2_1, x2_2)
inter_y2 = min(y2_1, y2_2)
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
else:
inter_area = 0
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = area1 + area2 - inter_area
iou = inter_area / union_area if union_area != 0 else 0
return iou > 0
class TableMatch(object):
"""
match table html and ocr res
"""
def __init__(self, filter_ocr_result=False):
self.filter_ocr_result = filter_ocr_result
def __call__(self, table_pred, ocr_pred):
structures = table_pred["structure"]
table_boxes = table_pred["bbox"]
ocr_dt_ploys = ocr_pred["dt_polys"]
ocr_text_res = ocr_pred["rec_text"]
if self.filter_ocr_result:
ocr_dt_ploys, ocr_text_res = self._filter_ocr_result(
table_boxes, ocr_dt_ploys, ocr_text_res
)
matched_index = self.metch_table_and_ocr(table_boxes, ocr_dt_ploys)
pred_html = self.get_html_result(matched_index, ocr_text_res, structures)
return pred_html
def metch_table_and_ocr(self, table_boxes, ocr_boxes):
"""
match table bo
Args:
table_boxes (list): bbox for table, 4 points, [x1,y1,x2,y2,x3,y3,x4,y4]
ocr_boxes (list): bbox for ocr, 4 points, [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
Returns:
dict: matched dict, key is table index, value is ocr index
"""
matched = {}
for i, ocr_box in enumerate(np.array(ocr_boxes)):
ocr_box = convert_4point2rect(ocr_box)
distances = []
for j, table_box in enumerate(table_boxes):
table_box = convert_4point2rect(table_box)
distances.append(
(
distance(table_box, ocr_box),
1.0 - compute_iou(table_box, ocr_box),
)
) # compute iou and l1 distance
sorted_distances = distances.copy()
# select det box by iou and l1 distance
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0])
)
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched
def get_html_result(self, matched_index, ocr_contents, pred_structures):
pred_html = []
td_index = 0
head_structure = pred_structures[0:3]
html = "".join(head_structure)
table_structure = pred_structures[3:-3]
for tag in table_structure:
if "" in tag:
if " | " == tag:
pred_html.extend("")
if td_index in matched_index.keys():
b_with = False
if (
"" in ocr_contents[matched_index[td_index][0]]
and len(matched_index[td_index]) > 1
):
b_with = True
pred_html.extend("")
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == " ":
content = content[1:]
if "" in content:
content = content[3:]
if "" in content:
content = content[:-4]
if len(content) == 0:
continue
if (
i != len(matched_index[td_index]) - 1
and " " != content[-1]
):
content += " "
pred_html.extend(content)
if b_with:
pred_html.extend("")
if " | | " == tag:
pred_html.append("")
else:
pred_html.append(tag)
td_index += 1
else:
pred_html.append(tag)
html += "".join(pred_html)
end_structure = pred_structures[-3:]
html += "".join(end_structure)
return html
def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
y1 = pred_bboxes[:, 1::2].min()
new_dt_boxes = []
new_rec_res = []
for box, rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res