|
@@ -2,6 +2,10 @@ from abc import ABC, abstractmethod
|
|
|
from typing import Dict, Any, List, Union, Optional, Tuple
|
|
from typing import Dict, Any, List, Union, Optional, Tuple
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
from PIL import Image
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+import cv2
|
|
|
|
|
+import json
|
|
|
|
|
|
|
|
class BaseAdapter(ABC):
|
|
class BaseAdapter(ABC):
|
|
|
"""基础适配器接口"""
|
|
"""基础适配器接口"""
|
|
@@ -44,6 +48,18 @@ class BasePreprocessor(BaseAdapter):
|
|
|
class BaseLayoutDetector(BaseAdapter):
|
|
class BaseLayoutDetector(BaseAdapter):
|
|
|
"""版式检测器基类"""
|
|
"""版式检测器基类"""
|
|
|
|
|
|
|
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
+ """初始化版式检测器
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ config: 配置字典
|
|
|
|
|
+ """
|
|
|
|
|
+ super().__init__(config)
|
|
|
|
|
+ # 初始化 debug 相关属性(支持从配置或运行时设置)
|
|
|
|
|
+ self.debug_mode = None # 将在 detect() 方法中从配置读取
|
|
|
|
|
+ self.output_dir = None # 将在 detect() 方法中从配置读取
|
|
|
|
|
+ self.page_name = None # 将在 detect() 方法中从配置读取
|
|
|
|
|
+
|
|
|
def detect(
|
|
def detect(
|
|
|
self,
|
|
self,
|
|
|
image: Union[np.ndarray, Image.Image],
|
|
image: Union[np.ndarray, Image.Image],
|
|
@@ -66,6 +82,58 @@ class BaseLayoutDetector(BaseAdapter):
|
|
|
# 调用子类实现的原始检测方法
|
|
# 调用子类实现的原始检测方法
|
|
|
layout_results = self._detect_raw(image, ocr_spans)
|
|
layout_results = self._detect_raw(image, ocr_spans)
|
|
|
|
|
|
|
|
|
|
+ # Debug 模式:打印和可视化后处理前的检测结果
|
|
|
|
|
+ # 优先从实例属性读取(如果存在),否则从配置读取
|
|
|
|
|
+ # 支持两种配置方式:debug_mode 或 debug_options.enabled
|
|
|
|
|
+ debug_mode = getattr(self, 'debug_mode', None)
|
|
|
|
|
+ if debug_mode is None:
|
|
|
|
|
+ if hasattr(self, 'config'):
|
|
|
|
|
+ # 优先从 debug_mode 读取
|
|
|
|
|
+ debug_mode = self.config.get('debug_mode', False)
|
|
|
|
|
+ # 如果没有 debug_mode,尝试从 debug_options.enabled 读取
|
|
|
|
|
+ if not debug_mode:
|
|
|
|
|
+ debug_options = self.config.get('debug_options', {})
|
|
|
|
|
+ if isinstance(debug_options, dict):
|
|
|
|
|
+ debug_mode = debug_options.get('enabled', False)
|
|
|
|
|
+ else:
|
|
|
|
|
+ debug_mode = False
|
|
|
|
|
+
|
|
|
|
|
+ if debug_mode:
|
|
|
|
|
+ logger.debug(f"🔍 Layout detection raw results (before post-processing): {len(layout_results)} elements")
|
|
|
|
|
+ # logger.debug(f"Raw layout_results: {layout_results}")
|
|
|
|
|
+ # 可视化 layout 结果
|
|
|
|
|
+ output_dir = getattr(self, 'output_dir', None)
|
|
|
|
|
+ if output_dir is None:
|
|
|
|
|
+ if hasattr(self, 'config'):
|
|
|
|
|
+ # 优先从 output_dir 读取
|
|
|
|
|
+ output_dir = self.config.get('output_dir', None)
|
|
|
|
|
+ # 如果没有 output_dir,尝试从 debug_options.output_dir 读取
|
|
|
|
|
+ if output_dir is None:
|
|
|
|
|
+ debug_options = self.config.get('debug_options', {})
|
|
|
|
|
+ if isinstance(debug_options, dict):
|
|
|
|
|
+ output_dir = debug_options.get('output_dir', None)
|
|
|
|
|
+ else:
|
|
|
|
|
+ output_dir = None
|
|
|
|
|
+
|
|
|
|
|
+ page_name = getattr(self, 'page_name', None)
|
|
|
|
|
+ if page_name is None:
|
|
|
|
|
+ if hasattr(self, 'config'):
|
|
|
|
|
+ # 优先从 page_name 读取
|
|
|
|
|
+ page_name = self.config.get('page_name', None)
|
|
|
|
|
+ # 如果没有 page_name,尝试从 debug_options.prefix 读取
|
|
|
|
|
+ if page_name is None:
|
|
|
|
|
+ debug_options = self.config.get('debug_options', {})
|
|
|
|
|
+ if isinstance(debug_options, dict):
|
|
|
|
|
+ prefix = debug_options.get('prefix', '')
|
|
|
|
|
+ page_name = prefix if prefix else 'layout_detection'
|
|
|
|
|
+ if page_name is None:
|
|
|
|
|
+ page_name = 'layout_detection'
|
|
|
|
|
+ else:
|
|
|
|
|
+ page_name = 'layout_detection'
|
|
|
|
|
+
|
|
|
|
|
+ if output_dir:
|
|
|
|
|
+ self._visualize_layout_results(image, layout_results, output_dir, page_name, suffix='raw')
|
|
|
|
|
+
|
|
|
# 自动执行后处理
|
|
# 自动执行后处理
|
|
|
if layout_results:
|
|
if layout_results:
|
|
|
layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
|
|
layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
|
|
@@ -132,7 +200,7 @@ class BaseLayoutDetector(BaseAdapter):
|
|
|
return layout_results
|
|
return layout_results
|
|
|
|
|
|
|
|
# 1. 去除重叠框
|
|
# 1. 去除重叠框
|
|
|
- layout_results = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
|
|
|
|
|
|
|
+ layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
|
|
|
|
|
|
|
|
# 2. 将大面积文本块转换为表格(如果配置启用)
|
|
# 2. 将大面积文本块转换为表格(如果配置启用)
|
|
|
layout_config = config if config is not None else {}
|
|
layout_config = config if config is not None else {}
|
|
@@ -143,94 +211,15 @@ class BaseLayoutDetector(BaseAdapter):
|
|
|
else:
|
|
else:
|
|
|
h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
|
|
h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
|
|
|
|
|
|
|
|
- layout_results = self._convert_large_text_to_table(
|
|
|
|
|
- layout_results,
|
|
|
|
|
|
|
+ layout_results_converted_large_text = self._convert_large_text_to_table(
|
|
|
|
|
+ layout_results_removed_overlapping,
|
|
|
(h, w),
|
|
(h, w),
|
|
|
min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
|
|
min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
|
|
|
min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
|
|
min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
|
|
|
min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
|
|
min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- return layout_results
|
|
|
|
|
-
|
|
|
|
|
- def _remove_overlapping_boxes(
|
|
|
|
|
- self,
|
|
|
|
|
- layout_results: List[Dict[str, Any]],
|
|
|
|
|
- coordinate_utils: Any,
|
|
|
|
|
- iou_threshold: float = 0.8,
|
|
|
|
|
- overlap_ratio_threshold: float = 0.8
|
|
|
|
|
- ) -> List[Dict[str, Any]]:
|
|
|
|
|
- """
|
|
|
|
|
- 处理重叠的布局框(参考 MinerU 的去重策略)
|
|
|
|
|
-
|
|
|
|
|
- 策略:
|
|
|
|
|
- 1. 高 IoU 重叠:保留置信度高的框
|
|
|
|
|
- 2. 包含关系:小框被大框高度包含时,保留大框并扩展边界
|
|
|
|
|
- """
|
|
|
|
|
- if not layout_results or len(layout_results) <= 1:
|
|
|
|
|
- return layout_results
|
|
|
|
|
-
|
|
|
|
|
- # 复制列表避免修改原数据
|
|
|
|
|
- results = [item.copy() for item in layout_results]
|
|
|
|
|
- need_remove = set()
|
|
|
|
|
-
|
|
|
|
|
- for i in range(len(results)):
|
|
|
|
|
- if i in need_remove:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- for j in range(i + 1, len(results)):
|
|
|
|
|
- if j in need_remove:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- bbox1 = results[i].get('bbox', [0, 0, 0, 0])
|
|
|
|
|
- bbox2 = results[j].get('bbox', [0, 0, 0, 0])
|
|
|
|
|
-
|
|
|
|
|
- if len(bbox1) < 4 or len(bbox2) < 4:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- # 计算 IoU
|
|
|
|
|
- iou = coordinate_utils.calculate_iou(bbox1, bbox2)
|
|
|
|
|
-
|
|
|
|
|
- if iou > iou_threshold:
|
|
|
|
|
- # 高度重叠,保留置信度高的
|
|
|
|
|
- score1 = results[i].get('confidence', results[i].get('score', 0))
|
|
|
|
|
- score2 = results[j].get('confidence', results[j].get('score', 0))
|
|
|
|
|
-
|
|
|
|
|
- if score1 >= score2:
|
|
|
|
|
- need_remove.add(j)
|
|
|
|
|
- else:
|
|
|
|
|
- need_remove.add(i)
|
|
|
|
|
- break # i 被移除,跳出内层循环
|
|
|
|
|
- else:
|
|
|
|
|
- # 检查包含关系
|
|
|
|
|
- overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
|
|
|
|
|
-
|
|
|
|
|
- if overlap_ratio > overlap_ratio_threshold:
|
|
|
|
|
- # 小框被大框高度包含
|
|
|
|
|
- area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
|
|
|
|
- area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
|
|
|
|
-
|
|
|
|
|
- if area1 <= area2:
|
|
|
|
|
- small_idx, large_idx = i, j
|
|
|
|
|
- else:
|
|
|
|
|
- small_idx, large_idx = j, i
|
|
|
|
|
-
|
|
|
|
|
- # 扩展大框的边界
|
|
|
|
|
- small_bbox = results[small_idx]['bbox']
|
|
|
|
|
- large_bbox = results[large_idx]['bbox']
|
|
|
|
|
- results[large_idx]['bbox'] = [
|
|
|
|
|
- min(small_bbox[0], large_bbox[0]),
|
|
|
|
|
- min(small_bbox[1], large_bbox[1]),
|
|
|
|
|
- max(small_bbox[2], large_bbox[2]),
|
|
|
|
|
- max(small_bbox[3], large_bbox[3])
|
|
|
|
|
- ]
|
|
|
|
|
- need_remove.add(small_idx)
|
|
|
|
|
-
|
|
|
|
|
- if small_idx == i:
|
|
|
|
|
- break # i 被移除,跳出内层循环
|
|
|
|
|
-
|
|
|
|
|
- # 返回去重后的结果
|
|
|
|
|
- return [results[i] for i in range(len(results)) if i not in need_remove]
|
|
|
|
|
|
|
+ return layout_results_converted_large_text
|
|
|
|
|
|
|
|
def _convert_large_text_to_table(
|
|
def _convert_large_text_to_table(
|
|
|
self,
|
|
self,
|
|
@@ -324,6 +313,281 @@ class BaseLayoutDetector(BaseAdapter):
|
|
|
101: 'image_footnote'
|
|
101: 'image_footnote'
|
|
|
}
|
|
}
|
|
|
return category_map.get(category_id, f'unknown_{category_id}')
|
|
return category_map.get(category_id, f'unknown_{category_id}')
|
|
|
|
|
+
|
|
|
|
|
+ def _visualize_layout_results(
|
|
|
|
|
+ self,
|
|
|
|
|
+ image: Union[np.ndarray, Image.Image],
|
|
|
|
|
+ layout_results: List[Dict[str, Any]],
|
|
|
|
|
+ output_dir: str,
|
|
|
|
|
+ page_name: str,
|
|
|
|
|
+ suffix: str = 'raw'
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ 可视化 layout 检测结果
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: 输入图像
|
|
|
|
|
+ layout_results: 布局检测结果
|
|
|
|
|
+ output_dir: 输出目录
|
|
|
|
|
+ page_name: 页面名称
|
|
|
|
|
+ suffix: 文件名后缀(如 'raw', 'postprocessed')
|
|
|
|
|
+ """
|
|
|
|
|
+ if not layout_results:
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 转换为 numpy 数组
|
|
|
|
|
+ if isinstance(image, Image.Image):
|
|
|
|
|
+ vis_image = np.array(image)
|
|
|
|
|
+ if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
|
|
|
|
|
+ # PIL RGB -> OpenCV BGR
|
|
|
|
|
+ vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ else:
|
|
|
|
|
+ vis_image = image.copy()
|
|
|
|
|
+ if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
|
|
|
|
|
+ # 如果是 RGB,转换为 BGR
|
|
|
|
|
+ vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
+
|
|
|
|
|
+ # 定义类别颜色映射 (BGR格式)
|
|
|
|
|
+ category_colors = {
|
|
|
|
|
+ 'table_body': (0, 0, 255), # 红色
|
|
|
|
|
+ 'table_caption': (0, 0, 200), # 暗红色
|
|
|
|
|
+ 'table_footnote': (0, 0, 150), # 更暗的红色
|
|
|
|
|
+ 'text': (255, 0, 0), # 蓝色
|
|
|
|
|
+ 'title': (0, 255, 255), # 黄色
|
|
|
|
|
+ 'header': (255, 0, 255), # 紫色
|
|
|
|
|
+ 'footer': (0, 165, 255), # 橙色
|
|
|
|
|
+ 'image_body': (0, 255, 0), # 绿色
|
|
|
|
|
+ 'image_caption': (0, 200, 0), # 暗绿色
|
|
|
|
|
+ 'image_footnote': (0, 150, 0), # 更暗的绿色
|
|
|
|
|
+ 'abandon': (128, 128, 128), # 灰色
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制检测框
|
|
|
|
|
+ for result in layout_results:
|
|
|
|
|
+ bbox = result.get('bbox', [])
|
|
|
|
|
+ if not bbox or len(bbox) < 4:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ category = result.get('category', 'unknown')
|
|
|
|
|
+ color = category_colors.get(category, (128, 128, 128)) # 默认灰色
|
|
|
|
|
+ thickness = 2
|
|
|
|
|
+
|
|
|
|
|
+ x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
|
|
|
|
+ cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, thickness)
|
|
|
|
|
+
|
|
|
|
|
+ # 添加类别标签
|
|
|
|
|
+ label = f"{category}"
|
|
|
|
|
+ confidence = result.get('confidence', result.get('score', 0))
|
|
|
|
|
+ if confidence:
|
|
|
|
|
+ label += f":{confidence:.2f}"
|
|
|
|
|
+
|
|
|
|
|
+ # 计算文本大小
|
|
|
|
|
+ font = cv2.FONT_HERSHEY_SIMPLEX
|
|
|
|
|
+ font_scale = 0.4
|
|
|
|
|
+ text_thickness = 1
|
|
|
|
|
+ (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, text_thickness)
|
|
|
|
|
+
|
|
|
|
|
+ # 在框的上方绘制文本背景
|
|
|
|
|
+ text_y = max(y1 - baseline - 1, text_height + baseline)
|
|
|
|
|
+ cv2.rectangle(vis_image, (x1, text_y - text_height - baseline - 2),
|
|
|
|
|
+ (x1 + text_width, text_y), color, -1)
|
|
|
|
|
+ # 绘制文本
|
|
|
|
|
+ cv2.putText(vis_image, label, (x1, text_y - baseline - 1),
|
|
|
|
|
+ font, font_scale, (255, 255, 255), text_thickness)
|
|
|
|
|
+
|
|
|
|
|
+ # 保存图像
|
|
|
|
|
+ debug_dir = Path(output_dir) / "debug_comparison" / "layout_detection"
|
|
|
|
|
+ debug_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ output_path = debug_dir / f"{page_name}_layout_{suffix}.jpg"
|
|
|
|
|
+ cv2.imwrite(str(output_path), vis_image)
|
|
|
|
|
+ logger.info(f"📊 Saved layout detection image ({suffix}): {output_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # 保存 JSON 数据
|
|
|
|
|
+ json_data = {
|
|
|
|
|
+ 'page_name': page_name,
|
|
|
|
|
+ 'suffix': suffix,
|
|
|
|
|
+ 'count': len(layout_results),
|
|
|
|
|
+ 'results': [
|
|
|
|
|
+ {
|
|
|
|
|
+ 'category': r.get('category'),
|
|
|
|
|
+ 'bbox': r.get('bbox'),
|
|
|
|
|
+ 'confidence': r.get('confidence', r.get('score', 0.0))
|
|
|
|
|
+ }
|
|
|
|
|
+ for r in layout_results
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ json_path = debug_dir / f"{page_name}_layout_{suffix}.json"
|
|
|
|
|
+ with open(json_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
+ logger.info(f"📊 Saved layout detection JSON ({suffix}): {json_path}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"⚠️ Failed to visualize layout results: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ def _remove_overlapping_boxes(
|
|
|
|
|
+ self,
|
|
|
|
|
+ layout_results: List[Dict[str, Any]],
|
|
|
|
|
+ coordinate_utils: Any,
|
|
|
|
|
+ iou_threshold: float = 0.8,
|
|
|
|
|
+ overlap_ratio_threshold: float = 0.8
|
|
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 改进版重叠框处理算法(基于优先级和决策规则的清晰算法)
|
|
|
|
|
+
|
|
|
|
|
+ 策略:
|
|
|
|
|
+ 1. 定义类别优先级(abandon < text/image < table_body)
|
|
|
|
|
+ 2. 使用统一的决策规则
|
|
|
|
|
+ 3. 按综合评分排序处理,优先保留大的聚合框
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ layout_results: 布局检测结果
|
|
|
|
|
+ coordinate_utils: 坐标工具类
|
|
|
|
|
+ iou_threshold: IoU阈值(默认0.8)
|
|
|
|
|
+ overlap_ratio_threshold: 重叠比例阈值(默认0.8)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 去重后的布局结果
|
|
|
|
|
+ """
|
|
|
|
|
+ if not layout_results or len(layout_results) <= 1:
|
|
|
|
|
+ return layout_results
|
|
|
|
|
+
|
|
|
|
|
+ # 常量定义
|
|
|
|
|
+ CATEGORY_PRIORITY = {
|
|
|
|
|
+ 'abandon': 0,
|
|
|
|
|
+ 'text': 1,
|
|
|
|
|
+ 'image_body': 1,
|
|
|
|
|
+ 'title': 2,
|
|
|
|
|
+ 'footer': 2,
|
|
|
|
|
+ 'header': 2,
|
|
|
|
|
+ 'table_body': 3,
|
|
|
|
|
+ }
|
|
|
|
|
+ AGGREGATE_LABELS = {'key-value region', 'form'}
|
|
|
|
|
+ MAX_AREA = 4000000.0 # 用于面积归一化
|
|
|
|
|
+ AREA_WEIGHT = 0.5
|
|
|
|
|
+ CONFIDENCE_WEIGHT = 0.5
|
|
|
|
|
+ AGGREGATE_BONUS = 0.1
|
|
|
|
|
+ AREA_RATIO_THRESHOLD = 3.0 # 大框面积需大于小框的倍数
|
|
|
|
|
+
|
|
|
|
|
+ def get_bbox_area(bbox: List[float]) -> float:
|
|
|
|
|
+ """计算bbox面积"""
|
|
|
|
|
+ if len(bbox) < 4:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+ return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
|
|
|
|
+
|
|
|
|
|
+ def is_aggregate_type(box: Dict[str, Any]) -> bool:
|
|
|
|
|
+ """检查是否是聚合类型"""
|
|
|
|
|
+ original_label = box.get('raw', {}).get('original_label', '').lower()
|
|
|
|
|
+ return original_label in AGGREGATE_LABELS
|
|
|
|
|
+
|
|
|
|
|
+ def is_bbox_inside(inner: List[float], outer: List[float]) -> bool:
|
|
|
|
|
+ """检查inner是否完全包含在outer内"""
|
|
|
|
|
+ if len(inner) < 4 or len(outer) < 4:
|
|
|
|
|
+ return False
|
|
|
|
|
+ return (inner[0] >= outer[0] and inner[1] >= outer[1] and
|
|
|
|
|
+ inner[2] <= outer[2] and inner[3] <= outer[3])
|
|
|
|
|
+
|
|
|
|
|
+ def calculate_composite_score(box: Dict[str, Any], area: float) -> float:
|
|
|
|
|
+ """计算text类型的综合评分(面积+置信度)"""
|
|
|
|
|
+ if box.get('category') != 'text':
|
|
|
|
|
+ return box.get('confidence', box.get('score', 0))
|
|
|
|
|
+
|
|
|
|
|
+ normalized_area = min(area / MAX_AREA, 1.0)
|
|
|
|
|
+ area_score = (normalized_area ** 0.5) * AREA_WEIGHT
|
|
|
|
|
+ confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT
|
|
|
|
|
+ bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0
|
|
|
|
|
+ return area_score + confidence_score + bonus
|
|
|
|
|
+
|
|
|
|
|
+ def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any],
|
|
|
|
|
+ iou: float, overlap_ratio: float,
|
|
|
|
|
+ contained_1_in_2: bool, contained_2_in_1: bool) -> bool:
|
|
|
|
|
+ """判断是否应该保留box1"""
|
|
|
|
|
+ # 提取基本信息
|
|
|
|
|
+ cat1, cat2 = box1.get('category', ''), box2.get('category', '')
|
|
|
|
|
+ score1 = box1.get('confidence', box1.get('score', 0))
|
|
|
|
|
+ score2 = box2.get('confidence', box2.get('score', 0))
|
|
|
|
|
+ bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0])
|
|
|
|
|
+ area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2)
|
|
|
|
|
+ is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2)
|
|
|
|
|
+
|
|
|
|
|
+ # 规则1: 类别优先级
|
|
|
|
|
+ priority1 = CATEGORY_PRIORITY.get(cat1, 1)
|
|
|
|
|
+ priority2 = CATEGORY_PRIORITY.get(cat2, 1)
|
|
|
|
|
+ if priority1 != priority2:
|
|
|
|
|
+ return priority1 > priority2
|
|
|
|
|
+
|
|
|
|
|
+ # 规则2: 包含关系 + 聚合类型优先
|
|
|
|
|
+ if contained_2_in_1 and is_agg1 and not is_agg2:
|
|
|
|
|
+ return True
|
|
|
|
|
+ if contained_1_in_2 and is_agg2 and not is_agg1:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 规则3: 包含关系 + 面积比例
|
|
|
|
|
+ if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD:
|
|
|
|
|
+ return True
|
|
|
|
|
+ if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 规则4: text类型使用综合评分
|
|
|
|
|
+ if cat1 == 'text' or cat2 == 'text':
|
|
|
|
|
+ comp_score1 = calculate_composite_score(box1, area1)
|
|
|
|
|
+ comp_score2 = calculate_composite_score(box2, area2)
|
|
|
|
|
+ if abs(comp_score1 - comp_score2) > 0.05:
|
|
|
|
|
+ return comp_score1 > comp_score2
|
|
|
|
|
+
|
|
|
|
|
+ # 规则5: 置信度比较
|
|
|
|
|
+ if abs(score1 - score2) > 0.1:
|
|
|
|
|
+ return score1 > score2
|
|
|
|
|
+
|
|
|
|
|
+ # 规则6: 面积比较
|
|
|
|
|
+ return area1 >= area2
|
|
|
|
|
+
|
|
|
|
|
+ # 主处理逻辑
|
|
|
|
|
+ results = [item.copy() for item in layout_results]
|
|
|
|
|
+ need_remove = set()
|
|
|
|
|
+
|
|
|
|
|
+ # 按综合评分排序(高分优先)
|
|
|
|
|
+ def get_sort_key(i: int) -> float:
|
|
|
|
|
+ item = results[i]
|
|
|
|
|
+ if item.get('category') == 'text':
|
|
|
|
|
+ return -calculate_composite_score(item, get_bbox_area(item.get('bbox', [])))
|
|
|
|
|
+ return -item.get('confidence', item.get('score', 0))
|
|
|
|
|
+
|
|
|
|
|
+ sorted_indices = sorted(range(len(results)), key=get_sort_key)
|
|
|
|
|
+
|
|
|
|
|
+ # 比较每对框
|
|
|
|
|
+ for idx_i, i in enumerate(sorted_indices):
|
|
|
|
|
+ if i in need_remove:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ for idx_j, j in enumerate(sorted_indices):
|
|
|
|
|
+ if j == i or j in need_remove or idx_j >= idx_i:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', [])
|
|
|
|
|
+ if len(bbox1) < 4 or len(bbox2) < 4:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 计算重叠指标
|
|
|
|
|
+ iou = coordinate_utils.calculate_iou(bbox1, bbox2)
|
|
|
|
|
+ overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
|
|
|
|
|
+ contained_1_in_2 = is_bbox_inside(bbox1, bbox2)
|
|
|
|
|
+ contained_2_in_1 = is_bbox_inside(bbox2, bbox1)
|
|
|
|
|
+
|
|
|
|
|
+ # 检查是否有显著重叠
|
|
|
|
|
+ if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or
|
|
|
|
|
+ contained_1_in_2 or contained_2_in_1):
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 应用决策规则
|
|
|
|
|
+ if should_keep_box1(results[i], results[j], iou, overlap_ratio,
|
|
|
|
|
+ contained_1_in_2, contained_2_in_1):
|
|
|
|
|
+ need_remove.add(j)
|
|
|
|
|
+ else:
|
|
|
|
|
+ need_remove.add(i)
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ return [results[i] for i in range(len(results)) if i not in need_remove]
|
|
|
|
|
|
|
|
class BaseVLRecognizer(BaseAdapter):
|
|
class BaseVLRecognizer(BaseAdapter):
|
|
|
"""VL识别器基类"""
|
|
"""VL识别器基类"""
|