|
@@ -4,6 +4,7 @@
|
|
|
提供表格单元格文本填充功能,包括OCR文本匹配和二次OCR填充。
|
|
提供表格单元格文本填充功能,包括OCR文本匹配和二次OCR填充。
|
|
|
"""
|
|
"""
|
|
|
from typing import List, Dict, Any, Tuple, Optional
|
|
from typing import List, Dict, Any, Tuple, Optional
|
|
|
|
|
+import bisect
|
|
|
import cv2
|
|
import cv2
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
@@ -26,17 +27,58 @@ class TextFiller:
|
|
|
self.cell_crop_margin: int = config.get("cell_crop_margin", 2)
|
|
self.cell_crop_margin: int = config.get("cell_crop_margin", 2)
|
|
|
self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.5)
|
|
self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.5)
|
|
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def calculate_overlap_ratio(ocr_bbox: List[float], cell_bbox: List[float]) -> float:
|
|
|
|
|
+ """
|
|
|
|
|
+ 计算 OCR box 与单元格的重叠比例(重叠面积 / OCR box 面积)
|
|
|
|
|
+
|
|
|
|
|
+ 这个比例表示 OCR box 有多少部分在单元格内,用于判断 OCR box 是否主要属于该单元格。
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ ocr_bbox: OCR box 坐标 [x1, y1, x2, y2]
|
|
|
|
|
+ cell_bbox: 单元格坐标 [x1, y1, x2, y2]
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 重叠比例 (0.0 ~ 1.0),表示 OCR box 有多少部分在单元格内
|
|
|
|
|
+ """
|
|
|
|
|
+ if not ocr_bbox or not cell_bbox or len(ocr_bbox) < 4 or len(cell_bbox) < 4:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ # 计算交集
|
|
|
|
|
+ inter_x1 = max(ocr_bbox[0], cell_bbox[0])
|
|
|
|
|
+ inter_y1 = max(ocr_bbox[1], cell_bbox[1])
|
|
|
|
|
+ inter_x2 = min(ocr_bbox[2], cell_bbox[2])
|
|
|
|
|
+ inter_y2 = min(ocr_bbox[3], cell_bbox[3])
|
|
|
|
|
+
|
|
|
|
|
+ if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
|
|
|
|
|
+ ocr_area = (ocr_bbox[2] - ocr_bbox[0]) * (ocr_bbox[3] - ocr_bbox[1])
|
|
|
|
|
+
|
|
|
|
|
+ if ocr_area <= 0:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ return inter_area / ocr_area
|
|
|
|
|
+
|
|
|
def fill_text_by_center_point(
|
|
def fill_text_by_center_point(
|
|
|
self,
|
|
self,
|
|
|
bboxes: List[List[float]],
|
|
bboxes: List[List[float]],
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
) -> Tuple[List[str], List[float], List[List[Dict[str, Any]]], List[int]]:
|
|
) -> Tuple[List[str], List[float], List[List[Dict[str, Any]]], List[int]]:
|
|
|
"""
|
|
"""
|
|
|
- 使用中心点落格策略填充文本。
|
|
|
|
|
|
|
+ 使用混合匹配策略填充文本:中心点 + 重叠比例。
|
|
|
|
|
+
|
|
|
|
|
+ 策略说明:
|
|
|
|
|
+ 1. 首先用中心点快速筛选:OCR box 的中心点在单元格内
|
|
|
|
|
+ 2. 然后检查重叠比例:OCR box 与单元格的重叠面积 / OCR box 面积 >= 0.5
|
|
|
|
|
+ (这确保 OCR box 主要属于该单元格,避免跨单元格匹配)
|
|
|
|
|
+ 3. 如果多个单元格都满足条件,选择重叠比例最高的
|
|
|
|
|
|
|
|
- 参考 fill_html_with_ocr_by_bbox:
|
|
|
|
|
- - OCR文本中心点落入单元格bbox内则匹配
|
|
|
|
|
- - 多行文本按y坐标排序拼接
|
|
|
|
|
|
|
+ 优点:
|
|
|
|
|
+ - 比纯 IOU 更宽松,能匹配到更多 OCR box
|
|
|
|
|
+ - 比纯中心点更准确,能过滤跨单元格的 OCR box
|
|
|
|
|
+ - 适合表格场景,OCR box 通常比单元格小或部分重叠
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
|
|
bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
|
|
@@ -45,7 +87,7 @@ class TextFiller:
|
|
|
Returns:
|
|
Returns:
|
|
|
每个单元格的文本列表
|
|
每个单元格的文本列表
|
|
|
每个单元格的置信度列表
|
|
每个单元格的置信度列表
|
|
|
- 每个单元格匹配到的 OCR boxes 列表
|
|
|
|
|
|
|
+ 每个单元格匹配到的 OCR boxes 列表(已过滤跨单元格的 OCR box)
|
|
|
需要二次 OCR 的单元格索引列表(OCR box 跨多个单元格或过大)
|
|
需要二次 OCR 的单元格索引列表(OCR box 跨多个单元格或过大)
|
|
|
"""
|
|
"""
|
|
|
texts: List[str] = ["" for _ in bboxes]
|
|
texts: List[str] = ["" for _ in bboxes]
|
|
@@ -56,61 +98,96 @@ class TextFiller:
|
|
|
if not ocr_boxes:
|
|
if not ocr_boxes:
|
|
|
return texts, scores, matched_boxes_list, need_reocr_indices
|
|
return texts, scores, matched_boxes_list, need_reocr_indices
|
|
|
|
|
|
|
|
- # 预处理OCR结果:计算中心点
|
|
|
|
|
|
|
+ # 预处理OCR结果:转换为 bbox 格式,并计算中心点
|
|
|
ocr_items: List[Dict[str, Any]] = []
|
|
ocr_items: List[Dict[str, Any]] = []
|
|
|
for item in ocr_boxes:
|
|
for item in ocr_boxes:
|
|
|
# 使用 CoordinateUtils.poly_to_bbox() 替换 _normalize_bbox()
|
|
# 使用 CoordinateUtils.poly_to_bbox() 替换 _normalize_bbox()
|
|
|
box = CoordinateUtils.poly_to_bbox(item.get("bbox", []))
|
|
box = CoordinateUtils.poly_to_bbox(item.get("bbox", []))
|
|
|
- if not box:
|
|
|
|
|
|
|
+ if not box or len(box) < 4:
|
|
|
continue
|
|
continue
|
|
|
cx = (box[0] + box[2]) / 2
|
|
cx = (box[0] + box[2]) / 2
|
|
|
cy = (box[1] + box[3]) / 2
|
|
cy = (box[1] + box[3]) / 2
|
|
|
ocr_items.append({
|
|
ocr_items.append({
|
|
|
|
|
+ "bbox": box,
|
|
|
"center_x": cx,
|
|
"center_x": cx,
|
|
|
"center_y": cy,
|
|
"center_y": cy,
|
|
|
- "y1": box[1],
|
|
|
|
|
- "bbox": box, # 保存 bbox 用于跨单元格检测
|
|
|
|
|
"text": item.get("text", ""),
|
|
"text": item.get("text", ""),
|
|
|
"confidence": float(item.get("confidence", item.get("score", 1.0))),
|
|
"confidence": float(item.get("confidence", item.get("score", 1.0))),
|
|
|
"original_box": item, # 保存完整的 OCR box 对象
|
|
"original_box": item, # 保存完整的 OCR box 对象
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
- # 为每个单元格匹配OCR文本
|
|
|
|
|
- for idx, bbox in enumerate(bboxes):
|
|
|
|
|
- x1, y1, x2, y2 = bbox
|
|
|
|
|
- matched: List[Tuple[str, float, float, Dict[str, Any]]] = [] # (text, y1, score, original_box)
|
|
|
|
|
|
|
+ # 按 (y1, x1) 排序,便于后续二分查找和提前退出
|
|
|
|
|
+ # 排序只需要一次,对整体性能影响很小(O(n log n))
|
|
|
|
|
+ ocr_items.sort(key=lambda item: (item["bbox"][1], item["bbox"][0]))
|
|
|
|
|
+
|
|
|
|
|
+ # 重叠比例阈值:OCR box 与单元格的重叠面积必须 >= OCR box 面积的 50%
|
|
|
|
|
+ # 这确保 OCR box 主要属于该单元格
|
|
|
|
|
+ overlap_ratio_threshold = 0.5
|
|
|
|
|
+
|
|
|
|
|
+ # 为每个单元格匹配OCR文本(使用中心点 + 重叠比例)
|
|
|
|
|
+ # 优化:使用二分查找和提前退出机制,减少遍历次数
|
|
|
|
|
+ # 创建一个 y1 值的列表用于二分查找(兼容 Python < 3.10)
|
|
|
|
|
+ ocr_y1_list = [item["bbox"][1] for item in ocr_items]
|
|
|
|
|
+
|
|
|
|
|
+ for idx, cell_bbox in enumerate(bboxes):
|
|
|
|
|
+ cell_x1, cell_y1, cell_x2, cell_y2 = cell_bbox
|
|
|
|
|
+ matched: List[Tuple[str, float, float, float, float, Dict[str, Any]]] = [] # (text, y1, x1, overlap_ratio, score, original_box)
|
|
|
|
|
|
|
|
- for ocr in ocr_items:
|
|
|
|
|
- if x1 <= ocr["center_x"] <= x2 and y1 <= ocr["center_y"] <= y2:
|
|
|
|
|
- matched.append((ocr["text"], ocr["y1"], ocr["confidence"], ocr["original_box"]))
|
|
|
|
|
|
|
+ # 使用二分查找找到第一个 y1 >= cell_y1 的 OCR item
|
|
|
|
|
+ # 由于 ocr_items 已按 (y1, x1) 排序,可以使用 bisect_left
|
|
|
|
|
+ start_idx = bisect.bisect_left(ocr_y1_list, cell_y1)
|
|
|
|
|
+
|
|
|
|
|
+ # 关键优化:OCR box 的 y1 可能 < cell_y1,但 y2 >= cell_y1(跨越单元格上边界)
|
|
|
|
|
+ # 为了不遗漏这种情况,我们需要向前查找一些 items
|
|
|
|
|
+ # 向前查找的最大数量:假设 OCR box 最大高度不超过 100 像素(可根据实际情况调整)
|
|
|
|
|
+ max_lookback = 20 # 向前查找最多 20 个 items
|
|
|
|
|
+ actual_start_idx = max(0, start_idx - max_lookback)
|
|
|
|
|
+
|
|
|
|
|
+ # 从 actual_start_idx 开始遍历,当 y1 > cell_y2 时提前退出
|
|
|
|
|
+ for i in range(actual_start_idx, len(ocr_items)):
|
|
|
|
|
+ ocr_item = ocr_items[i]
|
|
|
|
|
+ ocr_bbox = ocr_item["bbox"]
|
|
|
|
|
+
|
|
|
|
|
+ # 提前退出:如果 y1 > cell_y2,后续的 items 都不可能在单元格内
|
|
|
|
|
+ if ocr_bbox[1] > cell_y2:
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ # 快速过滤:如果 OCR box 的 y2 < cell_y1,说明它完全在单元格上方,跳过
|
|
|
|
|
+ if ocr_bbox[3] < cell_y1:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ cx = ocr_item["center_x"]
|
|
|
|
|
+ cy = ocr_item["center_y"]
|
|
|
|
|
+
|
|
|
|
|
+ # 第一步:中心点必须在单元格内
|
|
|
|
|
+ if not (cell_x1 <= cx <= cell_x2 and cell_y1 <= cy <= cell_y2):
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 第二步:检查重叠比例(OCR box 有多少部分在单元格内)
|
|
|
|
|
+ overlap_ratio = self.calculate_overlap_ratio(ocr_bbox, cell_bbox)
|
|
|
|
|
+ if overlap_ratio >= overlap_ratio_threshold:
|
|
|
|
|
+ matched.append((
|
|
|
|
|
+ ocr_item["text"],
|
|
|
|
|
+ ocr_bbox[1], # y1 坐标
|
|
|
|
|
+ ocr_bbox[0], # 添加 x1 坐标
|
|
|
|
|
+ overlap_ratio,
|
|
|
|
|
+ ocr_item["confidence"],
|
|
|
|
|
+ ocr_item["original_box"]
|
|
|
|
|
+ ))
|
|
|
|
|
|
|
|
if matched:
|
|
if matched:
|
|
|
- # 按y坐标排序,确保多行文本顺序正确
|
|
|
|
|
- matched.sort(key=lambda x: x[1])
|
|
|
|
|
- texts[idx] = "".join([t for t, _, _, _ in matched])
|
|
|
|
|
|
|
+ # 直接按 y1 和 x1 排序,确保文本顺序正确
|
|
|
|
|
+ # y_tolerance 用于将相近的 y1 归为同一行(容差范围内视为同一行)
|
|
|
|
|
+ # 同一行内按 x1 从左到右排序
|
|
|
|
|
+ y_tolerance = 5
|
|
|
|
|
+ matched.sort(key=lambda x: (round(x[1] / y_tolerance), x[2])) # 先按 y_group,再按 x1
|
|
|
|
|
+
|
|
|
|
|
+ texts[idx] = "".join([t for t, _, _, _, _, _ in matched])
|
|
|
# 计算平均置信度
|
|
# 计算平均置信度
|
|
|
- avg_score = sum([s for _, _, s, _ in matched]) / len(matched)
|
|
|
|
|
|
|
+ avg_score = sum([s for _, _, _, _, s, _ in matched]) / len(matched)
|
|
|
scores[idx] = avg_score
|
|
scores[idx] = avg_score
|
|
|
# 保存匹配到的 OCR boxes
|
|
# 保存匹配到的 OCR boxes
|
|
|
- matched_boxes_list[idx] = [box for _, _, _, box in matched]
|
|
|
|
|
-
|
|
|
|
|
- # 检测 OCR box 是否跨多个单元格或过大
|
|
|
|
|
- for ocr_item in ocr_items:
|
|
|
|
|
- ocr_bbox = ocr_item["bbox"]
|
|
|
|
|
- # 检测是否跨多个单元格
|
|
|
|
|
- overlapping_cells = self.detect_ocr_box_spanning_cells(ocr_bbox, bboxes, overlap_threshold=0.3)
|
|
|
|
|
- if len(overlapping_cells) >= 2:
|
|
|
|
|
- # OCR box 跨多个单元格,标记所有相关单元格需要二次 OCR
|
|
|
|
|
- for cell_idx in overlapping_cells:
|
|
|
|
|
- if cell_idx not in need_reocr_indices:
|
|
|
|
|
- need_reocr_indices.append(cell_idx)
|
|
|
|
|
- logger.debug(f"检测到 OCR box 跨 {len(overlapping_cells)} 个单元格: {ocr_item['text'][:20]}...")
|
|
|
|
|
-
|
|
|
|
|
- # 检测 OCR box 是否相对于当前单元格过大
|
|
|
|
|
- if self.is_ocr_box_too_large(ocr_bbox, bbox, size_ratio_threshold=1.5):
|
|
|
|
|
- if idx not in need_reocr_indices:
|
|
|
|
|
- need_reocr_indices.append(idx)
|
|
|
|
|
- logger.debug(f"检测到 OCR box 相对于单元格过大 (单元格 {idx}): {ocr_item['text'][:20]}...")
|
|
|
|
|
|
|
+ matched_boxes_list[idx] = [box for _, _, _, _, _, box in matched]
|
|
|
else:
|
|
else:
|
|
|
scores[idx] = 0.0 # 无匹配文本,置信度为0
|
|
scores[idx] = 0.0 # 无匹配文本,置信度为0
|
|
|
|
|
|
|
@@ -189,35 +266,6 @@ class TextFiller:
|
|
|
|
|
|
|
|
return overlapping_cells
|
|
return overlapping_cells
|
|
|
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def is_ocr_box_too_large(
|
|
|
|
|
- ocr_bbox: List[float],
|
|
|
|
|
- cell_bbox: List[float],
|
|
|
|
|
- size_ratio_threshold: float = 1.5
|
|
|
|
|
- ) -> bool:
|
|
|
|
|
- """
|
|
|
|
|
- 检测 OCR box 是否相对于单元格过大
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- ocr_bbox: OCR box 坐标 [x1, y1, x2, y2]
|
|
|
|
|
- cell_bbox: 单元格坐标 [x1, y1, x2, y2]
|
|
|
|
|
- size_ratio_threshold: 面积比阈值,如果 OCR box 面积 > 单元格面积 * 阈值,则认为过大
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 是否过大
|
|
|
|
|
- """
|
|
|
|
|
- if not ocr_bbox or len(ocr_bbox) < 4 or not cell_bbox or len(cell_bbox) < 4:
|
|
|
|
|
- return False
|
|
|
|
|
-
|
|
|
|
|
- ocr_area = (ocr_bbox[2] - ocr_bbox[0]) * (ocr_bbox[3] - ocr_bbox[1])
|
|
|
|
|
- cell_area = (cell_bbox[2] - cell_bbox[0]) * (cell_bbox[3] - cell_bbox[1])
|
|
|
|
|
-
|
|
|
|
|
- if cell_area <= 0:
|
|
|
|
|
- return False
|
|
|
|
|
-
|
|
|
|
|
- size_ratio = ocr_area / cell_area
|
|
|
|
|
- return size_ratio > size_ratio_threshold
|
|
|
|
|
-
|
|
|
|
|
def second_pass_ocr_fill(
|
|
def second_pass_ocr_fill(
|
|
|
self,
|
|
self,
|
|
|
table_image: np.ndarray,
|
|
table_image: np.ndarray,
|
|
@@ -377,6 +425,10 @@ class TextFiller:
|
|
|
|
|
|
|
|
if x2 > x1 and y2 > y1:
|
|
if x2 > x1 and y2 > y1:
|
|
|
cropped = cell_img[y1:y2, x1:x2]
|
|
cropped = cell_img[y1:y2, x1:x2]
|
|
|
|
|
+ ch, cw = cropped.shape[:2]
|
|
|
|
|
+ # 小图放大
|
|
|
|
|
+ if ch < 64 or cw < 64:
|
|
|
|
|
+ cropped = cv2.resize(cropped, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
|
|
if cropped.size > 0:
|
|
if cropped.size > 0:
|
|
|
rec_img_list.append(cropped)
|
|
rec_img_list.append(cropped)
|
|
|
rec_indices.append((cell_idx, box_idx))
|
|
rec_indices.append((cell_idx, box_idx))
|