|
@@ -140,12 +140,13 @@ class CellFusionEngine:
|
|
|
# 如果不使用RT-DETR,直接返回UNet结果
|
|
# 如果不使用RT-DETR,直接返回UNet结果
|
|
|
if not use_rtdetr:
|
|
if not use_rtdetr:
|
|
|
fused_cells = unet_cells.copy()
|
|
fused_cells = unet_cells.copy()
|
|
|
|
|
+ cell_labels = ['unet_only'] * len(fused_cells) # 所有都是UNet独有
|
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
|
|
|
|
|
|
# 可选:OCR补偿
|
|
# 可选:OCR补偿
|
|
|
if self.enable_ocr_compensation and ocr_boxes:
|
|
if self.enable_ocr_compensation and ocr_boxes:
|
|
|
- fused_cells, ocr_comp_count = self._compensate_with_ocr(
|
|
|
|
|
- fused_cells, ocr_boxes, (w, h)
|
|
|
|
|
|
|
+ fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
|
|
|
|
|
+ fused_cells, cell_labels, ocr_boxes, (w, h)
|
|
|
)
|
|
)
|
|
|
fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
fusion_stats['fused_count'] = len(fused_cells)
|
|
@@ -159,6 +160,8 @@ class CellFusionEngine:
|
|
|
table_image,
|
|
table_image,
|
|
|
conf_threshold=self.rtdetr_conf_threshold
|
|
conf_threshold=self.rtdetr_conf_threshold
|
|
|
)
|
|
)
|
|
|
|
|
+ # rtdetr_result从上到下,从左到右排序
|
|
|
|
|
+ rtdetr_results.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
|
rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
|
fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
@@ -171,10 +174,11 @@ class CellFusionEngine:
|
|
|
return fused_cells, fusion_stats
|
|
return fused_cells, fusion_stats
|
|
|
|
|
|
|
|
# Phase 2: 智能融合
|
|
# Phase 2: 智能融合
|
|
|
- fused_cells, merge_stats = self._fuse_cells(
|
|
|
|
|
|
|
+ fused_cells, merge_stats, cell_labels = self._fuse_cells(
|
|
|
unet_cells, rtdetr_cells, rtdetr_scores
|
|
unet_cells, rtdetr_cells, rtdetr_scores
|
|
|
)
|
|
)
|
|
|
fusion_stats['merged_count'] = merge_stats['merged']
|
|
fusion_stats['merged_count'] = merge_stats['merged']
|
|
|
|
|
+ fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
|
|
|
fusion_stats['added_count'] = merge_stats['added']
|
|
fusion_stats['added_count'] = merge_stats['added']
|
|
|
|
|
|
|
|
# Phase 3: NMS 去重
|
|
# Phase 3: NMS 去重
|
|
@@ -182,8 +186,8 @@ class CellFusionEngine:
|
|
|
|
|
|
|
|
# Phase 4: OCR 补偿(可选)
|
|
# Phase 4: OCR 补偿(可选)
|
|
|
if self.enable_ocr_compensation and ocr_boxes:
|
|
if self.enable_ocr_compensation and ocr_boxes:
|
|
|
- fused_cells, ocr_comp_count = self._compensate_with_ocr(
|
|
|
|
|
- fused_cells, ocr_boxes, (w, h)
|
|
|
|
|
|
|
+ fused_cells, cell_labels, ocr_comp_count = self._compensate_with_ocr(
|
|
|
|
|
+ fused_cells, cell_labels, ocr_boxes, (w, h)
|
|
|
)
|
|
)
|
|
|
fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
fusion_stats['ocr_compensated_count'] = ocr_comp_count
|
|
|
|
|
|
|
@@ -191,14 +195,14 @@ class CellFusionEngine:
|
|
|
|
|
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
|
- f"Merged={merge_stats['merged']}, Added={merge_stats['added']}, "
|
|
|
|
|
- f"Final={len(fused_cells)}"
|
|
|
|
|
|
|
+ f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
|
|
|
|
|
+ f"Added={merge_stats['added']}, Final={len(fused_cells)}"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 可视化(调试)
|
|
# 可视化(调试)
|
|
|
if debug_dir:
|
|
if debug_dir:
|
|
|
self._visualize_fusion(
|
|
self._visualize_fusion(
|
|
|
- table_image, unet_cells, rtdetr_cells, fused_cells,
|
|
|
|
|
|
|
+ table_image, unet_cells, rtdetr_cells, fused_cells, cell_labels,
|
|
|
debug_dir, debug_prefix
|
|
debug_dir, debug_prefix
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -209,14 +213,23 @@ class CellFusionEngine:
|
|
|
unet_cells: List[List[float]],
|
|
unet_cells: List[List[float]],
|
|
|
rtdetr_cells: List[List[float]],
|
|
rtdetr_cells: List[List[float]],
|
|
|
rtdetr_scores: List[float]
|
|
rtdetr_scores: List[float]
|
|
|
- ) -> Tuple[List[List[float]], Dict[str, int]]:
|
|
|
|
|
|
|
+ ) -> Tuple[List[List[float]], Dict[str, int], List[str]]:
|
|
|
"""
|
|
"""
|
|
|
- 融合 UNet 和 RT-DETR 检测结果
|
|
|
|
|
|
|
+ 融合 UNet 和 RT-DETR 检测结果(增强版:支持合并单元格检测)
|
|
|
|
|
|
|
|
融合规则:
|
|
融合规则:
|
|
|
- 1. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并
|
|
|
|
|
- 2. RT-DETR 独有 + 高置信度 (>0.7) → 补充
|
|
|
|
|
- 3. UNet 独有 → 保留
|
|
|
|
|
|
|
+ 1. 检测RT-DETR的合并单元格(一对多匹配,基于包含关系)
|
|
|
|
|
+ - 判断RT-DETR单元格包含多少个UNet单元格
|
|
|
|
|
+ - 使用中心点+包含率判断(而非IoU)
|
|
|
|
|
+ 2. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并(一对一)
|
|
|
|
|
+ 3. RT-DETR 独有 + 高置信度 (>0.7) → 补充
|
|
|
|
|
+ 4. UNet 独有 → 保留
|
|
|
|
|
+
|
|
|
|
|
+ 包含关系判断逻辑:
|
|
|
|
|
+ - UNet单元格的中心点在RT-DETR内
|
|
|
|
|
+ - UNet单元格的50%以上面积在RT-DETR内
|
|
|
|
|
+ - RT-DETR包含≥2个UNet单元格
|
|
|
|
|
+ - 总覆盖率>40%(所有UNet面积之和 / RT-DETR面积)
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
unet_cells: UNet单元格列表
|
|
unet_cells: UNet单元格列表
|
|
@@ -224,28 +237,104 @@ class CellFusionEngine:
|
|
|
rtdetr_scores: RT-DETR置信度列表
|
|
rtdetr_scores: RT-DETR置信度列表
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- (fused_cells, stats)
|
|
|
|
|
|
|
+ (fused_cells, stats, cell_labels)
|
|
|
- fused_cells: 融合后的单元格
|
|
- fused_cells: 融合后的单元格
|
|
|
- - stats: {'merged': int, 'added': int}
|
|
|
|
|
|
|
+ - stats: {'merged': int, 'added': int, 'merged_cells': int}
|
|
|
|
|
+ - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
|
|
|
"""
|
|
"""
|
|
|
fused_cells = []
|
|
fused_cells = []
|
|
|
|
|
+ cell_labels = [] # 记录每个单元格的来源标签
|
|
|
|
|
+ unet_matched = [False] * len(unet_cells)
|
|
|
rtdetr_matched = [False] * len(rtdetr_cells)
|
|
rtdetr_matched = [False] * len(rtdetr_cells)
|
|
|
- stats = {'merged': 0, 'added': 0}
|
|
|
|
|
|
|
+ stats = {'merged': 0, 'added': 0, 'merged_cells': 0}
|
|
|
|
|
|
|
|
- # Step 1: 遍历 UNet 单元格,尝试与 RT-DETR 匹配
|
|
|
|
|
- for unet_cell in unet_cells:
|
|
|
|
|
|
|
+ # Step 1: 检测RT-DETR的合并单元格(一对多匹配)
|
|
|
|
|
+ # 遍历RT-DETR单元格,查找被包含的多个UNet单元格
|
|
|
|
|
+ for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
|
|
|
|
|
+ if rtdetr_matched[rt_idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 查找所有被当前RT-DETR单元格包含(或大部分包含)的UNet单元格
|
|
|
|
|
+ contained_unet = []
|
|
|
|
|
+ for u_idx, unet_cell in enumerate(unet_cells):
|
|
|
|
|
+ if unet_matched[u_idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 判断UNet单元格是否被RT-DETR单元格包含
|
|
|
|
|
+ # 方法1: 检查UNet的中心点是否在RT-DETR内
|
|
|
|
|
+ unet_cx = (unet_cell[0] + unet_cell[2]) / 2
|
|
|
|
|
+ unet_cy = (unet_cell[1] + unet_cell[3]) / 2
|
|
|
|
|
+
|
|
|
|
|
+ if (rtdetr_cell[0] <= unet_cx <= rtdetr_cell[2] and
|
|
|
|
|
+ rtdetr_cell[1] <= unet_cy <= rtdetr_cell[3]):
|
|
|
|
|
+ # UNet中心点在RT-DETR内,计算包含程度
|
|
|
|
|
+ # 计算UNet有多少面积在RT-DETR内
|
|
|
|
|
+ intersect_x1 = max(unet_cell[0], rtdetr_cell[0])
|
|
|
|
|
+ intersect_y1 = max(unet_cell[1], rtdetr_cell[1])
|
|
|
|
|
+ intersect_x2 = min(unet_cell[2], rtdetr_cell[2])
|
|
|
|
|
+ intersect_y2 = min(unet_cell[3], rtdetr_cell[3])
|
|
|
|
|
+
|
|
|
|
|
+ if intersect_x2 > intersect_x1 and intersect_y2 > intersect_y1:
|
|
|
|
|
+ intersect_area = (intersect_x2 - intersect_x1) * (intersect_y2 - intersect_y1)
|
|
|
|
|
+ unet_area = (unet_cell[2] - unet_cell[0]) * (unet_cell[3] - unet_cell[1])
|
|
|
|
|
+ contain_ratio = intersect_area / unet_area if unet_area > 0 else 0
|
|
|
|
|
+
|
|
|
|
|
+ # 如果UNet单元格的50%以上在RT-DETR内,认为被包含
|
|
|
|
|
+ if contain_ratio > 0.5:
|
|
|
|
|
+ contained_unet.append((u_idx, contain_ratio))
|
|
|
|
|
+
|
|
|
|
|
+ # 判断是否为合并单元格(RT-DETR包含多个UNet单元格)
|
|
|
|
|
+ if len(contained_unet) >= 2:
|
|
|
|
|
+ # 合并单元格场景:优先使用RT-DETR的大框
|
|
|
|
|
+ # 条件:1) 包含2个以上UNet单元格 2) RT-DETR置信度足够高
|
|
|
|
|
+ if rtdetr_scores[rt_idx] > 0.7:
|
|
|
|
|
+ # 计算总包含率:使用所有被包含UNet单元格的外接矩形面积 vs RT-DETR面积
|
|
|
|
|
+ # 使用外接矩形更合理,因为:
|
|
|
|
|
+ # 1. 合并单元格是一个完整区域,应包括单元格间隙
|
|
|
|
|
+ # 2. 避免重复计算相邻单元格的边界
|
|
|
|
|
+ # 3. 更准确反映覆盖率(如11个连续单元格应该接近100%覆盖)
|
|
|
|
|
+ unet_indices = [u_idx for u_idx, _ in contained_unet]
|
|
|
|
|
+ bounding_x1 = min(unet_cells[i][0] for i in unet_indices)
|
|
|
|
|
+ bounding_y1 = min(unet_cells[i][1] for i in unet_indices)
|
|
|
|
|
+ bounding_x2 = max(unet_cells[i][2] for i in unet_indices)
|
|
|
|
|
+ bounding_y2 = max(unet_cells[i][3] for i in unet_indices)
|
|
|
|
|
+ total_unet_area = (bounding_x2 - bounding_x1) * (bounding_y2 - bounding_y1)
|
|
|
|
|
+
|
|
|
|
|
+ rtdetr_area = self._calc_bbox_area(rtdetr_cell)
|
|
|
|
|
+ coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
|
|
|
|
|
+
|
|
|
|
|
+ # 如果覆盖率>40%,说明这是一个真实的合并单元格
|
|
|
|
|
+ # 降低阈值从0.5到0.4,因为合并单元格可能包含很多空白区域
|
|
|
|
|
+ if coverage > 0.4:
|
|
|
|
|
+ fused_cells.append(rtdetr_cell)
|
|
|
|
|
+ cell_labels.append('merged_span') # 标记为合并单元格
|
|
|
|
|
+ rtdetr_matched[rt_idx] = True
|
|
|
|
|
+ # 标记所有被包含的UNet单元格
|
|
|
|
|
+ for u_idx, contain_ratio in contained_unet:
|
|
|
|
|
+ unet_matched[u_idx] = True
|
|
|
|
|
+ stats['merged_cells'] += 1
|
|
|
|
|
+ logger.debug(
|
|
|
|
|
+ f"🔗 检测到合并单元格: RT-DETR[{rt_idx}] 包含 {len(contained_unet)} 个UNet单元格 "
|
|
|
|
|
+ f"(coverage={coverage:.2f}, score={rtdetr_scores[rt_idx]:.2f})"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Step 2: 一对一匹配(处理剩余的单元格)
|
|
|
|
|
+ for u_idx, unet_cell in enumerate(unet_cells):
|
|
|
|
|
+ if unet_matched[u_idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
best_match_idx = -1
|
|
best_match_idx = -1
|
|
|
best_iou = 0.0
|
|
best_iou = 0.0
|
|
|
|
|
|
|
|
# 查找最佳匹配的 RT-DETR 单元格
|
|
# 查找最佳匹配的 RT-DETR 单元格
|
|
|
- for idx, rtdetr_cell in enumerate(rtdetr_cells):
|
|
|
|
|
- if rtdetr_matched[idx]:
|
|
|
|
|
|
|
+ for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
|
|
|
|
|
+ if rtdetr_matched[rt_idx]:
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
iou = CoordinateUtils.calculate_iou(unet_cell, rtdetr_cell)
|
|
iou = CoordinateUtils.calculate_iou(unet_cell, rtdetr_cell)
|
|
|
if iou > best_iou:
|
|
if iou > best_iou:
|
|
|
best_iou = iou
|
|
best_iou = iou
|
|
|
- best_match_idx = idx
|
|
|
|
|
|
|
+ best_match_idx = rt_idx
|
|
|
|
|
|
|
|
# 判断是否合并
|
|
# 判断是否合并
|
|
|
if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
|
|
if best_match_idx >= 0 and best_iou >= self.iou_merge_threshold:
|
|
@@ -257,19 +346,28 @@ class CellFusionEngine:
|
|
|
self.rtdetr_weight
|
|
self.rtdetr_weight
|
|
|
)
|
|
)
|
|
|
fused_cells.append(merged_cell)
|
|
fused_cells.append(merged_cell)
|
|
|
|
|
+ cell_labels.append('merged_1to1') # 标记为1:1融合
|
|
|
rtdetr_matched[best_match_idx] = True
|
|
rtdetr_matched[best_match_idx] = True
|
|
|
|
|
+ unet_matched[u_idx] = True
|
|
|
stats['merged'] += 1
|
|
stats['merged'] += 1
|
|
|
else:
|
|
else:
|
|
|
# UNet 独有:保留
|
|
# UNet 独有:保留
|
|
|
fused_cells.append(unet_cell)
|
|
fused_cells.append(unet_cell)
|
|
|
|
|
+ cell_labels.append('unet_only') # 标记为UNet独有
|
|
|
|
|
+ unet_matched[u_idx] = True
|
|
|
|
|
|
|
|
- # Step 2: 补充 RT-DETR 独有的高置信度单元格
|
|
|
|
|
|
|
+ # Step 3: 补充 RT-DETR 独有的高置信度单元格
|
|
|
for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
|
|
for idx, (rtdetr_cell, score) in enumerate(zip(rtdetr_cells, rtdetr_scores)):
|
|
|
if not rtdetr_matched[idx] and score > 0.7:
|
|
if not rtdetr_matched[idx] and score > 0.7:
|
|
|
fused_cells.append(rtdetr_cell)
|
|
fused_cells.append(rtdetr_cell)
|
|
|
|
|
+ cell_labels.append('rtdetr_only') # 标记为RT-DETR独有
|
|
|
stats['added'] += 1
|
|
stats['added'] += 1
|
|
|
|
|
|
|
|
- return fused_cells, stats
|
|
|
|
|
|
|
+ return fused_cells, stats, cell_labels
|
|
|
|
|
+
|
|
|
|
|
+ def _calc_bbox_area(self, bbox: List[float]) -> float:
|
|
|
|
|
+ """计算bbox面积"""
|
|
|
|
|
+ return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
|
|
|
|
|
|
|
def _weighted_merge_bbox(
|
|
def _weighted_merge_bbox(
|
|
|
self,
|
|
self,
|
|
@@ -345,9 +443,10 @@ class CellFusionEngine:
|
|
|
def _compensate_with_ocr(
|
|
def _compensate_with_ocr(
|
|
|
self,
|
|
self,
|
|
|
cells: List[List[float]],
|
|
cells: List[List[float]],
|
|
|
|
|
+ cell_labels: List[str],
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
table_size: Tuple[int, int]
|
|
table_size: Tuple[int, int]
|
|
|
- ) -> Tuple[List[List[float]], int]:
|
|
|
|
|
|
|
+ ) -> Tuple[List[List[float]], List[str], int]:
|
|
|
"""
|
|
"""
|
|
|
使用 OCR 补偿遗漏的单元格
|
|
使用 OCR 补偿遗漏的单元格
|
|
|
|
|
|
|
@@ -355,13 +454,15 @@ class CellFusionEngine:
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
cells: 现有单元格列表
|
|
cells: 现有单元格列表
|
|
|
|
|
+ cell_labels: 单元格标签列表
|
|
|
ocr_boxes: OCR结果列表
|
|
ocr_boxes: OCR结果列表
|
|
|
table_size: 表格尺寸 (width, height)
|
|
table_size: 表格尺寸 (width, height)
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- (compensated_cells, compensation_count)
|
|
|
|
|
|
|
+ (compensated_cells, compensated_labels, compensation_count)
|
|
|
"""
|
|
"""
|
|
|
compensated = cells.copy()
|
|
compensated = cells.copy()
|
|
|
|
|
+ compensated_labels = cell_labels.copy()
|
|
|
compensation_count = 0
|
|
compensation_count = 0
|
|
|
w, h = table_size
|
|
w, h = table_size
|
|
|
|
|
|
|
@@ -391,26 +492,27 @@ class CellFusionEngine:
|
|
|
# 扩展 OCR bbox 作为新单元格
|
|
# 扩展 OCR bbox 作为新单元格
|
|
|
if len(ocr_bbox) == 8:
|
|
if len(ocr_bbox) == 8:
|
|
|
new_cell = [
|
|
new_cell = [
|
|
|
- max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5),
|
|
|
|
|
- max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5),
|
|
|
|
|
- min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5),
|
|
|
|
|
- min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5)
|
|
|
|
|
|
|
+ float(max(0, min(ocr_bbox[0], ocr_bbox[6]) - 5)),
|
|
|
|
|
+ float(max(0, min(ocr_bbox[1], ocr_bbox[3]) - 5)),
|
|
|
|
|
+ float(min(w, max(ocr_bbox[2], ocr_bbox[4]) + 5)),
|
|
|
|
|
+ float(min(h, max(ocr_bbox[5], ocr_bbox[7]) + 5))
|
|
|
]
|
|
]
|
|
|
else:
|
|
else:
|
|
|
new_cell = [
|
|
new_cell = [
|
|
|
- max(0, ocr_bbox[0] - 5),
|
|
|
|
|
- max(0, ocr_bbox[1] - 5),
|
|
|
|
|
- min(w, ocr_bbox[2] + 5),
|
|
|
|
|
- min(h, ocr_bbox[3] + 5)
|
|
|
|
|
|
|
+ float(max(0, ocr_bbox[0] - 5)),
|
|
|
|
|
+ float(max(0, ocr_bbox[1] - 5)),
|
|
|
|
|
+ float(min(w, ocr_bbox[2] + 5)),
|
|
|
|
|
+ float(min(h, ocr_bbox[3] + 5))
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
compensated.append(new_cell)
|
|
compensated.append(new_cell)
|
|
|
|
|
+ compensated_labels.append('new') # 标记为新增(OCR补偿)
|
|
|
compensation_count += 1
|
|
compensation_count += 1
|
|
|
|
|
|
|
|
if compensation_count > 0:
|
|
if compensation_count > 0:
|
|
|
logger.debug(f"OCR compensation: added {compensation_count} cells")
|
|
logger.debug(f"OCR compensation: added {compensation_count} cells")
|
|
|
|
|
|
|
|
- return compensated, compensation_count
|
|
|
|
|
|
|
+ return compensated, compensated_labels, compensation_count
|
|
|
|
|
|
|
|
def _visualize_fusion(
|
|
def _visualize_fusion(
|
|
|
self,
|
|
self,
|
|
@@ -418,10 +520,11 @@ class CellFusionEngine:
|
|
|
unet_cells: List[List[float]],
|
|
unet_cells: List[List[float]],
|
|
|
rtdetr_cells: List[List[float]],
|
|
rtdetr_cells: List[List[float]],
|
|
|
fused_cells: List[List[float]],
|
|
fused_cells: List[List[float]],
|
|
|
|
|
+ cell_labels: List[str],
|
|
|
debug_dir: str,
|
|
debug_dir: str,
|
|
|
debug_prefix: str
|
|
debug_prefix: str
|
|
|
):
|
|
):
|
|
|
- """可视化融合结果(调试用)"""
|
|
|
|
|
|
|
+ """可视化融合结果(调试用)- 增强版:用颜色区分不同来源的单元格"""
|
|
|
try:
|
|
try:
|
|
|
import cv2
|
|
import cv2
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
@@ -429,35 +532,96 @@ class CellFusionEngine:
|
|
|
output_dir = Path(debug_dir)
|
|
output_dir = Path(debug_dir)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
- # 创建三栏对比图
|
|
|
|
|
h, w = table_image.shape[:2]
|
|
h, w = table_image.shape[:2]
|
|
|
- vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
- # 左栏:UNet
|
|
|
|
|
|
|
+ # === 图1:UNet原始结果 ===
|
|
|
img1 = table_image.copy()
|
|
img1 = table_image.copy()
|
|
|
|
|
+ if len(img1.shape) == 2:
|
|
|
|
|
+ img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
|
|
|
for cell in unet_cells:
|
|
for cell in unet_cells:
|
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
- cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
|
|
|
+ cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 255, 0), 2) # 绿色
|
|
|
cv2.putText(img1, f"UNet ({len(unet_cells)})", (10, 30),
|
|
cv2.putText(img1, f"UNet ({len(unet_cells)})", (10, 30),
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
|
|
|
|
|
|
|
- # 中栏:RT-DETR
|
|
|
|
|
|
|
+ # === 图2:RT-DETR原始结果 ===
|
|
|
img2 = table_image.copy()
|
|
img2 = table_image.copy()
|
|
|
|
|
+ if len(img2.shape) == 2:
|
|
|
|
|
+ img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)
|
|
|
for cell in rtdetr_cells:
|
|
for cell in rtdetr_cells:
|
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
- cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
|
|
|
|
|
|
+ cv2.rectangle(img2, (x1, y1), (x2, y2), (255, 0, 0), 2) # 蓝色
|
|
|
cv2.putText(img2, f"RT-DETR ({len(rtdetr_cells)})", (10, 30),
|
|
cv2.putText(img2, f"RT-DETR ({len(rtdetr_cells)})", (10, 30),
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
|
|
|
|
|
|
|
|
- # 右栏:融合结果
|
|
|
|
|
|
|
+ # === 图3:融合结果(按来源分类)===
|
|
|
img3 = table_image.copy()
|
|
img3 = table_image.copy()
|
|
|
- for cell in fused_cells:
|
|
|
|
|
|
|
+ if len(img3.shape) == 2:
|
|
|
|
|
+ img3 = cv2.cvtColor(img3, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
+
|
|
|
|
|
+ # 根据标签分类单元格(使用 _fuse_cells 中记录的标签)
|
|
|
|
|
+ unet_only = [] # UNet独有(绿色)
|
|
|
|
|
+ rtdetr_only = [] # RT-DETR独有(蓝色)
|
|
|
|
|
+ merged_cells_1to1 = [] # 1:1融合单元格(黄色)
|
|
|
|
|
+ merged_cells_span = [] # 合并单元格(品红色,RT-DETR检测的跨格单元格)
|
|
|
|
|
+ new_cells = [] # 新增单元格(紫色)
|
|
|
|
|
+
|
|
|
|
|
+ for fused_cell, label in zip(fused_cells, cell_labels):
|
|
|
|
|
+ if label == 'unet_only':
|
|
|
|
|
+ unet_only.append(fused_cell)
|
|
|
|
|
+ elif label == 'rtdetr_only':
|
|
|
|
|
+ rtdetr_only.append(fused_cell)
|
|
|
|
|
+ elif label == 'merged_1to1':
|
|
|
|
|
+ merged_cells_1to1.append(fused_cell)
|
|
|
|
|
+ elif label == 'merged_span':
|
|
|
|
|
+ merged_cells_span.append(fused_cell)
|
|
|
|
|
+ elif label == 'new':
|
|
|
|
|
+ new_cells.append(fused_cell)
|
|
|
|
|
+
|
|
|
|
|
+ # 绘制不同类型的单元格
|
|
|
|
|
+ for cell in unet_only:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 0), 2) # 绿色 - UNet独有
|
|
|
|
|
+
|
|
|
|
|
+ for cell in rtdetr_only:
|
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
- cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 255), 2)
|
|
|
|
|
- cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, 30),
|
|
|
|
|
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
|
|
|
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 0, 0), 2) # 蓝色 - RT-DETR独有
|
|
|
|
|
|
|
|
- # 拼接
|
|
|
|
|
|
|
+ for cell in merged_cells_1to1:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (0, 255, 255), 3) # 黄色 - 1:1融合(加粗)
|
|
|
|
|
+
|
|
|
|
|
+ for cell in merged_cells_span:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 0, 255), 4) # 品红色 - 合并单元格(加粗)
|
|
|
|
|
+
|
|
|
|
|
+ for cell in new_cells:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2) # 紫色 - 新增
|
|
|
|
|
+
|
|
|
|
|
+ # 添加图例
|
|
|
|
|
+ legend_y = 30
|
|
|
|
|
+ cv2.putText(img3, f"Fused ({len(fused_cells)})", (10, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
|
|
|
|
|
+ legend_y += 35
|
|
|
|
|
+ cv2.putText(img3, f"Green: UNet-only ({len(unet_only)})", (10, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
|
|
|
|
+ legend_y += 30
|
|
|
|
|
+ cv2.putText(img3, f"Blue: RTDETR-only ({len(rtdetr_only)})", (10, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
|
|
|
|
|
+ legend_y += 30
|
|
|
|
|
+ cv2.putText(img3, f"Yellow: 1:1 Merged ({len(merged_cells_1to1)})", (10, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
|
|
|
|
+ legend_y += 30
|
|
|
|
|
+ cv2.putText(img3, f"Magenta: Span Cells ({len(merged_cells_span)})", (10, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 255), 2)
|
|
|
|
|
+ if new_cells:
|
|
|
|
|
+ legend_y += 30
|
|
|
|
|
+ cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 拼接三栏对比
|
|
|
|
|
+ vis_canvas = np.zeros((h, w * 3, 3), dtype=np.uint8)
|
|
|
vis_canvas[:, :w] = img1
|
|
vis_canvas[:, :w] = img1
|
|
|
vis_canvas[:, w:2*w] = img2
|
|
vis_canvas[:, w:2*w] = img2
|
|
|
vis_canvas[:, 2*w:] = img3
|
|
vis_canvas[:, 2*w:] = img3
|
|
@@ -465,7 +629,9 @@ class CellFusionEngine:
|
|
|
# 保存
|
|
# 保存
|
|
|
output_path = output_dir / f"{debug_prefix}_fusion_comparison.png"
|
|
output_path = output_dir / f"{debug_prefix}_fusion_comparison.png"
|
|
|
cv2.imwrite(str(output_path), vis_canvas)
|
|
cv2.imwrite(str(output_path), vis_canvas)
|
|
|
- logger.debug(f"💾 Fusion visualization saved: {output_path}")
|
|
|
|
|
|
|
+ logger.info(f"💾 融合可视化已保存: {output_path}")
|
|
|
|
|
+ logger.info(f" 📊 单元格分类: UNet独有={len(unet_only)}, RT-DETR独有={len(rtdetr_only)}, "
|
|
|
|
|
+ f"1:1融合={len(merged_cells_1to1)}, 合并单元格={len(merged_cells_span)}, 新增={len(new_cells)}")
|
|
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to visualize fusion: {e}")
|
|
logger.warning(f"Failed to visualize fusion: {e}")
|