|
|
@@ -17,7 +17,7 @@ import cv2
|
|
|
import numpy as np
|
|
|
import threading
|
|
|
from pathlib import Path
|
|
|
-from typing import Dict, List, Union, Any, Optional
|
|
|
+from typing import Dict, List, Union, Any, Optional, Tuple
|
|
|
from PIL import Image
|
|
|
|
|
|
try:
|
|
|
@@ -87,23 +87,47 @@ class LayoutUtils:
|
|
|
|
|
|
return intersection / min_area
|
|
|
|
|
|
+ # 不允许合并的类别组合
|
|
|
+ FORBIDDEN_MERGE = {
|
|
|
+ 'image_body': ['text', 'title', 'table_body', 'table'],
|
|
|
+ 'figure': ['text', 'title', 'table_body', 'table'],
|
|
|
+ }
|
|
|
+
|
|
|
+ # 类别优先级(数字越大优先级越高)
|
|
|
+ CATEGORY_PRIORITY = {
|
|
|
+ 'text': 3,
|
|
|
+ 'title': 3,
|
|
|
+ 'table_body': 3,
|
|
|
+ 'table': 3,
|
|
|
+ 'image_body': 1,
|
|
|
+ 'figure': 1,
|
|
|
+ }
|
|
|
+
|
|
|
@staticmethod
|
|
|
def remove_overlapping_boxes(
|
|
|
layout_results: List[Dict[str, Any]],
|
|
|
iou_threshold: float = 0.8,
|
|
|
- overlap_ratio_threshold: float = 0.8
|
|
|
+ overlap_ratio_threshold: float = 0.8,
|
|
|
+ image_size: Optional[Tuple[int, int]] = None,
|
|
|
+ max_area_ratio: float = 0.8,
|
|
|
+ enable_category_restriction: bool = True,
|
|
|
+ enable_category_priority: bool = True
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
处理重叠的布局框(参考 MinerU 的去重策略)
|
|
|
|
|
|
策略:
|
|
|
- 1. 高 IoU 重叠:保留置信度高的框
|
|
|
- 2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
|
|
|
+ 1. 高 IoU 重叠:保留置信度高的框(考虑类别优先级)
|
|
|
+ 2. 包含关系:小框被大框高度包含时,检查类别限制和面积限制后决定是否合并
|
|
|
|
|
|
Args:
|
|
|
layout_results: Layout 检测结果列表
|
|
|
iou_threshold: IoU 阈值,超过此值认为高度重叠
|
|
|
overlap_ratio_threshold: 重叠面积占小框面积的比例阈值
|
|
|
+ image_size: 图像尺寸 (width, height),用于计算面积限制
|
|
|
+ max_area_ratio: 最大面积比例,合并后的框超过此比例则拒绝合并(默认0.8)
|
|
|
+ enable_category_restriction: 是否启用类别限制(默认True)
|
|
|
+ enable_category_priority: 是否启用类别优先级(默认True)
|
|
|
|
|
|
Returns:
|
|
|
去重后的布局结果列表
|
|
|
@@ -115,6 +139,44 @@ class LayoutUtils:
|
|
|
results = [item.copy() for item in layout_results]
|
|
|
need_remove = set()
|
|
|
|
|
|
+ # 计算图像总面积(如果提供了图像尺寸)
|
|
|
+ img_area = None
|
|
|
+ if image_size is not None:
|
|
|
+ img_width, img_height = image_size
|
|
|
+ img_area = img_width * img_height
|
|
|
+
|
|
|
+ def can_merge(cat1: str, cat2: str) -> bool:
|
|
|
+ """检查两个类别是否允许合并"""
|
|
|
+ if not enable_category_restriction:
|
|
|
+ return True
|
|
|
+
|
|
|
+ # 检查是否在禁止合并列表中
|
|
|
+ forbidden1 = LayoutUtils.FORBIDDEN_MERGE.get(cat1, [])
|
|
|
+ if cat2 in forbidden1:
|
|
|
+ return False
|
|
|
+
|
|
|
+ forbidden2 = LayoutUtils.FORBIDDEN_MERGE.get(cat2, [])
|
|
|
+ if cat1 in forbidden2:
|
|
|
+ return False
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+ def get_priority(category: str) -> int:
|
|
|
+ """获取类别优先级"""
|
|
|
+ if not enable_category_priority:
|
|
|
+ return 0
|
|
|
+ return LayoutUtils.CATEGORY_PRIORITY.get(category, 0)
|
|
|
+
|
|
|
+ def check_area_limit(merged_bbox: List[float]) -> bool:
|
|
|
+ """检查合并后的框是否超过面积限制"""
|
|
|
+ if img_area is None:
|
|
|
+ return True # 如果没有提供图像尺寸,不检查
|
|
|
+
|
|
|
+ merged_area = (merged_bbox[2] - merged_bbox[0]) * (merged_bbox[3] - merged_bbox[1])
|
|
|
+ area_ratio = merged_area / img_area if img_area > 0 else 0
|
|
|
+
|
|
|
+ return area_ratio <= max_area_ratio
|
|
|
+
|
|
|
for i in range(len(results)):
|
|
|
if i in need_remove:
|
|
|
continue
|
|
|
@@ -129,24 +191,41 @@ class LayoutUtils:
|
|
|
if len(bbox1) < 4 or len(bbox2) < 4:
|
|
|
continue
|
|
|
|
|
|
+ cat1 = results[i].get('category', 'unknown')
|
|
|
+ cat2 = results[j].get('category', 'unknown')
|
|
|
+
|
|
|
# 计算 IoU
|
|
|
iou = LayoutUtils.calculate_iou(bbox1, bbox2)
|
|
|
|
|
|
if iou > iou_threshold:
|
|
|
- # 高度重叠,保留置信度高的
|
|
|
+ # 高度重叠,保留置信度高的框(考虑类别优先级)
|
|
|
score1 = results[i].get('confidence', results[i].get('score', 0))
|
|
|
score2 = results[j].get('confidence', results[j].get('score', 0))
|
|
|
+ priority1 = get_priority(cat1)
|
|
|
+ priority2 = get_priority(cat2)
|
|
|
|
|
|
- if score1 >= score2:
|
|
|
+ # 如果类别优先级不同,优先保留高优先级
|
|
|
+ if priority1 != priority2:
|
|
|
+ if priority1 > priority2:
|
|
|
+ need_remove.add(j)
|
|
|
+ else:
|
|
|
+ need_remove.add(i)
|
|
|
+ break
|
|
|
+ # 如果类别优先级相同,保留置信度高的
|
|
|
+ elif score1 >= score2:
|
|
|
need_remove.add(j)
|
|
|
else:
|
|
|
need_remove.add(i)
|
|
|
- break # i 被移除,跳出内层循环
|
|
|
+ break
|
|
|
else:
|
|
|
# 检查包含关系
|
|
|
overlap_ratio = LayoutUtils.calculate_overlap_ratio(bbox1, bbox2)
|
|
|
|
|
|
if overlap_ratio > overlap_ratio_threshold:
|
|
|
+ # 检查类别是否允许合并
|
|
|
+ if not can_merge(cat1, cat2):
|
|
|
+ continue # 不允许合并,跳过
|
|
|
+
|
|
|
# 小框被大框高度包含
|
|
|
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
|
|
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
|
|
@@ -156,15 +235,31 @@ class LayoutUtils:
|
|
|
else:
|
|
|
small_idx, large_idx = j, i
|
|
|
|
|
|
- # 扩展大框的边界
|
|
|
+ # 计算合并后的框
|
|
|
small_bbox = results[small_idx]['bbox']
|
|
|
large_bbox = results[large_idx]['bbox']
|
|
|
- results[large_idx]['bbox'] = [
|
|
|
+ merged_bbox = [
|
|
|
min(small_bbox[0], large_bbox[0]),
|
|
|
min(small_bbox[1], large_bbox[1]),
|
|
|
max(small_bbox[2], large_bbox[2]),
|
|
|
max(small_bbox[3], large_bbox[3])
|
|
|
]
|
|
|
+
|
|
|
+ # 检查合并后的面积是否超过限制
|
|
|
+ if not check_area_limit(merged_bbox):
|
|
|
+ continue # 超过面积限制,拒绝合并
|
|
|
+
|
|
|
+ # 检查类别优先级:如果小框优先级更高,不应该被大框合并
|
|
|
+ small_cat = results[small_idx].get('category', 'unknown')
|
|
|
+ large_cat = results[large_idx].get('category', 'unknown')
|
|
|
+ small_priority = get_priority(small_cat)
|
|
|
+ large_priority = get_priority(large_cat)
|
|
|
+
|
|
|
+ if small_priority > large_priority:
|
|
|
+ continue # 小框优先级更高,不应该被合并
|
|
|
+
|
|
|
+ # 执行合并:扩展大框的边界
|
|
|
+ results[large_idx]['bbox'] = merged_bbox
|
|
|
need_remove.add(small_idx)
|
|
|
|
|
|
if small_idx == i:
|
|
|
@@ -172,6 +267,90 @@ class LayoutUtils:
|
|
|
|
|
|
# 返回去重后的结果
|
|
|
return [results[i] for i in range(len(results)) if i not in need_remove]
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def filter_false_positive_images(
|
|
|
+ layout_results: List[Dict[str, Any]],
|
|
|
+ min_text_area_ratio: float = 0.3
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 过滤误检的图片框:如果图片框内包含的其他类型(如text/title/table)的面积总和
|
|
|
+ 与图片框的面积比大于阈值,则认为该图片框是误检,应该移除。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ layout_results: Layout 检测结果列表
|
|
|
+ min_text_area_ratio: 最小文本面积比例阈值,如果图片框内文本面积占比超过此值则移除(默认0.3)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 过滤后的布局结果列表
|
|
|
+ """
|
|
|
+ if not layout_results:
|
|
|
+ return layout_results
|
|
|
+
|
|
|
+ # 需要移除的图片框索引
|
|
|
+ need_remove = set()
|
|
|
+
|
|
|
+ # 找出所有图片框
|
|
|
+ image_boxes = []
|
|
|
+ other_boxes = []
|
|
|
+
|
|
|
+ for i, result in enumerate(layout_results):
|
|
|
+ category = result.get('category', 'unknown')
|
|
|
+ if category in ['image_body', 'figure']:
|
|
|
+ image_boxes.append((i, result))
|
|
|
+ else:
|
|
|
+ other_boxes.append((i, result))
|
|
|
+
|
|
|
+ # 对每个图片框,检查其内部包含的其他类型框的面积
|
|
|
+ for img_idx, img_result in image_boxes:
|
|
|
+ img_bbox = img_result.get('bbox', [0, 0, 0, 0])
|
|
|
+ if len(img_bbox) < 4:
|
|
|
+ continue
|
|
|
+
|
|
|
+ img_area = (img_bbox[2] - img_bbox[0]) * (img_bbox[3] - img_bbox[1])
|
|
|
+ if img_area == 0:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 计算图片框内包含的其他类型框的总面积
|
|
|
+ total_contained_area = 0.0
|
|
|
+
|
|
|
+ for other_idx, other_result in other_boxes:
|
|
|
+ if other_idx in need_remove:
|
|
|
+ continue
|
|
|
+
|
|
|
+ other_bbox = other_result.get('bbox', [0, 0, 0, 0])
|
|
|
+ if len(other_bbox) < 4:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 检查其他框是否被图片框包含
|
|
|
+ # 使用 IoU 或包含关系判断
|
|
|
+ overlap_ratio = LayoutUtils.calculate_overlap_ratio(other_bbox, img_bbox)
|
|
|
+
|
|
|
+ # 如果其他框的大部分(>50%)都在图片框内,认为被包含
|
|
|
+ if overlap_ratio > 0.5:
|
|
|
+ other_area = (other_bbox[2] - other_bbox[0]) * (other_bbox[3] - other_bbox[1])
|
|
|
+ # 计算实际包含的面积(交集)
|
|
|
+ x1_i = max(img_bbox[0], other_bbox[0])
|
|
|
+ y1_i = max(img_bbox[1], other_bbox[1])
|
|
|
+ x2_i = min(img_bbox[2], other_bbox[2])
|
|
|
+ y2_i = min(img_bbox[3], other_bbox[3])
|
|
|
+
|
|
|
+ if x2_i > x1_i and y2_i > y1_i:
|
|
|
+ intersection_area = (x2_i - x1_i) * (y2_i - y1_i)
|
|
|
+ total_contained_area += intersection_area
|
|
|
+
|
|
|
+ # 计算文本面积占比
|
|
|
+ text_area_ratio = total_contained_area / img_area if img_area > 0 else 0.0
|
|
|
+
|
|
|
+ # 如果文本面积占比超过阈值,移除该图片框
|
|
|
+ if text_area_ratio > min_text_area_ratio:
|
|
|
+ need_remove.add(img_idx)
|
|
|
+ # 可选:打印调试信息
|
|
|
+ # print(f"🔄 Removed false positive image box: category={img_result.get('category')}, "
|
|
|
+ # f"bbox={img_bbox}, text_area_ratio={text_area_ratio:.2f} > {min_text_area_ratio}")
|
|
|
+
|
|
|
+ # 返回过滤后的结果
|
|
|
+ return [result for i, result in enumerate(layout_results) if i not in need_remove]
|
|
|
|
|
|
|
|
|
class DitLayoutDetector(BaseLayoutDetector):
|
|
|
@@ -214,6 +393,11 @@ class DitLayoutDetector(BaseLayoutDetector):
|
|
|
- remove_overlap: 是否启用重叠框处理 (默认 True)
|
|
|
- iou_threshold: IoU 阈值 (默认 0.8)
|
|
|
- overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
|
|
|
+ - max_area_ratio: 最大面积比例 (默认 0.8)
|
|
|
+ - enable_category_restriction: 是否启用类别限制 (默认 True)
|
|
|
+ - enable_category_priority: 是否启用类别优先级 (默认 True)
|
|
|
+ - filter_false_positive_images: 是否过滤误检的图片框 (默认 True)
|
|
|
+ - min_text_area_ratio: 最小文本面积比例阈值,图片框内文本面积占比超过此值则移除 (默认 0.3)
|
|
|
"""
|
|
|
super().__init__(config)
|
|
|
self.predictor = None
|
|
|
@@ -223,6 +407,11 @@ class DitLayoutDetector(BaseLayoutDetector):
|
|
|
self._remove_overlap = True
|
|
|
self._iou_threshold = 0.8
|
|
|
self._overlap_ratio_threshold = 0.8
|
|
|
+ self._max_area_ratio = 0.8 # 最大面积比例,合并后的框超过此比例则拒绝合并(默认0.8)
|
|
|
+ self._enable_category_restriction = True
|
|
|
+ self._enable_category_priority = True
|
|
|
+ self._filter_false_positive_images = True
|
|
|
+ self._min_text_area_ratio = 0.3
|
|
|
|
|
|
def initialize(self):
|
|
|
"""初始化模型"""
|
|
|
@@ -251,17 +440,16 @@ class DitLayoutDetector(BaseLayoutDetector):
|
|
|
|
|
|
# 添加 dit_support 路径(适配到 universal_doc_parser)
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
- dit_support_path = os.path.join(current_dir, '..', 'dit_support')
|
|
|
+ dit_support_path = Path(__file__).parents[2] / 'dit_support'
|
|
|
if dit_support_path not in sys.path:
|
|
|
- sys.path.insert(0, dit_support_path)
|
|
|
+ sys.path.insert(0, str(dit_support_path))
|
|
|
|
|
|
from ditod import add_vit_config
|
|
|
|
|
|
# 获取配置参数
|
|
|
config_file = self.config.get(
|
|
|
'config_file',
|
|
|
- os.path.join(current_dir, '..', 'dit_support', 'configs',
|
|
|
- 'cascade', 'cascade_dit_large.yaml')
|
|
|
+ dit_support_path / 'configs' / 'cascade' / 'cascade_dit_large.yaml'
|
|
|
)
|
|
|
model_weights = self.config.get(
|
|
|
'model_weights',
|
|
|
@@ -272,6 +460,11 @@ class DitLayoutDetector(BaseLayoutDetector):
|
|
|
self._remove_overlap = self.config.get('remove_overlap', True)
|
|
|
self._iou_threshold = self.config.get('iou_threshold', 0.8)
|
|
|
self._overlap_ratio_threshold = self.config.get('overlap_ratio_threshold', 0.8)
|
|
|
+ self._max_area_ratio = self.config.get('max_area_ratio', 0.8)
|
|
|
+ self._enable_category_restriction = self.config.get('enable_category_restriction', True)
|
|
|
+ self._enable_category_priority = self.config.get('enable_category_priority', True)
|
|
|
+ self._filter_false_positive_images = self.config.get('filter_false_positive_images', True)
|
|
|
+ self._min_text_area_ratio = self.config.get('min_text_area_ratio', 0.3)
|
|
|
|
|
|
# 设置设备
|
|
|
self._device = torch.device(device)
|
|
|
@@ -393,7 +586,7 @@ class DitLayoutDetector(BaseLayoutDetector):
|
|
|
# 过滤面积异常大的框
|
|
|
area = width * height
|
|
|
img_area = orig_w * orig_h
|
|
|
- if area > img_area * 0.95:
|
|
|
+ if area > img_area:
|
|
|
continue
|
|
|
|
|
|
# 生成多边形坐标
|
|
|
@@ -422,8 +615,23 @@ class DitLayoutDetector(BaseLayoutDetector):
|
|
|
formatted_results = LayoutUtils.remove_overlapping_boxes(
|
|
|
formatted_results,
|
|
|
iou_threshold=self._iou_threshold,
|
|
|
- overlap_ratio_threshold=self._overlap_ratio_threshold
|
|
|
+ overlap_ratio_threshold=self._overlap_ratio_threshold,
|
|
|
+ image_size=(orig_w, orig_h),
|
|
|
+ max_area_ratio=self._max_area_ratio,
|
|
|
+ enable_category_restriction=self._enable_category_restriction,
|
|
|
+ enable_category_priority=self._enable_category_priority
|
|
|
+ )
|
|
|
+
|
|
|
+ # 过滤误检的图片框(包含过多文本内容的图片框)
|
|
|
+ if self._filter_false_positive_images and len(formatted_results) > 1:
|
|
|
+ before_count = len(formatted_results)
|
|
|
+ formatted_results = LayoutUtils.filter_false_positive_images(
|
|
|
+ formatted_results,
|
|
|
+ min_text_area_ratio=self._min_text_area_ratio
|
|
|
)
|
|
|
+ removed_count = before_count - len(formatted_results)
|
|
|
+ if removed_count > 0:
|
|
|
+ print(f"🔄 Filtered {removed_count} false positive image boxes")
|
|
|
|
|
|
return formatted_results
|
|
|
|
|
|
@@ -457,7 +665,7 @@ class DitLayoutDetector(BaseLayoutDetector):
|
|
|
self,
|
|
|
img: np.ndarray,
|
|
|
results: List[Dict],
|
|
|
- output_path: str = None,
|
|
|
+ output_path: Optional[str] = None,
|
|
|
show_confidence: bool = True,
|
|
|
min_confidence: float = 0.0
|
|
|
) -> np.ndarray:
|