|
|
@@ -450,11 +450,120 @@ class MagicModel:
|
|
|
)
|
|
|
return ret
|
|
|
|
|
|
+
|
|
|
+ def __tie_up_category_by_distance_v3(
|
|
|
+ self,
|
|
|
+ page_no: int,
|
|
|
+ subject_category_id: int,
|
|
|
+ object_category_id: int,
|
|
|
+ priority_pos: PosRelationEnum,
|
|
|
+ ):
|
|
|
+ subjects = self.__reduct_overlap(
|
|
|
+ list(
|
|
|
+ map(
|
|
|
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
|
|
+ filter(
|
|
|
+ lambda x: x['category_id'] == subject_category_id,
|
|
|
+ self.__model_list[page_no]['layout_dets'],
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+ objects = self.__reduct_overlap(
|
|
|
+ list(
|
|
|
+ map(
|
|
|
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
|
|
+ filter(
|
|
|
+ lambda x: x['category_id'] == object_category_id,
|
|
|
+ self.__model_list[page_no]['layout_dets'],
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ ret = []
|
|
|
+ N, M = len(subjects), len(objects)
|
|
|
+ subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
|
|
+ objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
|
|
+
|
|
|
+ OBJ_IDX_OFFSET = 10000
|
|
|
+ SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
|
|
|
+
|
|
|
+ all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
|
|
|
+ seen_idx = set()
|
|
|
+ seen_sub_idx = set()
|
|
|
+
|
|
|
+ while N > len(seen_sub_idx):
|
|
|
+ candidates = []
|
|
|
+ for idx, kind, x0, y0 in all_boxes_with_idx:
|
|
|
+ if idx in seen_idx:
|
|
|
+ continue
|
|
|
+ candidates.append((idx, kind, x0, y0))
|
|
|
+
|
|
|
+ if len(candidates) == 0:
|
|
|
+ break
|
|
|
+ left_x = min([v[2] for v in candidates])
|
|
|
+ top_y = min([v[3] for v in candidates])
|
|
|
+
|
|
|
+ candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
|
|
|
+
|
|
|
+
|
|
|
+ fst_idx, fst_kind, left_x, top_y = candidates[0]
|
|
|
+ candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
|
|
|
+ nxt = None
|
|
|
+
|
|
|
+ for i in range(1, len(candidates)):
|
|
|
+ if candidates[i][1] ^ fst_kind == 1:
|
|
|
+ nxt = candidates[i]
|
|
|
+ break
|
|
|
+ if nxt is None:
|
|
|
+ break
|
|
|
+
|
|
|
+ seen_idx.add(fst_idx)
|
|
|
+ seen_idx.add(nxt[0])
|
|
|
+ if fst_kind == SUB_BIT_KIND:
|
|
|
+ seen_sub_idx.add(fst_idx)
|
|
|
+ sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
|
|
|
+
|
|
|
+ else:
|
|
|
+ seen_sub_idx.add(nxt[0])
|
|
|
+ sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
|
|
|
+ ret.append(
|
|
|
+ {
|
|
|
+ 'sub_bbox': {
|
|
|
+ 'bbox': subjects[sub_idx]['bbox'],
|
|
|
+ 'score': subjects[sub_idx]['score'],
|
|
|
+ },
|
|
|
+ 'obj_bboxes': [
|
|
|
+ {'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
|
|
|
+ ],
|
|
|
+ 'sub_idx': sub_idx,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ for i in range(len(subjects)):
|
|
|
+ if i in seen_sub_idx:
|
|
|
+ continue
|
|
|
+ ret.append(
|
|
|
+ {
|
|
|
+ 'sub_bbox': {
|
|
|
+ 'bbox': subjects[i]['bbox'],
|
|
|
+ 'score': subjects[i]['score'],
|
|
|
+ },
|
|
|
+ 'obj_bboxes': [],
|
|
|
+ 'sub_idx': i,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ return ret
|
|
|
+
|
|
|
+
|
|
|
def get_imgs_v2(self, page_no: int):
|
|
|
- with_captions = self.__tie_up_category_by_distance_v2(
|
|
|
+ with_captions = self.__tie_up_category_by_distance_v3(
|
|
|
page_no, 3, 4, PosRelationEnum.BOTTOM
|
|
|
)
|
|
|
- with_footnotes = self.__tie_up_category_by_distance_v2(
|
|
|
+ with_footnotes = self.__tie_up_category_by_distance_v3(
|
|
|
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
|
|
|
)
|
|
|
ret = []
|
|
|
@@ -470,10 +579,10 @@ class MagicModel:
|
|
|
return ret
|
|
|
|
|
|
def get_tables_v2(self, page_no: int) -> list:
|
|
|
- with_captions = self.__tie_up_category_by_distance_v2(
|
|
|
+ with_captions = self.__tie_up_category_by_distance_v3(
|
|
|
page_no, 5, 6, PosRelationEnum.UP
|
|
|
)
|
|
|
- with_footnotes = self.__tie_up_category_by_distance_v2(
|
|
|
+ with_footnotes = self.__tie_up_category_by_distance_v3(
|
|
|
page_no, 5, 7, PosRelationEnum.ALL
|
|
|
)
|
|
|
ret = []
|