Эх сурвалжийг харах

fix: add logic to remove low confidence overlapping blocks in layout results

myhloli 3 сар өмнө
parent
commit
2ce4352a25
1 өөрчлөгдсөн 60 нэмэгдсэн , 63 устгасан
  1. 60 63
      mineru/utils/model_utils.py

+ 60 - 63
mineru/utils/model_utils.py

@@ -5,6 +5,7 @@ from loguru import logger
 import numpy as np
 
 from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
+from mineru.utils.span_pre_proc import remove_overlaps_low_confidence_spans
 
 try:
     import torch
@@ -131,15 +132,10 @@ def merge_high_iou_tables(table_res_list, layout_res, table_indices, iou_thresho
 
                     # Create merged table
                     merged_table = table_res_list[i].copy()
-                    merged_table['poly'][0] = union_xmin
-                    merged_table['poly'][1] = union_ymin
-                    merged_table['poly'][2] = union_xmax
-                    merged_table['poly'][3] = union_ymin
-                    merged_table['poly'][4] = union_xmax
-                    merged_table['poly'][5] = union_ymax
-                    merged_table['poly'][6] = union_xmin
-                    merged_table['poly'][7] = union_ymax
-
+                    merged_table['poly'] = [
+                        union_xmin, union_ymin, union_xmax, union_ymin,
+                        union_xmax, union_ymax, union_xmin, union_ymax
+                    ]
                     # Update layout_res
                     to_remove = [table_indices[j], table_indices[i]]
                     for idx in sorted(to_remove, reverse=True):
@@ -253,6 +249,59 @@ def remove_overlaps_min_blocks(res_list):
     return res_list, need_remove
 
 
+def remove_overlaps_low_confidence_blocks(combined_res_list):
+    # 计算每个block的坐标和面积
+    block_info = []
+    for block in combined_res_list:
+        xmin, ymin = int(block['poly'][0]), int(block['poly'][1])
+        xmax, ymax = int(block['poly'][4]), int(block['poly'][5])
+        area = (xmax - xmin) * (ymax - ymin)
+        score = block.get('score', 0.5)  # 如果没有score字段,默认为0.5
+        block_info.append((xmin, ymin, xmax, ymax, area, score, block))
+
+    blocks_to_remove = []
+
+    # 检查每个block内部是否有3个及以上的小block
+    for i, (xmin, ymin, xmax, ymax, area, score, block) in enumerate(block_info):
+        # 查找内部的小block
+        blocks_inside = [(j, j_score, j_block) for j, (xj_min, yj_min, xj_max, yj_max, j_area, j_score, j_block) in
+                         enumerate(block_info)
+                         if i != j and is_inside(block_info[j], block_info[i])]
+
+        # 如果内部有3个及以上的小block
+        if len(blocks_inside) >= 3:
+            # 计算小block的平均分数
+            avg_score = sum(s for _, s, _ in blocks_inside) / len(blocks_inside)
+
+            # 比较大block的分数和小block的平均分数
+            if score > avg_score:
+                # 保留大block,扩展其边界
+                # 首先将所有小block标记为要删除
+                for j, _, j_block in blocks_inside:
+                    if j_block not in blocks_to_remove:
+                        blocks_to_remove.append(j_block)
+
+                # 扩展大block的边界以包含所有小block
+                new_xmin, new_ymin, new_xmax, new_ymax = xmin, ymin, xmax, ymax
+                for _, _, j_block in blocks_inside:
+                    j_xmin, j_ymin = int(j_block['poly'][0]), int(j_block['poly'][1])
+                    j_xmax, j_ymax = int(j_block['poly'][4]), int(j_block['poly'][5])
+                    new_xmin = min(new_xmin, j_xmin)
+                    new_ymin = min(new_ymin, j_ymin)
+                    new_xmax = max(new_xmax, j_xmax)
+                    new_ymax = max(new_ymax, j_ymax)
+
+                # 更新大block的边界
+                block['poly'][0] = block['poly'][6] = new_xmin
+                block['poly'][1] = block['poly'][3] = new_ymin
+                block['poly'][2] = block['poly'][4] = new_xmax
+                block['poly'][5] = block['poly'][7] = new_ymax
+            else:
+                # 保留小blocks,删除大block
+                blocks_to_remove.append(block)
+    return blocks_to_remove
+
+
 def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
     """Extract OCR, table and other regions from layout results."""
     ocr_res_list = []
@@ -311,67 +360,15 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
             del res['bbox']
             layout_res.remove(res)
 
-    # 新增:检测大block内部是否包含多个小block
-    # 合并ocr和table列表进行检测
+    # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     combined_res_list = ocr_res_list + filtered_table_res_list
-
-    # 计算每个block的坐标和面积
-    block_info = []
-    for block in combined_res_list:
-        xmin, ymin = int(block['poly'][0]), int(block['poly'][1])
-        xmax, ymax = int(block['poly'][4]), int(block['poly'][5])
-        area = (xmax - xmin) * (ymax - ymin)
-        score = block.get('score', 0.5)  # 如果没有score字段,默认为0.5
-        block_info.append((xmin, ymin, xmax, ymax, area, score, block))
-
-    blocks_to_remove = []
-
-    # 检查每个block内部是否有3个及以上的小block
-    for i, (xmin, ymin, xmax, ymax, area, score, block) in enumerate(block_info):
-        # 查找内部的小block
-        blocks_inside = [(j, j_score, j_block) for j, (xj_min, yj_min, xj_max, yj_max, j_area, j_score, j_block) in
-                         enumerate(block_info)
-                         if i != j and is_inside(block_info[j], block_info[i])]
-
-        # 如果内部有3个及以上的小block
-        if len(blocks_inside) >= 3:
-            # 计算小block的平均分数
-            avg_score = sum(s for _, s, _ in blocks_inside) / len(blocks_inside)
-
-            # 比较大block的分数和小block的平均分数
-            if score > avg_score:
-                # 保留大block,扩展其边界
-                # 首先将所有小block标记为要删除
-                for j, _, j_block in blocks_inside:
-                    if j_block not in blocks_to_remove:
-                        blocks_to_remove.append(j_block)
-
-                # 扩展大block的边界以包含所有小block
-                new_xmin, new_ymin, new_xmax, new_ymax = xmin, ymin, xmax, ymax
-                for _, _, j_block in blocks_inside:
-                    j_xmin, j_ymin = int(j_block['poly'][0]), int(j_block['poly'][1])
-                    j_xmax, j_ymax = int(j_block['poly'][4]), int(j_block['poly'][5])
-                    new_xmin = min(new_xmin, j_xmin)
-                    new_ymin = min(new_ymin, j_ymin)
-                    new_xmax = max(new_xmax, j_xmax)
-                    new_ymax = max(new_ymax, j_ymax)
-
-                # 更新大block的边界
-                block['poly'][0] = block['poly'][6] = new_xmin
-                block['poly'][1] = block['poly'][3] = new_ymin
-                block['poly'][2] = block['poly'][4] = new_xmax
-                block['poly'][5] = block['poly'][7] = new_ymax
-            else:
-                # 保留小blocks,删除大block
-                blocks_to_remove.append(block)
-
+    blocks_to_remove = remove_overlaps_low_confidence_blocks(combined_res_list)
     # 移除需要删除的blocks
     for block in blocks_to_remove:
         if block in ocr_res_list:
             ocr_res_list.remove(block)
         elif block in filtered_table_res_list:
             filtered_table_res_list.remove(block)
-
         # 同时从layout_res中删除
         if block in layout_res:
             layout_res.remove(block)