|
|
@@ -3,12 +3,9 @@ import enum
|
|
|
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
|
|
|
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
|
|
|
from magic_pdf.data.dataset import Dataset
|
|
|
-from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
|
|
|
- bbox_relative_pos, box_area, calculate_iou,
|
|
|
- calculate_overlap_area_in_bbox1_area_ratio,
|
|
|
- get_overlap_area)
|
|
|
+from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
|
|
|
+ calculate_iou)
|
|
|
from magic_pdf.libs.coordinate_transform import get_scale_ratio
|
|
|
-from magic_pdf.libs.local_math import float_gt
|
|
|
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
|
|
|
|
|
|
CAPATION_OVERLAP_AREA_RATIO = 0.6
|
|
|
@@ -208,393 +205,6 @@ class MagicModel:
|
|
|
keep[i] = False
|
|
|
return [bboxes[i] for i in range(N) if keep[i]]
|
|
|
|
|
|
- def __tie_up_category_by_distance(
|
|
|
- self, page_no, subject_category_id, object_category_id
|
|
|
- ):
|
|
|
- """假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
|
|
|
- 只能属于一个 subject."""
|
|
|
- ret = []
|
|
|
- MAX_DIS_OF_POINT = 10**9 + 7
|
|
|
- """
|
|
|
- subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
|
|
|
- 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
|
|
|
- 再求出筛选出的 subjects 和 object 的最短距离
|
|
|
- """
|
|
|
-
|
|
|
- def search_overlap_between_boxes(subject_idx, object_idx):
|
|
|
- idxes = [subject_idx, object_idx]
|
|
|
- x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
|
|
|
- y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
|
|
|
- x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
|
|
|
- y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
|
|
|
-
|
|
|
- merged_bbox = [
|
|
|
- min(x0s),
|
|
|
- min(y0s),
|
|
|
- max(x1s),
|
|
|
- max(y1s),
|
|
|
- ]
|
|
|
- ratio = 0
|
|
|
-
|
|
|
- other_objects = list(
|
|
|
- map(
|
|
|
- lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
|
|
- filter(
|
|
|
- lambda x: x['category_id']
|
|
|
- not in (object_category_id, subject_category_id),
|
|
|
- self.__model_list[page_no]['layout_dets'],
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
- for other_object in other_objects:
|
|
|
- ratio = max(
|
|
|
- ratio,
|
|
|
- get_overlap_area(merged_bbox, other_object['bbox'])
|
|
|
- * 1.0
|
|
|
- / box_area(all_bboxes[object_idx]['bbox']),
|
|
|
- )
|
|
|
- if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
|
|
|
- break
|
|
|
-
|
|
|
- return ratio
|
|
|
-
|
|
|
- def may_find_other_nearest_bbox(subject_idx, object_idx):
|
|
|
- ret = float('inf')
|
|
|
-
|
|
|
- x0 = min(
|
|
|
- all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
|
|
|
- )
|
|
|
- y0 = min(
|
|
|
- all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
|
|
|
- )
|
|
|
- x1 = max(
|
|
|
- all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
|
|
|
- )
|
|
|
- y1 = max(
|
|
|
- all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
|
|
|
- )
|
|
|
-
|
|
|
- object_area = abs(
|
|
|
- all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
|
|
|
- ) * abs(
|
|
|
- all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
|
|
|
- )
|
|
|
-
|
|
|
- for i in range(len(all_bboxes)):
|
|
|
- if (
|
|
|
- i == subject_idx
|
|
|
- or all_bboxes[i]['category_id'] != subject_category_id
|
|
|
- ):
|
|
|
- continue
|
|
|
- if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
|
|
|
- all_bboxes[i]['bbox'], [x0, y0, x1, y1]
|
|
|
- ):
|
|
|
-
|
|
|
- i_area = abs(
|
|
|
- all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
|
|
|
- ) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
|
|
|
- if i_area >= object_area:
|
|
|
- ret = min(float('inf'), dis[i][object_idx])
|
|
|
-
|
|
|
- return ret
|
|
|
-
|
|
|
- def expand_bbbox(idxes):
|
|
|
- x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
|
|
|
- y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
|
|
|
- x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
|
|
|
- y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
|
|
|
- return min(x0s), min(y0s), max(x1s), max(y1s)
|
|
|
-
|
|
|
- subjects = self.__reduct_overlap(
|
|
|
- list(
|
|
|
- map(
|
|
|
- lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
|
|
- filter(
|
|
|
- lambda x: x['category_id'] == subject_category_id,
|
|
|
- self.__model_list[page_no]['layout_dets'],
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- objects = self.__reduct_overlap(
|
|
|
- list(
|
|
|
- map(
|
|
|
- lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
|
|
- filter(
|
|
|
- lambda x: x['category_id'] == object_category_id,
|
|
|
- self.__model_list[page_no]['layout_dets'],
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- subject_object_relation_map = {}
|
|
|
-
|
|
|
- subjects.sort(
|
|
|
- key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
|
|
|
- ) # get the distance !
|
|
|
-
|
|
|
- all_bboxes = []
|
|
|
-
|
|
|
- for v in subjects:
|
|
|
- all_bboxes.append(
|
|
|
- {
|
|
|
- 'category_id': subject_category_id,
|
|
|
- 'bbox': v['bbox'],
|
|
|
- 'score': v['score'],
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- for v in objects:
|
|
|
- all_bboxes.append(
|
|
|
- {
|
|
|
- 'category_id': object_category_id,
|
|
|
- 'bbox': v['bbox'],
|
|
|
- 'score': v['score'],
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- N = len(all_bboxes)
|
|
|
- dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
|
|
|
-
|
|
|
- for i in range(N):
|
|
|
- for j in range(i):
|
|
|
- if (
|
|
|
- all_bboxes[i]['category_id'] == subject_category_id
|
|
|
- and all_bboxes[j]['category_id'] == subject_category_id
|
|
|
- ):
|
|
|
- continue
|
|
|
-
|
|
|
- subject_idx, object_idx = i, j
|
|
|
- if all_bboxes[j]['category_id'] == subject_category_id:
|
|
|
- subject_idx, object_idx = j, i
|
|
|
-
|
|
|
- if (
|
|
|
- search_overlap_between_boxes(subject_idx, object_idx)
|
|
|
- >= MERGE_BOX_OVERLAP_AREA_RATIO
|
|
|
- ):
|
|
|
- dis[i][j] = float('inf')
|
|
|
- dis[j][i] = dis[i][j]
|
|
|
- continue
|
|
|
-
|
|
|
- dis[i][j] = self._bbox_distance(
|
|
|
- all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
|
|
|
- )
|
|
|
- dis[j][i] = dis[i][j]
|
|
|
-
|
|
|
- used = set()
|
|
|
- for i in range(N):
|
|
|
- # 求第 i 个 subject 所关联的 object
|
|
|
- if all_bboxes[i]['category_id'] != subject_category_id:
|
|
|
- continue
|
|
|
- seen = set()
|
|
|
- candidates = []
|
|
|
- arr = []
|
|
|
- for j in range(N):
|
|
|
-
|
|
|
- pos_flag_count = sum(
|
|
|
- list(
|
|
|
- map(
|
|
|
- lambda x: 1 if x else 0,
|
|
|
- bbox_relative_pos(
|
|
|
- all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- if pos_flag_count > 1:
|
|
|
- continue
|
|
|
- if (
|
|
|
- all_bboxes[j]['category_id'] != object_category_id
|
|
|
- or j in used
|
|
|
- or dis[i][j] == MAX_DIS_OF_POINT
|
|
|
- ):
|
|
|
- continue
|
|
|
- left, right, _, _ = bbox_relative_pos(
|
|
|
- all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
|
|
|
- ) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
|
|
|
- if left or right:
|
|
|
- one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
|
|
|
- else:
|
|
|
- one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
|
|
|
- if dis[i][j] > one_way_dis:
|
|
|
- continue
|
|
|
- arr.append((dis[i][j], j))
|
|
|
-
|
|
|
- arr.sort(key=lambda x: x[0])
|
|
|
- if len(arr) > 0:
|
|
|
- """
|
|
|
- bug: 离该subject 最近的 object 可能跨越了其它的 subject。
|
|
|
- 比如 [this subect] [some sbuject] [the nearest object of subject]
|
|
|
- """
|
|
|
- if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
|
|
|
-
|
|
|
- candidates.append(arr[0][1])
|
|
|
- seen.add(arr[0][1])
|
|
|
-
|
|
|
- # 已经获取初始种子
|
|
|
- for j in set(candidates):
|
|
|
- tmp = []
|
|
|
- for k in range(i + 1, N):
|
|
|
- pos_flag_count = sum(
|
|
|
- list(
|
|
|
- map(
|
|
|
- lambda x: 1 if x else 0,
|
|
|
- bbox_relative_pos(
|
|
|
- all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- if pos_flag_count > 1:
|
|
|
- continue
|
|
|
-
|
|
|
- if (
|
|
|
- all_bboxes[k]['category_id'] != object_category_id
|
|
|
- or k in used
|
|
|
- or k in seen
|
|
|
- or dis[j][k] == MAX_DIS_OF_POINT
|
|
|
- or dis[j][k] > dis[i][j]
|
|
|
- ):
|
|
|
- continue
|
|
|
-
|
|
|
- is_nearest = True
|
|
|
- for ni in range(i + 1, N):
|
|
|
- if ni in (j, k) or ni in used or ni in seen:
|
|
|
- continue
|
|
|
-
|
|
|
- if not float_gt(dis[ni][k], dis[j][k]):
|
|
|
- is_nearest = False
|
|
|
- break
|
|
|
-
|
|
|
- if is_nearest:
|
|
|
- nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
|
|
|
- n_dis = bbox_distance(
|
|
|
- all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
|
|
|
- )
|
|
|
- if float_gt(dis[i][j], n_dis):
|
|
|
- continue
|
|
|
- tmp.append(k)
|
|
|
- seen.add(k)
|
|
|
-
|
|
|
- candidates = tmp
|
|
|
- if len(candidates) == 0:
|
|
|
- break
|
|
|
-
|
|
|
- # 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
|
|
|
- # 先扩一下 bbox,
|
|
|
- ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
|
|
|
- ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
|
|
|
-
|
|
|
- # 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
|
|
|
- caption_poses = [
|
|
|
- [ox0, oy0, ix0, oy1],
|
|
|
- [ox0, oy0, ox1, iy0],
|
|
|
- [ox0, iy1, ox1, oy1],
|
|
|
- [ix1, oy0, ox1, oy1],
|
|
|
- ]
|
|
|
-
|
|
|
- caption_areas = []
|
|
|
- for bbox in caption_poses:
|
|
|
- embed_arr = []
|
|
|
- for idx in seen:
|
|
|
- if (
|
|
|
- calculate_overlap_area_in_bbox1_area_ratio(
|
|
|
- all_bboxes[idx]['bbox'], bbox
|
|
|
- )
|
|
|
- > CAPATION_OVERLAP_AREA_RATIO
|
|
|
- ):
|
|
|
- embed_arr.append(idx)
|
|
|
-
|
|
|
- if len(embed_arr) > 0:
|
|
|
- embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
|
|
|
- embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
|
|
|
- embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
|
|
|
- embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
|
|
|
- caption_areas.append(
|
|
|
- int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
|
|
|
- )
|
|
|
- else:
|
|
|
- caption_areas.append(0)
|
|
|
-
|
|
|
- subject_object_relation_map[i] = []
|
|
|
- if max(caption_areas) > 0:
|
|
|
- max_area_idx = caption_areas.index(max(caption_areas))
|
|
|
- caption_bbox = caption_poses[max_area_idx]
|
|
|
-
|
|
|
- for j in seen:
|
|
|
- if (
|
|
|
- calculate_overlap_area_in_bbox1_area_ratio(
|
|
|
- all_bboxes[j]['bbox'], caption_bbox
|
|
|
- )
|
|
|
- > CAPATION_OVERLAP_AREA_RATIO
|
|
|
- ):
|
|
|
- used.add(j)
|
|
|
- subject_object_relation_map[i].append(j)
|
|
|
-
|
|
|
- for i in sorted(subject_object_relation_map.keys()):
|
|
|
- result = {
|
|
|
- 'subject_body': all_bboxes[i]['bbox'],
|
|
|
- 'all': all_bboxes[i]['bbox'],
|
|
|
- 'score': all_bboxes[i]['score'],
|
|
|
- }
|
|
|
-
|
|
|
- if len(subject_object_relation_map[i]) > 0:
|
|
|
- x0 = min(
|
|
|
- [all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
|
|
|
- )
|
|
|
- y0 = min(
|
|
|
- [all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
|
|
|
- )
|
|
|
- x1 = max(
|
|
|
- [all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
|
|
|
- )
|
|
|
- y1 = max(
|
|
|
- [all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
|
|
|
- )
|
|
|
- result['object_body'] = [x0, y0, x1, y1]
|
|
|
- result['all'] = [
|
|
|
- min(x0, all_bboxes[i]['bbox'][0]),
|
|
|
- min(y0, all_bboxes[i]['bbox'][1]),
|
|
|
- max(x1, all_bboxes[i]['bbox'][2]),
|
|
|
- max(y1, all_bboxes[i]['bbox'][3]),
|
|
|
- ]
|
|
|
- ret.append(result)
|
|
|
-
|
|
|
- total_subject_object_dis = 0
|
|
|
- # 计算已经配对的 distance 距离
|
|
|
- for i in subject_object_relation_map.keys():
|
|
|
- for j in subject_object_relation_map[i]:
|
|
|
- total_subject_object_dis += bbox_distance(
|
|
|
- all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
|
|
|
- )
|
|
|
-
|
|
|
- # 计算未匹配的 subject 和 object 的距离(非精确版)
|
|
|
- with_caption_subject = set(
|
|
|
- [
|
|
|
- key
|
|
|
- for key in subject_object_relation_map.keys()
|
|
|
- if len(subject_object_relation_map[i]) > 0
|
|
|
- ]
|
|
|
- )
|
|
|
- for i in range(N):
|
|
|
- if all_bboxes[i]['category_id'] != object_category_id or i in used:
|
|
|
- continue
|
|
|
- candidates = []
|
|
|
- for j in range(N):
|
|
|
- if (
|
|
|
- all_bboxes[j]['category_id'] != subject_category_id
|
|
|
- or j in with_caption_subject
|
|
|
- ):
|
|
|
- continue
|
|
|
- candidates.append((dis[i][j], j))
|
|
|
- if len(candidates) > 0:
|
|
|
- candidates.sort(key=lambda x: x[0])
|
|
|
- total_subject_object_dis += candidates[0][1]
|
|
|
- with_caption_subject.add(j)
|
|
|
- return ret, total_subject_object_dis
|
|
|
-
|
|
|
def __tie_up_category_by_distance_v2(
|
|
|
self,
|
|
|
page_no: int,
|
|
|
@@ -879,52 +489,12 @@ class MagicModel:
|
|
|
return ret
|
|
|
|
|
|
def get_imgs(self, page_no: int):
|
|
|
- with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
|
|
|
- with_footnotes, _ = self.__tie_up_category_by_distance(
|
|
|
- page_no, 3, CategoryId.ImageFootnote
|
|
|
- )
|
|
|
- ret = []
|
|
|
- N, M = len(with_captions), len(with_footnotes)
|
|
|
- assert N == M
|
|
|
- for i in range(N):
|
|
|
- record = {
|
|
|
- 'score': with_captions[i]['score'],
|
|
|
- 'img_caption_bbox': with_captions[i].get('object_body', None),
|
|
|
- 'img_body_bbox': with_captions[i]['subject_body'],
|
|
|
- 'img_footnote_bbox': with_footnotes[i].get('object_body', None),
|
|
|
- }
|
|
|
-
|
|
|
- x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
|
|
|
- y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
|
|
|
- x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
|
|
|
- y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
|
|
|
- record['bbox'] = [x0, y0, x1, y1]
|
|
|
- ret.append(record)
|
|
|
- return ret
|
|
|
+ return self.get_imgs_v2(page_no)
|
|
|
|
|
|
def get_tables(
|
|
|
self, page_no: int
|
|
|
) -> list: # 3个坐标, caption, table主体,table-note
|
|
|
- with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
|
|
|
- with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
|
|
|
- ret = []
|
|
|
- N, M = len(with_captions), len(with_footnotes)
|
|
|
- assert N == M
|
|
|
- for i in range(N):
|
|
|
- record = {
|
|
|
- 'score': with_captions[i]['score'],
|
|
|
- 'table_caption_bbox': with_captions[i].get('object_body', None),
|
|
|
- 'table_body_bbox': with_captions[i]['subject_body'],
|
|
|
- 'table_footnote_bbox': with_footnotes[i].get('object_body', None),
|
|
|
- }
|
|
|
-
|
|
|
- x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
|
|
|
- y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
|
|
|
- x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
|
|
|
- y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
|
|
|
- record['bbox'] = [x0, y0, x1, y1]
|
|
|
- ret.append(record)
|
|
|
- return ret
|
|
|
+ return self.get_tables_v2(page_no)
|
|
|
|
|
|
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
|
|
|
inline_equations = self.__get_blocks_by_type(
|
|
|
@@ -1043,4 +613,3 @@ class MagicModel:
|
|
|
|
|
|
def get_model_list(self, page_no):
|
|
|
return self.__model_list[page_no]
|
|
|
-
|