|
|
@@ -242,6 +242,15 @@ class CellFusionEngine:
|
|
|
- stats: {'merged': int, 'added': int, 'merged_cells': int}
|
|
|
- cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
|
|
|
"""
|
|
|
+
|
|
|
+ # 计算unet_cells的边界框bbox[x1,y1,x2,y2]
|
|
|
+ unet_bbox = [
|
|
|
+ min(unet_cells, key=lambda box: box[0])[0], \
|
|
|
+ min(unet_cells, key=lambda box: box[1])[1], \
|
|
|
+ max(unet_cells, key=lambda box: box[2])[2], \
|
|
|
+ max(unet_cells, key=lambda box: box[3])[3]
|
|
|
+ ] if unet_cells else [0,0,0,0]
|
|
|
+
|
|
|
fused_cells = []
|
|
|
cell_labels = [] # 记录每个单元格的来源标签
|
|
|
unet_matched = [False] * len(unet_cells)
|
|
|
@@ -303,16 +312,20 @@ class CellFusionEngine:
|
|
|
rtdetr_area = self._calc_bbox_area(rtdetr_cell)
|
|
|
coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
|
|
|
|
|
|
- # 如果覆盖率>40%,说明这是一个真实的合并单元格
|
|
|
- # 降低阈值从0.5到0.4,因为合并单元格可能包含很多空白区域
|
|
|
- if coverage > 0.4:
|
|
|
- # 认定为合并单元格,取bounding与RT-DETR的最大范围
|
|
|
+ # 如果覆盖率>50%,说明这是一个真实的合并单元格
|
|
|
+ if coverage > 0.5:
|
|
|
+ # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过unet_bbox范围
|
|
|
fused_cell = [
|
|
|
min(bounding_x1, rtdetr_cell[0]),
|
|
|
min(bounding_y1, rtdetr_cell[1]),
|
|
|
max(bounding_x2, rtdetr_cell[2]),
|
|
|
max(bounding_y2, rtdetr_cell[3])
|
|
|
]
|
|
|
+ # x限制在unet_bbox范围内
|
|
|
+ fused_cell[0] = max(fused_cell[0], unet_bbox[0])
|
|
|
+ # fused_cell[1] = max(fused_cell[1], unet_bbox[1])
|
|
|
+ fused_cell[2] = min(fused_cell[2], unet_bbox[2])
|
|
|
+ # fused_cell[3] = min(fused_cell[3], unet_bbox[3])
|
|
|
fused_cells.append(fused_cell)
|
|
|
cell_labels.append('merged_span') # 标记为合并单元格
|
|
|
rtdetr_matched[rt_idx] = True
|
|
|
@@ -366,6 +379,11 @@ class CellFusionEngine:
|
|
|
# Step 3: 补充 RT-DETR 独有的高置信度单元格
|
|
|
for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
|
|
|
if not rtdetr_matched[idx] and score > 0.7:
|
|
|
+ # rtdetr_cell不能超出unet_bbox范围, x方向分别限制
|
|
|
+ rtdetr_cell[0] = max(rtdetr_cell[0], unet_bbox[0])
|
|
|
+ # rtdetr_cell[1] = max(rtdetr_cell[1], unet_bbox[1])
|
|
|
+ rtdetr_cell[2] = min(rtdetr_cell[2], unet_bbox[2])
|
|
|
+ # rtdetr_cell[3] = min(rtdetr_cell[3], unet_bbox[3])
|
|
|
fused_cells.append(rtdetr_cell)
|
|
|
cell_labels.append('rtdetr_only') # 标记为RT-DETR独有
|
|
|
stats['added'] += 1
|