|
|
@@ -90,10 +90,21 @@ class MagicModel:
|
|
|
ret = []
|
|
|
MAX_DIS_OF_POINT = 10**9 + 7
|
|
|
|
|
|
+ def expand_bbox(bbox1, bbox2):
|
|
|
+ x0 = min(bbox1[0], bbox2[0])
|
|
|
+ y0 = min(bbox1[1], bbox2[1])
|
|
|
+ x1 = max(bbox1[2], bbox2[2])
|
|
|
+ y1 = max(bbox1[3], bbox2[3])
|
|
|
+ return [x0, y0, x1, y1]
|
|
|
+
|
|
|
+ def get_bbox_area(bbox):
|
|
|
+ return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])
|
|
|
+
|
|
|
# subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
|
|
|
# 再求出筛选出的 subjects 和 object 的最短距离!
|
|
|
def may_find_other_nearest_bbox(subject_idx, object_idx):
|
|
|
ret = float("inf")
|
|
|
+
|
|
|
x0 = min(
|
|
|
all_bboxes[subject_idx]["bbox"][0], all_bboxes[object_idx]["bbox"][0]
|
|
|
)
|
|
|
@@ -112,6 +123,7 @@ class MagicModel:
|
|
|
) * abs(
|
|
|
all_bboxes[object_idx]["bbox"][3] - all_bboxes[object_idx]["bbox"][1]
|
|
|
)
|
|
|
+
|
|
|
for i in range(len(all_bboxes)):
|
|
|
if (
|
|
|
i == subject_idx
|
|
|
@@ -121,11 +133,13 @@ class MagicModel:
|
|
|
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]["bbox"]) or _is_in(
|
|
|
all_bboxes[i]["bbox"], [x0, y0, x1, y1]
|
|
|
):
|
|
|
+
|
|
|
i_area = abs(
|
|
|
all_bboxes[i]["bbox"][2] - all_bboxes[i]["bbox"][0]
|
|
|
) * abs(all_bboxes[i]["bbox"][3] - all_bboxes[i]["bbox"][1])
|
|
|
if i_area >= object_area:
|
|
|
ret = min(float("inf"), dis[i][object_idx])
|
|
|
+
|
|
|
return ret
|
|
|
|
|
|
subjects = self.__reduct_overlap(
|
|
|
@@ -224,7 +238,7 @@ class MagicModel:
|
|
|
arr.sort(key=lambda x: x[0])
|
|
|
if len(arr) > 0:
|
|
|
# bug: 离该subject 最近的 object 可能跨越了其它的 subject 。比如 [this subect] [some sbuject] [the nearest objec of subject]
|
|
|
- if may_find_other_nearest_bbox(i, j) >= arr[0][0]:
|
|
|
+ if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
|
|
|
candidates.append(arr[0][1])
|
|
|
seen.add(arr[0][1])
|
|
|
|
|
|
@@ -470,6 +484,12 @@ class MagicModel:
|
|
|
return text_spans
|
|
|
|
|
|
def get_all_spans(self, page_no: int) -> list:
|
|
|
+ def remove_duplicate_spans(spans):
|
|
|
+ new_spans = []
|
|
|
+ for span in spans:
|
|
|
+ if not any(span == existing_span for existing_span in new_spans):
|
|
|
+ new_spans.append(span)
|
|
|
+ return new_spans
|
|
|
all_spans = []
|
|
|
model_page_info = self.__model_list[page_no]
|
|
|
layout_dets = model_page_info["layout_dets"]
|
|
|
@@ -483,7 +503,10 @@ class MagicModel:
|
|
|
for layout_det in layout_dets:
|
|
|
category_id = layout_det["category_id"]
|
|
|
if category_id in allow_category_id_list:
|
|
|
- span = {"bbox": layout_det["bbox"]}
|
|
|
+ span = {
|
|
|
+ "bbox": layout_det["bbox"],
|
|
|
+ "score": layout_det["score"]
|
|
|
+ }
|
|
|
if category_id == 3:
|
|
|
span["type"] = ContentType.Image
|
|
|
elif category_id == 5:
|
|
|
@@ -498,7 +521,7 @@ class MagicModel:
|
|
|
span["content"] = layout_det["text"]
|
|
|
span["type"] = ContentType.Text
|
|
|
all_spans.append(span)
|
|
|
- return all_spans
|
|
|
+ return remove_duplicate_spans(all_spans)
|
|
|
|
|
|
def get_page_size(self, page_no: int): # 获取页面宽高
|
|
|
# 获取当前页的page对象
|
|
|
@@ -533,6 +556,7 @@ class MagicModel:
|
|
|
return self.__model_list[page_no]
|
|
|
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
drw = DiskReaderWriter(r"D:/project/20231108code-clean")
|
|
|
if 0:
|