Browse Source

Merge branch 'opendatalab:dev' into dev

Xiaomeng Zhao 9 months ago
parent
commit
773743434c

+ 30 - 18
magic_pdf/model/magic_model.py

@@ -488,46 +488,58 @@ class MagicModel:
 
         OBJ_IDX_OFFSET = 10000
         SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
-        
+
         all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
         seen_idx = set()
         seen_sub_idx = set()
-        
+
         while N > len(seen_sub_idx):
-            candidates = [] 
+            candidates = []
             for idx, kind, x0, y0 in all_boxes_with_idx:
                 if idx in seen_idx:
-                    continue 
+                    continue
                 candidates.append((idx, kind, x0, y0))
-            
+
             if len(candidates) == 0:
                 break
             left_x = min([v[2] for v in candidates])
             top_y =  min([v[3] for v in candidates])
-            
+
             candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
 
-            
+
             fst_idx, fst_kind, left_x, top_y = candidates[0]
             candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
             nxt = None
-            
+
             for i in range(1, len(candidates)):
                 if candidates[i][1] ^ fst_kind == 1:
                     nxt = candidates[i]
-                    break 
+                    break
             if nxt is None:
                 break
-            
-            seen_idx.add(fst_idx)
-            seen_idx.add(nxt[0])
+
             if fst_kind == SUB_BIT_KIND:
-                seen_sub_idx.add(fst_idx)
                 sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
-                
+
             else:
-                seen_sub_idx.add(nxt[0])
                 sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
+
+            pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
+            nearest_dis = float('inf')
+            for i in range(N):
+                if i in seen_idx:continue
+                nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
+
+            if pair_dis >= 3*nearest_dis:
+                seen_idx.add(sub_idx)
+                continue
+
+
+            seen_idx.add(sub_idx)
+            seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
+            seen_sub_idx.add(sub_idx)
+
             ret.append(
                 {
                     'sub_bbox': {
@@ -543,7 +555,7 @@ class MagicModel:
 
         for i in range(len(subjects)):
             if i in seen_sub_idx:
-                continue 
+                continue
             ret.append(
                 {
                     'sub_bbox': {
@@ -554,8 +566,8 @@ class MagicModel:
                     'sub_idx': i,
                 }
             )
-        
-        
+
+
         return ret
 
 

+ 2 - 2
tests/unittest/test_integrations/test_rag/test_utils.py

@@ -24,7 +24,7 @@ def test_convert_middle_json_to_layout_elements():
     assert len(res[0].layout_dets) > 0
     assert res[0].layout_dets[0].anno_id == 0
     assert res[0].layout_dets[0].category_type == CategoryType.text
-    assert len(res[0].extra.element_relation) >= 3
+    assert len(res[0].extra.element_relation) >= 2
 
     # teardown
     shutil.rmtree(temp_output_dir)
@@ -51,7 +51,7 @@ def test_inference():
     assert len(res[0].layout_dets) > 0
     assert res[0].layout_dets[0].anno_id == 0
     assert res[0].layout_dets[0].category_type == CategoryType.text
-    assert len(res[0].extra.element_relation) >= 3
+    assert len(res[0].extra.element_relation) >= 2
 
     # teardown
     shutil.rmtree(temp_output_dir)