|
|
@@ -7,6 +7,9 @@ from typing import List, Dict, Any, Tuple, Optional
|
|
|
import bisect
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
+import os
|
|
|
+import re
|
|
|
+from pathlib import Path
|
|
|
from loguru import logger
|
|
|
|
|
|
from ocr_tools.universal_doc_parser.core.coordinate_utils import CoordinateUtils
|
|
|
@@ -25,7 +28,14 @@ class TextFiller:
|
|
|
"""
|
|
|
self.ocr_engine = ocr_engine
|
|
|
self.cell_crop_margin: int = config.get("cell_crop_margin", 2)
|
|
|
- self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.5)
|
|
|
+ self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.8) # 单元格 OCR 置信度阈值
|
|
|
+
|
|
|
+ # 跨单元格检测配置参数
|
|
|
+ self.overlap_threshold_horizontal: float = config.get("overlap_threshold_horizontal", 0.2)
|
|
|
+ self.overlap_threshold_vertical: float = config.get("overlap_threshold_vertical", 0.5)
|
|
|
+ self.min_overlap_area: float = config.get("min_overlap_area", 50.0)
|
|
|
+ self.center_cell_ratio: float = config.get("center_cell_ratio", 0.5)
|
|
|
+ self.other_cell_max_ratio: float = config.get("other_cell_max_ratio", 0.3)
|
|
|
|
|
|
@staticmethod
|
|
|
def calculate_overlap_ratio(ocr_bbox: List[float], cell_bbox: List[float]) -> float:
|
|
|
@@ -61,6 +71,7 @@ class TextFiller:
|
|
|
|
|
|
return inter_area / ocr_area
|
|
|
|
|
|
+
|
|
|
def fill_text_by_center_point(
|
|
|
self,
|
|
|
bboxes: List[List[float]],
|
|
|
@@ -190,6 +201,32 @@ class TextFiller:
|
|
|
matched_boxes_list[idx] = [box for _, _, _, _, _, box in matched]
|
|
|
else:
|
|
|
scores[idx] = 0.0 # 无匹配文本,置信度为0
|
|
|
+
|
|
|
+ # 在外层统一检测 OCR box 是否跨多个单元格或过大(避免重复检测)
|
|
|
+ processed_ocr_indices = set() # 记录已处理的 OCR box 索引,避免重复检测
|
|
|
+ for ocr_idx, ocr_item in enumerate(ocr_items):
|
|
|
+ if ocr_idx in processed_ocr_indices:
|
|
|
+ continue
|
|
|
+
|
|
|
+ ocr_bbox = ocr_item["bbox"]
|
|
|
+ center_point = (ocr_item["center_x"], ocr_item["center_y"])
|
|
|
+
|
|
|
+ # 检测是否跨多个单元格(使用方向感知检测)
|
|
|
+ overlapping_cells = self.detect_ocr_box_spanning_cells(
|
|
|
+ ocr_bbox,
|
|
|
+ bboxes,
|
|
|
+ overlap_threshold=None, # 使用配置的方向感知阈值
|
|
|
+ center_point=center_point
|
|
|
+ )
|
|
|
+
|
|
|
+ if len(overlapping_cells) >= 2:
|
|
|
+ # OCR box 跨多个单元格,标记所有相关单元格需要二次 OCR
|
|
|
+ for cell_idx in overlapping_cells:
|
|
|
+ if cell_idx not in need_reocr_indices:
|
|
|
+ need_reocr_indices.append(cell_idx)
|
|
|
+ logger.debug(f"检测到 OCR box 跨 {len(overlapping_cells)} 个单元格[{', '.join(map(str, overlapping_cells))}]: {ocr_item['text'][:20]}...")
|
|
|
+
|
|
|
+ processed_ocr_indices.add(ocr_idx)
|
|
|
|
|
|
return texts, scores, matched_boxes_list, need_reocr_indices
|
|
|
|
|
|
@@ -221,50 +258,100 @@ class TextFiller:
|
|
|
y2 = max(c[3] for c in coords_list)
|
|
|
return [float(x1), float(y1), float(x2), float(y2)]
|
|
|
|
|
|
- @staticmethod
|
|
|
def detect_ocr_box_spanning_cells(
|
|
|
+ self,
|
|
|
ocr_bbox: List[float],
|
|
|
cell_bboxes: List[List[float]],
|
|
|
- overlap_threshold: float = 0.3
|
|
|
+ overlap_threshold: Optional[float] = None,
|
|
|
+ center_point: Optional[Tuple[float, float]] = None
|
|
|
) -> List[int]:
|
|
|
"""
|
|
|
- 检测 OCR box 是否跨多个单元格
|
|
|
+ 检测 OCR box 是否跨多个单元格(支持方向感知检测)
|
|
|
|
|
|
Args:
|
|
|
ocr_bbox: OCR box 坐标 [x1, y1, x2, y2]
|
|
|
- cell_bboxes: 单元格坐标列表
|
|
|
- overlap_threshold: 重叠比例阈值(OCR box 与单元格的重叠面积占 OCR box 面积的比例)
|
|
|
+ cell_bboxes: 单元格坐标列表 [[x1, y1, x2, y2], ...]
|
|
|
+ overlap_threshold: 统一重叠比例阈值(如果为 None,则使用方向感知阈值)
|
|
|
+ center_point: OCR box 中心点坐标 (cx, cy),如果提供则用于判断中心点所在的单元格
|
|
|
|
|
|
Returns:
|
|
|
- 与 OCR box 重叠的单元格索引列表
|
|
|
+ 重叠的单元格索引列表(如果 <= 1 个则表示没有跨单元格)
|
|
|
"""
|
|
|
if not ocr_bbox or len(ocr_bbox) < 4:
|
|
|
return []
|
|
|
|
|
|
- overlapping_cells = []
|
|
|
- ocr_area = (ocr_bbox[2] - ocr_bbox[0]) * (ocr_bbox[3] - ocr_bbox[1])
|
|
|
+ ocr_x1, ocr_y1, ocr_x2, ocr_y2 = ocr_bbox
|
|
|
+ ocr_area = (ocr_x2 - ocr_x1) * (ocr_y2 - ocr_y1)
|
|
|
+ ocr_width = ocr_x2 - ocr_x1
|
|
|
+ ocr_height = ocr_y2 - ocr_y1
|
|
|
|
|
|
if ocr_area <= 0:
|
|
|
return []
|
|
|
|
|
|
+ # 找到中心点所在的单元格索引(如果提供了中心点坐标)
|
|
|
+ center_cell_idx = None
|
|
|
+ if center_point is not None:
|
|
|
+ cx, cy = center_point
|
|
|
+ for idx, cell_bbox in enumerate(cell_bboxes):
|
|
|
+ if not cell_bbox or len(cell_bbox) < 4:
|
|
|
+ continue
|
|
|
+ cell_x1, cell_y1, cell_x2, cell_y2 = cell_bbox
|
|
|
+ if cell_x1 <= cx <= cell_x2 and cell_y1 <= cy <= cell_y2:
|
|
|
+ center_cell_idx = idx
|
|
|
+ break
|
|
|
+
|
|
|
+ cell_overlaps: List[Tuple[int, float]] = []
|
|
|
+
|
|
|
for idx, cell_bbox in enumerate(cell_bboxes):
|
|
|
if not cell_bbox or len(cell_bbox) < 4:
|
|
|
continue
|
|
|
|
|
|
+ cell_x1, cell_y1, cell_x2, cell_y2 = cell_bbox
|
|
|
+
|
|
|
# 计算交集
|
|
|
- inter_x1 = max(ocr_bbox[0], cell_bbox[0])
|
|
|
- inter_y1 = max(ocr_bbox[1], cell_bbox[1])
|
|
|
- inter_x2 = min(ocr_bbox[2], cell_bbox[2])
|
|
|
- inter_y2 = min(ocr_bbox[3], cell_bbox[3])
|
|
|
+ inter_x1 = max(ocr_x1, cell_x1)
|
|
|
+ inter_y1 = max(ocr_y1, cell_y1)
|
|
|
+ inter_x2 = min(ocr_x2, cell_x2)
|
|
|
+ inter_y2 = min(ocr_y2, cell_y2)
|
|
|
|
|
|
if inter_x2 > inter_x1 and inter_y2 > inter_y1:
|
|
|
inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
|
|
|
overlap_ratio = inter_area / ocr_area
|
|
|
|
|
|
- if overlap_ratio > overlap_threshold:
|
|
|
- overlapping_cells.append(idx)
|
|
|
+ # 方向感知检测
|
|
|
+ if overlap_threshold is None:
|
|
|
+ # 计算水平和垂直重叠比例
|
|
|
+ h_overlap_ratio = (inter_x2 - inter_x1) / ocr_width if ocr_width > 0 else 0
|
|
|
+ v_overlap_ratio = (inter_y2 - inter_y1) / ocr_height if ocr_height > 0 else 0
|
|
|
+
|
|
|
+ # 垂直方向使用更严格的阈值,水平方向使用较宽松的阈值
|
|
|
+ # 同时检查重叠面积是否超过最小阈值
|
|
|
+ is_overlapping = (
|
|
|
+ (h_overlap_ratio > self.overlap_threshold_horizontal and
|
|
|
+ v_overlap_ratio > self.overlap_threshold_vertical) and
|
|
|
+ inter_area >= self.min_overlap_area
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # 使用统一阈值
|
|
|
+ is_overlapping = overlap_ratio > overlap_threshold
|
|
|
+
|
|
|
+ if is_overlapping:
|
|
|
+ cell_overlaps.append((idx, overlap_ratio))
|
|
|
+
|
|
|
+ # 如果中心点在某个单元格内,且该单元格的重叠比例符合阈值,且没有其他单元格达到次要阈值,则不标记为跨单元格
|
|
|
+ if center_cell_idx is not None and cell_overlaps:
|
|
|
+ # 找到中心点所在单元格的重叠比例
|
|
|
+ center_overlap = next((overlap for idx, overlap in cell_overlaps if idx == center_cell_idx), None)
|
|
|
+ if center_overlap is not None and center_overlap >= self.center_cell_ratio:
|
|
|
+ # 检查是否有其他单元格的重叠比例也超过次要阈值
|
|
|
+ other_high_overlaps = [idx for idx, overlap in cell_overlaps
|
|
|
+ if idx != center_cell_idx and overlap >= self.other_cell_max_ratio]
|
|
|
+ if not other_high_overlaps:
|
|
|
+ # 中心点所在单元格占主导,不应该标记为跨单元格
|
|
|
+ return []
|
|
|
|
|
|
- return overlapping_cells
|
|
|
+ # 返回所有满足阈值的单元格索引
|
|
|
+ return [idx for idx, _ in cell_overlaps]
|
|
|
|
|
|
def second_pass_ocr_fill(
|
|
|
self,
|
|
|
@@ -274,6 +361,7 @@ class TextFiller:
|
|
|
scores: Optional[List[float]] = None,
|
|
|
need_reocr_indices: Optional[List[int]] = None,
|
|
|
force_all: bool = False,
|
|
|
+ output_dir: Optional[str] = None,
|
|
|
) -> List[str]:
|
|
|
"""
|
|
|
二次OCR统一封装:
|
|
|
@@ -282,6 +370,7 @@ class TextFiller:
|
|
|
- 对竖排单元格(高宽比大)进行旋转后识别
|
|
|
- 对 OCR 误合并的单元格进行重识别(OCR box 跨多个单元格或过大)
|
|
|
- [New] force_all=True: 强制对所有单元格进行裁剪识别 (Full-page OCR 作为 fallback)
|
|
|
+ - [New] output_dir: 输出目录,如果提供则保存单元格OCR图片用于调试
|
|
|
|
|
|
Args:
|
|
|
table_image: 表格图像
|
|
|
@@ -290,6 +379,7 @@ class TextFiller:
|
|
|
scores: 当前置信度列表
|
|
|
need_reocr_indices: 需要二次 OCR 的单元格索引列表(OCR 误合并检测结果)
|
|
|
force_all: 是否强制对所有单元格进行 OCR (Default: False)
|
|
|
+ output_dir: 输出目录,如果提供则保存单元格OCR图片到 {output_dir}/tablecell_ocr/ 目录
|
|
|
"""
|
|
|
try:
|
|
|
if not self.ocr_engine:
|
|
|
@@ -303,6 +393,12 @@ class TextFiller:
|
|
|
if need_reocr_indices is None:
|
|
|
need_reocr_indices = []
|
|
|
|
|
|
+ # 如果提供了输出目录,创建 tablecell_ocr 子目录
|
|
|
+ cell_ocr_dir = None
|
|
|
+ if output_dir:
|
|
|
+ cell_ocr_dir = os.path.join(output_dir, "tablecell_ocr")
|
|
|
+ os.makedirs(cell_ocr_dir, exist_ok=True)
|
|
|
+
|
|
|
h_img, w_img = table_image.shape[:2]
|
|
|
margin = self.cell_crop_margin
|
|
|
|
|
|
@@ -369,6 +465,7 @@ class TextFiller:
|
|
|
if ch < 64 or cw < 64:
|
|
|
cell_img = cv2.resize(cell_img, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
|
|
ch, cw = cell_img.shape[:2]
|
|
|
+ logger.debug(f"单元格({texts[i] if i < len(texts) and len(texts[i]) else 'empty'}) {i} 裁剪后图像过小,放大至 {cw}x{ch} 像素")
|
|
|
|
|
|
# 竖排文本旋转为横排
|
|
|
if ch > cw * 2.0:
|
|
|
@@ -381,7 +478,7 @@ class TextFiller:
|
|
|
return texts
|
|
|
|
|
|
logger.info(f"触发二次OCR: {len(crop_list)} 个单元格 (总数 {len(texts)})")
|
|
|
-
|
|
|
+
|
|
|
# 先批量检测文本块,再批量识别(提高效率)
|
|
|
# Step 1: 批量检测
|
|
|
det_results = []
|
|
|
@@ -426,9 +523,6 @@ class TextFiller:
|
|
|
if x2 > x1 and y2 > y1:
|
|
|
cropped = cell_img[y1:y2, x1:x2]
|
|
|
ch, cw = cropped.shape[:2]
|
|
|
- # 小图放大
|
|
|
- if ch < 64 or cw < 64:
|
|
|
- cropped = cv2.resize(cropped, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
|
|
if cropped.size > 0:
|
|
|
rec_img_list.append(cropped)
|
|
|
rec_indices.append((cell_idx, box_idx))
|
|
|
@@ -503,10 +597,42 @@ class TextFiller:
|
|
|
n = min(len(results) if isinstance(results, list) else 0, len(crop_list), len(crop_indices))
|
|
|
conf_th = self.ocr_conf_threshold
|
|
|
|
|
|
+ # 辅助函数:清理文件名中的非法字符
|
|
|
+ def sanitize_filename(text: str, max_length: int = 50) -> str:
|
|
|
+ """清理文件名,移除非法字符并限制长度"""
|
|
|
+ if not text:
|
|
|
+ return "empty"
|
|
|
+ # 替换或删除文件名中的非法字符
|
|
|
+ # Windows/Linux 文件名非法字符: / \ : * ? " < > |
|
|
|
+ illegal_chars = r'[/\\:*?"<>|]'
|
|
|
+ sanitized = re.sub(illegal_chars, '_', text)
|
|
|
+ # 限制长度
|
|
|
+ if len(sanitized) > max_length:
|
|
|
+ sanitized = sanitized[:max_length]
|
|
|
+ # 移除首尾空格和下划线
|
|
|
+ sanitized = sanitized.strip('_').strip()
|
|
|
+ return sanitized if sanitized else "empty"
|
|
|
+
|
|
|
for k in range(n):
|
|
|
text_k, score_k = _parse_item(results[k])
|
|
|
+ cell_idx = crop_indices[k]
|
|
|
+ cell_img = crop_list[k]
|
|
|
+
|
|
|
+ # 保存单元格OCR图片用于调试
|
|
|
+ if cell_ocr_dir and cell_img is not None:
|
|
|
+ try:
|
|
|
+ # 生成文件名:序号_识别内容
|
|
|
+ sanitized_text = sanitize_filename(text_k)
|
|
|
+ filename = f"{cell_idx:03d}_{sanitized_text}.png"
|
|
|
+ filepath = os.path.join(cell_ocr_dir, filename)
|
|
|
+ cv2.imwrite(filepath, cell_img)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"保存单元格OCR图片失败 (cell {cell_idx}): {e}")
|
|
|
+
|
|
|
if text_k and score_k >= conf_th:
|
|
|
- texts[crop_indices[k]] = text_k
|
|
|
+ texts[cell_idx] = text_k
|
|
|
+ elif text_k:
|
|
|
+ logger.debug(f"单元格 {cell_idx} 二次OCR结果置信度({score_k:.2f})低于阈值({conf_th}): (文本: '{text_k[:30]}...')")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"二次OCR失败: {e}")
|