|
@@ -57,11 +57,22 @@ class CellFusionEngine:
|
|
|
self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
|
|
self.rtdetr_conf_threshold = self.config.get('rtdetr_conf_threshold', 0.5)
|
|
|
self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
|
|
self.enable_ocr_compensation = self.config.get('enable_ocr_compensation', True)
|
|
|
self.enable_boundary_noise_filter = self.config.get('enable_boundary_noise_filter', True)
|
|
self.enable_boundary_noise_filter = self.config.get('enable_boundary_noise_filter', True)
|
|
|
|
|
+ self.unet_split_min_count = self.config.get('unet_split_min_count', 2)
|
|
|
|
|
+ self.rtdetr_split_cover_threshold = self.config.get('rtdetr_split_cover_threshold', 0.5)
|
|
|
|
|
+ self.unet_split_cover_threshold = self.config.get('unet_split_cover_threshold', 0.5)
|
|
|
|
|
+ self.unet_split_rtdetr_score_threshold = self.config.get(
|
|
|
|
|
+ 'unet_split_rtdetr_score_threshold',
|
|
|
|
|
+ self.rtdetr_conf_threshold
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
logger.info(f"🔧 CellFusionEngine initialized: "
|
|
logger.info(f"🔧 CellFusionEngine initialized: "
|
|
|
- f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
|
|
|
|
|
- f"iou_merge={self.iou_merge_threshold}, ocr_comp={self.enable_ocr_compensation}, "
|
|
|
|
|
- f"boundary_filter={self.enable_boundary_noise_filter}")
|
|
|
|
|
|
|
+ f"unet_w={self.unet_weight}, rtdetr_w={self.rtdetr_weight}, "
|
|
|
|
|
+ f"iou_merge={self.iou_merge_threshold}, ocr_comp={self.enable_ocr_compensation}, "
|
|
|
|
|
+ f"boundary_filter={self.enable_boundary_noise_filter}, "
|
|
|
|
|
+ f"unet_split_min={self.unet_split_min_count}, "
|
|
|
|
|
+ f"unet_split_cover={self.unet_split_cover_threshold}, "
|
|
|
|
|
+ f"unet_split_score={self.unet_split_rtdetr_score_threshold}, "
|
|
|
|
|
+ f"rtdetr_split_cover={self.rtdetr_split_cover_threshold}")
|
|
|
|
|
|
|
|
def should_use_rtdetr(
|
|
def should_use_rtdetr(
|
|
|
self,
|
|
self,
|
|
@@ -99,6 +110,7 @@ class CellFusionEngine:
|
|
|
table_image: np.ndarray,
|
|
table_image: np.ndarray,
|
|
|
unet_cells: List[List[float]],
|
|
unet_cells: List[List[float]],
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
|
|
+ ocr_text_pixel_tolerance: float = 10.0,
|
|
|
pdf_type: str = 'ocr',
|
|
pdf_type: str = 'ocr',
|
|
|
debug_dir: Optional[str] = None,
|
|
debug_dir: Optional[str] = None,
|
|
|
debug_prefix: str = "fusion"
|
|
debug_prefix: str = "fusion"
|
|
@@ -110,6 +122,7 @@ class CellFusionEngine:
|
|
|
table_image: 表格图像(原图坐标系)
|
|
table_image: 表格图像(原图坐标系)
|
|
|
unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
|
|
unet_cells: UNet检测的单元格列表 [[x1,y1,x2,y2], ...](原图坐标系)
|
|
|
ocr_boxes: OCR结果列表
|
|
ocr_boxes: OCR结果列表
|
|
|
|
|
+ ocr_text_pixel_tolerance: OCR文本容差(原图坐标系,默认10.0)
|
|
|
pdf_type: PDF类型 ('txt' 或 'ocr')
|
|
pdf_type: PDF类型 ('txt' 或 'ocr')
|
|
|
debug_dir: 调试输出目录(可选)
|
|
debug_dir: 调试输出目录(可选)
|
|
|
debug_prefix: 调试文件前缀
|
|
debug_prefix: 调试文件前缀
|
|
@@ -126,7 +139,7 @@ class CellFusionEngine:
|
|
|
max(unet_cells, key=lambda box: box[2])[2], \
|
|
max(unet_cells, key=lambda box: box[2])[2], \
|
|
|
max(unet_cells, key=lambda box: box[3])[3]
|
|
max(unet_cells, key=lambda box: box[3])[3]
|
|
|
] if unet_cells else [0,0,0,0]
|
|
] if unet_cells else [0,0,0,0]
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
# 决策:是否使用 RT-DETR
|
|
# 决策:是否使用 RT-DETR
|
|
|
use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
|
|
use_rtdetr = self.should_use_rtdetr(pdf_type, len(unet_cells), (w, h))
|
|
|
|
|
|
|
@@ -165,8 +178,8 @@ class CellFusionEngine:
|
|
|
table_image,
|
|
table_image,
|
|
|
conf_threshold=self.rtdetr_conf_threshold
|
|
conf_threshold=self.rtdetr_conf_threshold
|
|
|
)
|
|
)
|
|
|
- # rtdetr_result从上到下,从左到右排序
|
|
|
|
|
- rtdetr_results.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
|
|
|
|
|
|
|
+ # rtdetr_result从上到下,从左到右排序, 排序按取整后,容差为10
|
|
|
|
|
+ rtdetr_results = sorted(rtdetr_results, key=lambda x: (round(x['bbox'][1] / 10), round(x['bbox'][0])))
|
|
|
rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
rtdetr_cells = [res['bbox'] for res in rtdetr_results]
|
|
|
rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
rtdetr_scores = [res['score'] for res in rtdetr_results]
|
|
|
fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
fusion_stats['rtdetr_count'] = len(rtdetr_cells)
|
|
@@ -179,7 +192,7 @@ class CellFusionEngine:
|
|
|
|
|
|
|
|
# Phase 2: 智能融合
|
|
# Phase 2: 智能融合
|
|
|
# 使用稳健边界估计(避免单个超大单元格撑开边界)
|
|
# 使用稳健边界估计(避免单个超大单元格撑开边界)
|
|
|
- table_bbox = self._estimate_robust_table_bbox(rtdetr_cells)
|
|
|
|
|
|
|
+ table_bbox = self._estimate_robust_table_bbox(rtdetr_cells, ocr_text_pixel_tolerance)
|
|
|
|
|
|
|
|
# 将所有单元格的边界限制在表格边界内
|
|
# 将所有单元格的边界限制在表格边界内
|
|
|
# rtdetr_cells = self._clip_cells_to_bbox(rtdetr_cells, table_bbox)
|
|
# rtdetr_cells = self._clip_cells_to_bbox(rtdetr_cells, table_bbox)
|
|
@@ -190,6 +203,7 @@ class CellFusionEngine:
|
|
|
fusion_stats['merged_count'] = merge_stats['merged']
|
|
fusion_stats['merged_count'] = merge_stats['merged']
|
|
|
fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
|
|
fusion_stats['merged_cells_count'] = merge_stats['merged_cells']
|
|
|
fusion_stats['added_count'] = merge_stats['added']
|
|
fusion_stats['added_count'] = merge_stats['added']
|
|
|
|
|
+ fusion_stats['split_count'] = merge_stats.get('split', 0)
|
|
|
|
|
|
|
|
# Phase 3: NMS 去重
|
|
# Phase 3: NMS 去重
|
|
|
fused_cells, suppressed = self._nms_filter(fused_cells, self.iou_nms_threshold)
|
|
fused_cells, suppressed = self._nms_filter(fused_cells, self.iou_nms_threshold)
|
|
@@ -199,7 +213,8 @@ class CellFusionEngine:
|
|
|
# Phase 4: 边界噪声过滤(过滤掉边界的 unet_only 噪声单元格)
|
|
# Phase 4: 边界噪声过滤(过滤掉边界的 unet_only 噪声单元格)
|
|
|
if self.enable_boundary_noise_filter:
|
|
if self.enable_boundary_noise_filter:
|
|
|
fused_cells, cell_labels, noise_filtered = self._filter_boundary_noise(
|
|
fused_cells, cell_labels, noise_filtered = self._filter_boundary_noise(
|
|
|
- fused_cells, cell_labels, ocr_boxes, table_bbox
|
|
|
|
|
|
|
+ fused_cells, cell_labels, ocr_boxes, table_bbox,
|
|
|
|
|
+ boundary_tolerance=ocr_text_pixel_tolerance
|
|
|
)
|
|
)
|
|
|
fusion_stats['noise_filtered_count'] = noise_filtered
|
|
fusion_stats['noise_filtered_count'] = noise_filtered
|
|
|
else:
|
|
else:
|
|
@@ -220,7 +235,7 @@ class CellFusionEngine:
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
f"📊 Fusion (UNet+RT-DETR): UNet={len(unet_cells)}, RT-DETR={len(rtdetr_cells)}, "
|
|
|
f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
|
|
f"1:1Merged={merge_stats['merged']}, MergedCells={merge_stats['merged_cells']}, "
|
|
|
- f"Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, "
|
|
|
|
|
|
|
+ f"Split={merge_stats.get('split', 0)}, Added={merge_stats['added']}, NoiseFiltered={noise_filtered}, "
|
|
|
f"OCRCompensated={fusion_stats.get('ocr_compensated_count', 0)}, Final={len(fused_cells)}"
|
|
f"OCRCompensated={fusion_stats.get('ocr_compensated_count', 0)}, Final={len(fused_cells)}"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -243,13 +258,14 @@ class CellFusionEngine:
|
|
|
"""
|
|
"""
|
|
|
融合 UNet 和 RT-DETR 检测结果(增强版:支持合并单元格检测)
|
|
融合 UNet 和 RT-DETR 检测结果(增强版:支持合并单元格检测)
|
|
|
|
|
|
|
|
- 融合规则:
|
|
|
|
|
- 1. 检测RT-DETR的合并单元格(一对多匹配,基于包含关系)
|
|
|
|
|
- - 判断RT-DETR单元格包含多少个UNet单元格
|
|
|
|
|
- - 使用中心点+包含率判断(而非IoU)
|
|
|
|
|
- 2. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并(一对一)
|
|
|
|
|
- 3. RT-DETR 独有 + 高置信度 (>0.7) → 补充
|
|
|
|
|
- 4. UNet 独有 → 保留
|
|
|
|
|
|
|
+ 融合规则:
|
|
|
|
|
+ 1. 检测RT-DETR的合并单元格(一对多匹配,基于包含关系)
|
|
|
|
|
+ - 判断RT-DETR单元格包含多少个UNet单元格
|
|
|
|
|
+ - 使用中心点+包含率判断(而非IoU)
|
|
|
|
|
+ 2. 检测UNet过度合并(一个UNet包含多个RT-DETR)并拆分
|
|
|
|
|
+ 3. UNet + RT-DETR 高IoU (>threshold) → 加权平均合并(一对一)
|
|
|
|
|
+ 4. RT-DETR 独有 + 高置信度 (>0.7) → 补充
|
|
|
|
|
+ 5. UNet 独有 → 保留
|
|
|
|
|
|
|
|
包含关系判断逻辑:
|
|
包含关系判断逻辑:
|
|
|
- UNet单元格的中心点在RT-DETR内
|
|
- UNet单元格的中心点在RT-DETR内
|
|
@@ -267,14 +283,14 @@ class CellFusionEngine:
|
|
|
(fused_cells, stats, cell_labels)
|
|
(fused_cells, stats, cell_labels)
|
|
|
- fused_cells: 融合后的单元格
|
|
- fused_cells: 融合后的单元格
|
|
|
- stats: {'merged': int, 'added': int, 'merged_cells': int}
|
|
- stats: {'merged': int, 'added': int, 'merged_cells': int}
|
|
|
- - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'new']
|
|
|
|
|
|
|
+ - cell_labels: 每个单元格的来源标签列表 ['merged_span', 'merged_1to1', 'unet_only', 'rtdetr_only', 'split_rtdetr', 'new']
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
fused_cells = []
|
|
fused_cells = []
|
|
|
cell_labels = [] # 记录每个单元格的来源标签
|
|
cell_labels = [] # 记录每个单元格的来源标签
|
|
|
unet_matched = [False] * len(unet_cells)
|
|
unet_matched = [False] * len(unet_cells)
|
|
|
rtdetr_matched = [False] * len(rtdetr_cells)
|
|
rtdetr_matched = [False] * len(rtdetr_cells)
|
|
|
- stats = {'merged': 0, 'added': 0, 'merged_cells': 0}
|
|
|
|
|
|
|
+ stats = {'merged': 0, 'added': 0, 'merged_cells': 0, 'split': 0}
|
|
|
|
|
|
|
|
# Step 1: 检测RT-DETR的合并单元格(一对多匹配)
|
|
# Step 1: 检测RT-DETR的合并单元格(一对多匹配)
|
|
|
# 遍历RT-DETR单元格,查找被包含的多个UNet单元格
|
|
# 遍历RT-DETR单元格,查找被包含的多个UNet单元格
|
|
@@ -332,7 +348,7 @@ class CellFusionEngine:
|
|
|
coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
|
|
coverage = min(total_unet_area / rtdetr_area, 1.0) if rtdetr_area > 0 else 0
|
|
|
|
|
|
|
|
# 如果覆盖率>50%,说明这是一个真实的合并单元格
|
|
# 如果覆盖率>50%,说明这是一个真实的合并单元格
|
|
|
- if coverage > 0.5:
|
|
|
|
|
|
|
+ if coverage > self.rtdetr_split_cover_threshold:
|
|
|
# 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过table_bbox范围
|
|
# 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过table_bbox范围
|
|
|
fused_cell = [
|
|
fused_cell = [
|
|
|
min(bounding_x1, rtdetr_cell[0]),
|
|
min(bounding_x1, rtdetr_cell[0]),
|
|
@@ -342,9 +358,9 @@ class CellFusionEngine:
|
|
|
]
|
|
]
|
|
|
# x限制在table_bbox范围内
|
|
# x限制在table_bbox范围内
|
|
|
fused_cell[0] = max(fused_cell[0], table_bbox[0])
|
|
fused_cell[0] = max(fused_cell[0], table_bbox[0])
|
|
|
- # fused_cell[1] = max(fused_cell[1], table_bbox[1])
|
|
|
|
|
|
|
+ fused_cell[1] = max(fused_cell[1], table_bbox[1])
|
|
|
fused_cell[2] = min(fused_cell[2], table_bbox[2])
|
|
fused_cell[2] = min(fused_cell[2], table_bbox[2])
|
|
|
- # fused_cell[3] = min(fused_cell[3], table_bbox[3])
|
|
|
|
|
|
|
+ fused_cell[3] = min(fused_cell[3], table_bbox[3])
|
|
|
fused_cells.append(fused_cell)
|
|
fused_cells.append(fused_cell)
|
|
|
cell_labels.append('merged_span') # 标记为合并单元格
|
|
cell_labels.append('merged_span') # 标记为合并单元格
|
|
|
rtdetr_matched[rt_idx] = True
|
|
rtdetr_matched[rt_idx] = True
|
|
@@ -357,6 +373,80 @@ class CellFusionEngine:
|
|
|
f"(coverage={coverage:.2f}, score={rtdetr_scores[rt_idx]:.2f})"
|
|
f"(coverage={coverage:.2f}, score={rtdetr_scores[rt_idx]:.2f})"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # Step 1.5: 检测UNet过度合并(一个UNet包含多个RT-DETR)并拆分
|
|
|
|
|
+ for u_idx, unet_cell in enumerate(unet_cells):
|
|
|
|
|
+ if unet_matched[u_idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ unet_area = self._calc_bbox_area(unet_cell)
|
|
|
|
|
+ if unet_area <= 0:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ contained_rtdetr = []
|
|
|
|
|
+ contained_intersects = []
|
|
|
|
|
+
|
|
|
|
|
+ for rt_idx, rtdetr_cell in enumerate(rtdetr_cells):
|
|
|
|
|
+ if rtdetr_matched[rt_idx]:
|
|
|
|
|
+ continue
|
|
|
|
|
+ if rtdetr_scores[rt_idx] < self.unet_split_rtdetr_score_threshold:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ rt_cx = (rtdetr_cell[0] + rtdetr_cell[2]) / 2
|
|
|
|
|
+ rt_cy = (rtdetr_cell[1] + rtdetr_cell[3]) / 2
|
|
|
|
|
+ if not (unet_cell[0] <= rt_cx <= unet_cell[2] and
|
|
|
|
|
+ unet_cell[1] <= rt_cy <= unet_cell[3]):
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ intersect_x1 = max(unet_cell[0], rtdetr_cell[0])
|
|
|
|
|
+ intersect_y1 = max(unet_cell[1], rtdetr_cell[1])
|
|
|
|
|
+ intersect_x2 = min(unet_cell[2], rtdetr_cell[2])
|
|
|
|
|
+ intersect_y2 = min(unet_cell[3], rtdetr_cell[3])
|
|
|
|
|
+ if intersect_x2 <= intersect_x1 or intersect_y2 <= intersect_y1:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ intersect_area = (intersect_x2 - intersect_x1) * (intersect_y2 - intersect_y1)
|
|
|
|
|
+ rtdetr_area = self._calc_bbox_area(rtdetr_cell)
|
|
|
|
|
+ contain_ratio = intersect_area / rtdetr_area if rtdetr_area > 0 else 0
|
|
|
|
|
+ if contain_ratio > 0.5:
|
|
|
|
|
+ contained_rtdetr.append(rt_idx)
|
|
|
|
|
+ contained_intersects.append(intersect_area)
|
|
|
|
|
+
|
|
|
|
|
+ if len(contained_rtdetr) >= self.unet_split_min_count:
|
|
|
|
|
+ # 计算总包含率:使用所有被包含RT-DETR单元格的外接矩形面积 vs UNet面积
|
|
|
|
|
+ # 与RT-DETR合并逻辑保持一致,避免相邻框重复/间隙导致覆盖率失真
|
|
|
|
|
+ rt_indices = contained_rtdetr
|
|
|
|
|
+ bounding_x1 = min(rtdetr_cells[i][0] for i in rt_indices)
|
|
|
|
|
+ bounding_y1 = min(rtdetr_cells[i][1] for i in rt_indices)
|
|
|
|
|
+ bounding_x2 = max(rtdetr_cells[i][2] for i in rt_indices)
|
|
|
|
|
+ bounding_y2 = max(rtdetr_cells[i][3] for i in rt_indices)
|
|
|
|
|
+ total_rtdetr_area = (bounding_x2 - bounding_x1) * (bounding_y2 - bounding_y1)
|
|
|
|
|
+ coverage = min(total_rtdetr_area / unet_area, 1.0)
|
|
|
|
|
+ if coverage >= self.unet_split_cover_threshold:
|
|
|
|
|
+ # 认定为合并单元格,取bounding与RT-DETR的最大范围, 且不能超过table_bbox范围
|
|
|
|
|
+ split_cell = [
|
|
|
|
|
+ min(bounding_x1, unet_cell[0]),
|
|
|
|
|
+ min(bounding_y1, unet_cell[1]),
|
|
|
|
|
+ max(bounding_x2, unet_cell[2]),
|
|
|
|
|
+ max(bounding_y2, unet_cell[3])
|
|
|
|
|
+ ]
|
|
|
|
|
+ split_cell = [
|
|
|
|
|
+ max(split_cell[0], table_bbox[0]),
|
|
|
|
|
+ max(split_cell[1], table_bbox[1]),
|
|
|
|
|
+ min(split_cell[2], table_bbox[2]),
|
|
|
|
|
+ min(split_cell[3], table_bbox[3])
|
|
|
|
|
+ ]
|
|
|
|
|
+ fused_cells.append(split_cell)
|
|
|
|
|
+ cell_labels.append('split_rtdetr')
|
|
|
|
|
+ for rt_idx in contained_rtdetr:
|
|
|
|
|
+ rtdetr_matched[rt_idx] = True
|
|
|
|
|
+
|
|
|
|
|
+ unet_matched[u_idx] = True
|
|
|
|
|
+ stats['split'] += len(contained_rtdetr)
|
|
|
|
|
+ logger.debug(
|
|
|
|
|
+ f"🧩 UNet过度合并拆分: UNet[{u_idx}] -> {len(contained_rtdetr)} RT-DETR "
|
|
|
|
|
+ f"(coverage={coverage:.2f})"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
# Step 2: 一对一匹配(处理剩余的单元格)
|
|
# Step 2: 一对一匹配(处理剩余的单元格)
|
|
|
for u_idx, unet_cell in enumerate(unet_cells):
|
|
for u_idx, unet_cell in enumerate(unet_cells):
|
|
|
if unet_matched[u_idx]:
|
|
if unet_matched[u_idx]:
|
|
@@ -401,9 +491,9 @@ class CellFusionEngine:
|
|
|
if not rtdetr_matched[idx] and score > 0.7:
|
|
if not rtdetr_matched[idx] and score > 0.7:
|
|
|
# rtdetr_cell不能超出table_bbox范围, x方向分别限制
|
|
# rtdetr_cell不能超出table_bbox范围, x方向分别限制
|
|
|
rtdetr_cell[0] = max(rtdetr_cell[0], table_bbox[0])
|
|
rtdetr_cell[0] = max(rtdetr_cell[0], table_bbox[0])
|
|
|
- # rtdetr_cell[1] = max(rtdetr_cell[1], table_bbox[1])
|
|
|
|
|
|
|
+ rtdetr_cell[1] = max(rtdetr_cell[1], table_bbox[1])
|
|
|
rtdetr_cell[2] = min(rtdetr_cell[2], table_bbox[2])
|
|
rtdetr_cell[2] = min(rtdetr_cell[2], table_bbox[2])
|
|
|
- # rtdetr_cell[3] = min(rtdetr_cell[3], table_bbox[3])
|
|
|
|
|
|
|
+ rtdetr_cell[3] = min(rtdetr_cell[3], table_bbox[3])
|
|
|
fused_cells.append(rtdetr_cell)
|
|
fused_cells.append(rtdetr_cell)
|
|
|
cell_labels.append('rtdetr_only') # 标记为RT-DETR独有
|
|
cell_labels.append('rtdetr_only') # 标记为RT-DETR独有
|
|
|
stats['added'] += 1
|
|
stats['added'] += 1
|
|
@@ -418,18 +508,16 @@ class CellFusionEngine:
|
|
|
"""
|
|
"""
|
|
|
稳健的表格边界估计
|
|
稳健的表格边界估计
|
|
|
|
|
|
|
|
- 使用聚类方法找到"主流"的左右边界,避免单个超大单元格撑开边界。
|
|
|
|
|
|
|
+ 使用聚类方法找到"主流"的边界,避免单个超大单元格撑开边界。
|
|
|
|
|
|
|
|
算法:
|
|
算法:
|
|
|
- 1. 收集所有单元格的左边界x1和右边界x2
|
|
|
|
|
- 2. 对x1聚类,选择支持度最高的聚类中心作为表格左边界
|
|
|
|
|
- 3. 对x2聚类,选择支持度最高的聚类中心作为表格右边界
|
|
|
|
|
- 4. y方向使用简单的min/max(行高变化大,不适合聚类)
|
|
|
|
|
|
|
+ 1. 收集所有单元格的边界
|
|
|
|
|
+ 2. 聚类,选择支持度最高的聚类中心作为表格边界
|
|
|
|
|
+ 3. 通过容差向内调整边界,过滤掉过于宽松的边界(可能包含噪声单元格)
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
rtdetr_cells: RT-DETR单元格列表
|
|
rtdetr_cells: RT-DETR单元格列表
|
|
|
cluster_tolerance: 聚类容差(像素)
|
|
cluster_tolerance: 聚类容差(像素)
|
|
|
-
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
table_bbox: [x1, y1, x2, y2]
|
|
table_bbox: [x1, y1, x2, y2]
|
|
|
"""
|
|
"""
|
|
@@ -448,11 +536,13 @@ class CellFusionEngine:
|
|
|
# 对x2聚类,找主流右边界
|
|
# 对x2聚类,找主流右边界
|
|
|
robust_x2 = self._find_dominant_boundary(x2_coords, cluster_tolerance, mode='max')
|
|
robust_x2 = self._find_dominant_boundary(x2_coords, cluster_tolerance, mode='max')
|
|
|
# y方向直接取极值
|
|
# y方向直接取极值
|
|
|
- robust_y1 = min(y1_coords)
|
|
|
|
|
- robust_y2 = max(y2_coords)
|
|
|
|
|
|
|
+ robust_y1 = self._find_dominant_boundary(y1_coords, cluster_tolerance, mode='min')
|
|
|
|
|
+ robust_y2 = self._find_dominant_boundary(y2_coords, cluster_tolerance, mode='max')
|
|
|
|
|
|
|
|
logger.debug(f"📐 稳健边界估计: x=[{robust_x1:.1f}, {robust_x2:.1f}], "
|
|
logger.debug(f"📐 稳健边界估计: x=[{robust_x1:.1f}, {robust_x2:.1f}], "
|
|
|
- f"原始x范围=[{min(x1_coords):.1f}, {max(x2_coords):.1f}]")
|
|
|
|
|
|
|
+ f"原始x范围=[{min(x1_coords):.1f}, {max(x2_coords):.1f}]"
|
|
|
|
|
+ f" | y=[{robust_y1:.1f}, {robust_y2:.1f}], "
|
|
|
|
|
+ f"原始y范围=[{min(y1_coords):.1f}, {max(y2_coords):.1f}]")
|
|
|
|
|
|
|
|
return [robust_x1, robust_y1, robust_x2, robust_y2]
|
|
return [robust_x1, robust_y1, robust_x2, robust_y2]
|
|
|
|
|
|
|
@@ -624,7 +714,8 @@ class CellFusionEngine:
|
|
|
cells: List[List[float]],
|
|
cells: List[List[float]],
|
|
|
cell_labels: List[str],
|
|
cell_labels: List[str],
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
- rtdetr_bbox: List[float]
|
|
|
|
|
|
|
+ rtdetr_bbox: List[float],
|
|
|
|
|
+ boundary_tolerance: float = 0.0
|
|
|
) -> Tuple[List[List[float]], List[str], int]:
|
|
) -> Tuple[List[List[float]], List[str], int]:
|
|
|
"""
|
|
"""
|
|
|
过滤边界噪声单元格
|
|
过滤边界噪声单元格
|
|
@@ -639,6 +730,7 @@ class CellFusionEngine:
|
|
|
cell_labels: 单元格标签列表
|
|
cell_labels: 单元格标签列表
|
|
|
ocr_boxes: OCR结果列表
|
|
ocr_boxes: OCR结果列表
|
|
|
rtdetr_bbox: RT-DETR单元格的边界框 [x1, y1, x2, y2]
|
|
rtdetr_bbox: RT-DETR单元格的边界框 [x1, y1, x2, y2]
|
|
|
|
|
+ boundary_tolerance: 边界判定容忍范围(像素,原图坐标系)
|
|
|
Returns:
|
|
Returns:
|
|
|
(filtered_cells, filtered_labels, filtered_count)
|
|
(filtered_cells, filtered_labels, filtered_count)
|
|
|
"""
|
|
"""
|
|
@@ -646,6 +738,8 @@ class CellFusionEngine:
|
|
|
filtered_labels = []
|
|
filtered_labels = []
|
|
|
filtered_count = 0
|
|
filtered_count = 0
|
|
|
|
|
|
|
|
|
|
+ tol = max(0.0, boundary_tolerance)
|
|
|
|
|
+
|
|
|
for cell, label in zip(cells, cell_labels):
|
|
for cell, label in zip(cells, cell_labels):
|
|
|
# # 只过滤 unet_only 标记的单元格
|
|
# # 只过滤 unet_only 标记的单元格
|
|
|
# if label != 'unet_only':
|
|
# if label != 'unet_only':
|
|
@@ -655,9 +749,9 @@ class CellFusionEngine:
|
|
|
|
|
|
|
|
x1, y1, x2, y2 = cell
|
|
x1, y1, x2, y2 = cell
|
|
|
|
|
|
|
|
- # 检查是否在边界
|
|
|
|
|
- is_left_boundary = x1 <= rtdetr_bbox[0]
|
|
|
|
|
- is_right_boundary = x2 >= rtdetr_bbox[2]
|
|
|
|
|
|
|
+ # 检查是否在边界(加入容忍范围,避免贴边被误判)
|
|
|
|
|
+ is_left_boundary = x1 <= (rtdetr_bbox[0] - tol)
|
|
|
|
|
+ is_right_boundary = x2 >= (rtdetr_bbox[2] + tol)
|
|
|
is_on_boundary = is_left_boundary or is_right_boundary
|
|
is_on_boundary = is_left_boundary or is_right_boundary
|
|
|
|
|
|
|
|
if not is_on_boundary:
|
|
if not is_on_boundary:
|
|
@@ -906,6 +1000,7 @@ class CellFusionEngine:
|
|
|
merged_cells_1to1 = [] # 1:1融合单元格(黄色)
|
|
merged_cells_1to1 = [] # 1:1融合单元格(黄色)
|
|
|
merged_cells_span = [] # 合并单元格(品红色,RT-DETR检测的跨格单元格)
|
|
merged_cells_span = [] # 合并单元格(品红色,RT-DETR检测的跨格单元格)
|
|
|
new_cells = [] # 新增单元格(紫色)
|
|
new_cells = [] # 新增单元格(紫色)
|
|
|
|
|
+ split_cells = [] # UNet拆分得到的RT-DETR单元格(青色)
|
|
|
ocr_compensated = [] # OCR补偿单元格(橙色)
|
|
ocr_compensated = [] # OCR补偿单元格(橙色)
|
|
|
|
|
|
|
|
for fused_cell, label in zip(fused_cells, cell_labels):
|
|
for fused_cell, label in zip(fused_cells, cell_labels):
|
|
@@ -919,6 +1014,8 @@ class CellFusionEngine:
|
|
|
merged_cells_span.append(fused_cell)
|
|
merged_cells_span.append(fused_cell)
|
|
|
elif label == 'new':
|
|
elif label == 'new':
|
|
|
new_cells.append(fused_cell)
|
|
new_cells.append(fused_cell)
|
|
|
|
|
+ elif label == 'split_rtdetr':
|
|
|
|
|
+ split_cells.append(fused_cell)
|
|
|
elif label == 'ocr_compensated':
|
|
elif label == 'ocr_compensated':
|
|
|
ocr_compensated.append(fused_cell)
|
|
ocr_compensated.append(fused_cell)
|
|
|
|
|
|
|
@@ -942,6 +1039,10 @@ class CellFusionEngine:
|
|
|
for cell in new_cells:
|
|
for cell in new_cells:
|
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2) # 紫色 - 新增
|
|
cv2.rectangle(img3, (x1, y1), (x2, y2), (128, 0, 128), 2) # 紫色 - 新增
|
|
|
|
|
+
|
|
|
|
|
+ for cell in split_cells:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(v) for v in cell]
|
|
|
|
|
+ cv2.rectangle(img3, (x1, y1), (x2, y2), (255, 255, 0), 3) # 青色 - UNet拆分
|
|
|
|
|
|
|
|
for cell in ocr_compensated:
|
|
for cell in ocr_compensated:
|
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
x1, y1, x2, y2 = [int(v) for v in cell]
|
|
@@ -967,6 +1068,10 @@ class CellFusionEngine:
|
|
|
legend_y += 30
|
|
legend_y += 30
|
|
|
cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
|
|
cv2.putText(img3, f"Purple: New ({len(new_cells)})", (10, legend_y),
|
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (128, 0, 128), 2)
|
|
|
|
|
+ if split_cells:
|
|
|
|
|
+ legend_y += 30
|
|
|
|
|
+ cv2.putText(img3, f"Cyan: Split ({len(split_cells)})", (10, legend_y),
|
|
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
|
|
|
if ocr_compensated:
|
|
if ocr_compensated:
|
|
|
legend_y += 30
|
|
legend_y += 30
|
|
|
cv2.putText(img3, f"Orange: OCR Compensated ({len(ocr_compensated)})", (10, legend_y),
|
|
cv2.putText(img3, f"Orange: OCR Compensated ({len(ocr_compensated)})", (10, legend_y),
|