# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import re import copy __all__ = [ "TableMatch", "convert_4point2rect", "get_ori_coordinate_for_table", "is_inside", ] def deal_eb_token(master_token): """ post process with , , ... emptyBboxTokenDict = { "[]": '', "[' ']": '', "['', ' ', '']": '', "['\\u2028', '\\u2028']": '', "['', ' ', '']": '', "['', '']": '', "['', ' ', '']": '', "['', '', '', '']": '', "['', '', ' ', '', '']": '', "['', '']": '', "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '', } :param master_token: :return: """ master_token = master_token.replace("", "") master_token = master_token.replace("", " ") master_token = master_token.replace("", " ") master_token = master_token.replace("", "\u2028\u2028") master_token = master_token.replace("", " ") master_token = master_token.replace("", "") master_token = master_token.replace("", " ") master_token = master_token.replace("", "") master_token = master_token.replace("", " ") master_token = master_token.replace("", "") master_token = master_token.replace( "", " \u2028 \u2028 " ) return master_token def deal_bb(result_token): """ In our opinion, always occurs in text's context. This function will find out all tokens in and insert by manual. :param result_token: :return: """ # find out parts. thead_pattern = "(.*?)" if re.search(thead_pattern, result_token) is None: return result_token thead_part = re.search(thead_pattern, result_token).group() origin_thead_part = copy.deepcopy(thead_part) # check "rowspan" or "colspan" occur in parts or not . span_pattern = ( '|||' ) span_iter = re.finditer(span_pattern, thead_part) span_list = [s.group() for s in span_iter] has_span_in_head = True if len(span_list) > 0 else False if not has_span_in_head: # not include "rowspan" or "colspan" branch 1. # 1. replace to , and to # 2. it is possible to predict text include or by Text-line recognition, # so we replace to , and to thead_part = ( thead_part.replace("", "") .replace("", "") .replace("", "") .replace("", "") ) else: # include "rowspan" or "colspan" branch 2. # Firstly, we deal rowspan or colspan cases. # 1. replace > to > # 2. replace to # 3. it is possible to predict text include or by Text-line recognition, # so we replace to , and to # Secondly, deal ordinary cases like branch 1 # replace ">" to "" replaced_span_list = [] for sp in span_list: replaced_span_list.append(sp.replace(">", ">")) for sp, rsp in zip(span_list, replaced_span_list): thead_part = thead_part.replace(sp, rsp) # replace "" to "" thead_part = thead_part.replace("", "") # remove duplicated by re.sub mb_pattern = "()+" single_b_string = "" thead_part = re.sub(mb_pattern, single_b_string, thead_part) mgb_pattern = "()+" single_gb_string = "" thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) # ordinary cases like branch 1 thead_part = thead_part.replace("", "").replace("", "") # convert back to , empty cell has no . # but space cell( ) is suitable for thead_part = thead_part.replace("", "") # deal with duplicated thead_part = deal_duplicate_bb(thead_part) # deal with isolate span tokens, which causes by wrong predict by structure prediction. # eg.PMC5994107_011_00.png thead_part = deal_isolate_span(thead_part) # replace original result with new thead part. result_token = result_token.replace(origin_thead_part, thead_part) return result_token def deal_isolate_span(thead_part): """ Deal with isolate span cases in this function. It causes by wrong prediction in structure recognition model. eg. predict to rowspan="2">. :param thead_part: :return: """ # 1. find out isolate span tokens. isolate_pattern = ( ' rowspan="(\d)+" colspan="(\d)+">|' ' colspan="(\d)+" rowspan="(\d)+">|' ' rowspan="(\d)+">|' ' colspan="(\d)+">' ) isolate_iter = re.finditer(isolate_pattern, thead_part) isolate_list = [i.group() for i in isolate_iter] # 2. find out span number, by step 1 results. span_pattern = ( ' rowspan="(\d)+" colspan="(\d)+"|' ' colspan="(\d)+" rowspan="(\d)+"|' ' rowspan="(\d)+"|' ' colspan="(\d)+"' ) corrected_list = [] for isolate_item in isolate_list: span_part = re.search(span_pattern, isolate_item) spanStr_in_isolateItem = span_part.group() # 3. merge the span number into the span token format string. if spanStr_in_isolateItem is not None: corrected_item = "".format(spanStr_in_isolateItem) corrected_list.append(corrected_item) else: corrected_list.append(None) # 4. replace original isolated token. for corrected_item, isolate_item in zip(corrected_list, isolate_list): if corrected_item is not None: thead_part = thead_part.replace(isolate_item, corrected_item) else: pass return thead_part def deal_duplicate_bb(thead_part): """ Deal duplicate or after replace. Keep one in a token. :param thead_part: :return: """ # 1. find out in . td_pattern = ( '(.+?)|' '(.+?)|' '(.+?)|' '(.+?)|' "(.*?)" ) td_iter = re.finditer(td_pattern, thead_part) td_list = [t.group() for t in td_iter] # 2. is multiply in or not? new_td_list = [] for td_item in td_list: if td_item.count("") > 1 or td_item.count("") > 1: # multiply in case. # 1. remove all td_item = td_item.replace("", "").replace("", "") # 2. replace -> , -> . td_item = td_item.replace("", "").replace("", "") new_td_list.append(td_item) else: new_td_list.append(td_item) # 3. replace original thead part. for td_item, new_td_item in zip(td_list, new_td_list): thead_part = thead_part.replace(td_item, new_td_item) return thead_part def distance(box_1, box_2): """ compute the distance between two boxes Args: box_1 (list): first rectangle box,eg.(x1, y1, x2, y2) box_2 (list): second rectangle box,eg.(x1, y1, x2, y2) Returns: int: the distance between two boxes """ x1, y1, x2, y2 = box_1 x3, y3, x4, y4 = box_2 dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) dis_2 = abs(x3 - x1) + abs(y3 - y1) dis_3 = abs(x4 - x2) + abs(y4 - y2) return dis + min(dis_2, dis_3) def compute_iou(rec1, rec2): """ computing IoU Args: rec1 (list): (x1, y1, x2, y2) rec2 (list): (x1, y1, x2, y2) Returns: float: Intersection over Union """ # computing area of each rectangles S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) # computing the sum_area sum_area = S_rec1 + S_rec2 # find the each edge of intersect rectangle left_line = max(rec1[0], rec2[0]) right_line = min(rec1[2], rec2[2]) top_line = max(rec1[1], rec2[1]) bottom_line = min(rec1[3], rec2[3]) # judge if there is an intersect if left_line >= right_line or top_line >= bottom_line: return 0.0 else: intersect = (right_line - left_line) * (bottom_line - top_line) return (intersect / (sum_area - intersect)) * 1.0 def convert_4point2rect(bbox): """ Convert 4 point coordinate to rectangle coordinate Args: bbox (list): list of 4 points, eg. [x1, y1, x2, y2,...] or [[x1,y1],[x2,y2],...] """ if isinstance(bbox, list): bbox = np.array(bbox) if bbox.shape[0] == 8: bbox = np.reshape(bbox, (4, 2)) x1 = min(bbox[:, 0]) y1 = min(bbox[:, 1]) x2 = max(bbox[:, 0]) y2 = max(bbox[:, 1]) return [x1, y1, x2, y2] def get_ori_coordinate_for_table(x, y, table_bbox): """ get the original coordinate from Cropped image to Original image. Args: x (int): x coordinate of cropped image y (int): y coordinate of cropped image table_bbox (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]] Returns: list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]] """ if not table_bbox: return table_bbox offset = np.array([x, y] * 4) table_bbox = np.array(table_bbox) if table_bbox.shape[-1] == 2: offset = offset.reshape(4, 2) return offset + table_bbox def is_inside(target_box, text_box): """ check if text box is inside target box Args: target_box (list): target box where we want to detect, eg. [x1, y1, x2, y2] text_box (list): text box, eg. [x1, y1, x2, y2] Returns: bool: True if text box is inside target box """ x1_1, y1_1, x2_1, y2_1 = target_box x1_2, y1_2, x2_2, y2_2 = text_box inter_x1 = max(x1_1, x1_2) inter_y1 = max(y1_1, y1_2) inter_x2 = min(x2_1, x2_2) inter_y2 = min(y2_1, y2_2) if inter_x1 < inter_x2 and inter_y1 < inter_y2: inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1) else: inter_area = 0 area1 = (x2_1 - x1_1) * (y2_1 - y1_1) area2 = (x2_2 - x1_2) * (y2_2 - y1_2) union_area = area1 + area2 - inter_area iou = inter_area / union_area if union_area != 0 else 0 return iou > 0 class TableMatch(object): """ match table html and ocr res """ def __init__(self, filter_ocr_result=False): self.filter_ocr_result = filter_ocr_result def __call__(self, table_pred, ocr_pred): structures = table_pred["structure"] table_boxes = table_pred["bbox"] ocr_dt_ploys = ocr_pred["dt_polys"] ocr_text_res = ocr_pred["rec_text"] if self.filter_ocr_result: ocr_dt_ploys, ocr_text_res = self._filter_ocr_result( table_boxes, ocr_dt_ploys, ocr_text_res ) matched_index = self.metch_table_and_ocr(table_boxes, ocr_dt_ploys) pred_html = self.get_html_result(matched_index, ocr_text_res, structures) return pred_html def metch_table_and_ocr(self, table_boxes, ocr_boxes): """ match table bo Args: table_boxes (list): bbox for table, 4 points, [x1,y1,x2,y2,x3,y3,x4,y4] ocr_boxes (list): bbox for ocr, 4 points, [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] Returns: dict: matched dict, key is table index, value is ocr index """ matched = {} for i, ocr_box in enumerate(np.array(ocr_boxes)): ocr_box = convert_4point2rect(ocr_box) distances = [] for j, table_box in enumerate(table_boxes): table_box = convert_4point2rect(table_box) distances.append( ( distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box), ) ) # compute iou and l1 distance sorted_distances = distances.copy() # select det box by iou and l1 distance sorted_distances = sorted( sorted_distances, key=lambda item: (item[1], item[0]) ) if distances.index(sorted_distances[0]) not in matched.keys(): matched[distances.index(sorted_distances[0])] = [i] else: matched[distances.index(sorted_distances[0])].append(i) return matched def get_html_result(self, matched_index, ocr_contents, pred_structures): pred_html = [] td_index = 0 head_structure = pred_structures[0:3] html = "".join(head_structure) table_structure = pred_structures[3:-3] for tag in table_structure: if "" in tag: if "" == tag: pred_html.extend("") if td_index in matched_index.keys(): b_with = False if ( "" in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1 ): b_with = True pred_html.extend("") for i, td_index_index in enumerate(matched_index[td_index]): content = ocr_contents[td_index_index] if len(matched_index[td_index]) > 1: if len(content) == 0: continue if content[0] == " ": content = content[1:] if "" in content: content = content[3:] if "" in content: content = content[:-4] if len(content) == 0: continue if ( i != len(matched_index[td_index]) - 1 and " " != content[-1] ): content += " " pred_html.extend(content) if b_with: pred_html.extend("") if "" == tag: pred_html.append("") else: pred_html.append(tag) td_index += 1 else: pred_html.append(tag) html += "".join(pred_html) end_structure = pred_structures[-3:] html += "".join(end_structure) return html def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): y1 = pred_bboxes[:, 1::2].min() new_dt_boxes = [] new_rec_res = [] for box, rec in zip(dt_boxes, rec_res): if np.max(box[1::2]) < y1: continue new_dt_boxes.append(box) new_rec_res.append(rec) return new_dt_boxes, new_rec_res