|
|
@@ -9,7 +9,6 @@ import cv2
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import re
|
|
|
-from pathlib import Path
|
|
|
from loguru import logger
|
|
|
|
|
|
from ocr_utils.coordinate_utils import CoordinateUtils
|
|
|
@@ -42,7 +41,259 @@ class TextFiller:
|
|
|
self.horizontal_secondary_overlap_ratio: float = config.get(
|
|
|
"horizontal_secondary_overlap_ratio", 0.15
|
|
|
)
|
|
|
-
|
|
|
+ sp_cfg = config.get("second_pass_ocr") or {}
|
|
|
+ if not isinstance(sp_cfg, dict):
|
|
|
+ sp_cfg = {}
|
|
|
+ self.second_pass_line_min_score: float = float(sp_cfg.get("line_min_score", 0.8))
|
|
|
+ self.second_pass_drop_low: bool = bool(sp_cfg.get("drop_low_score_blocks", True))
|
|
|
+ self.second_pass_whole_fallback: bool = bool(sp_cfg.get("whole_cell_fallback", True))
|
|
|
+ self.second_pass_prefer_whole_on_tie: bool = bool(
|
|
|
+ sp_cfg.get("prefer_whole_on_tie", True)
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def sanitize_debug_filename(text: str, max_length: int = 50) -> str:
|
|
|
+ """清理调试文件名中的非法字符。"""
|
|
|
+ if not text:
|
|
|
+ return "empty"
|
|
|
+ illegal_chars = r'[/\\:*?"<>|]'
|
|
|
+ sanitized = re.sub(illegal_chars, "_", text)
|
|
|
+ if len(sanitized) > max_length:
|
|
|
+ sanitized = sanitized[:max_length]
|
|
|
+ sanitized = sanitized.strip("_").strip()
|
|
|
+ return sanitized if sanitized else "empty"
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def aggregate_line_ocr(
|
|
|
+ blocks: List[Tuple[str, float]],
|
|
|
+ *,
|
|
|
+ line_min_score: float = 0.8,
|
|
|
+ drop_low_score_blocks: bool = True,
|
|
|
+ ) -> Tuple[str, float]:
|
|
|
+ """
|
|
|
+ 合并分行 OCR 结果:可选丢弃低分块,置信度按字符数加权平均。
|
|
|
+ """
|
|
|
+ if not blocks:
|
|
|
+ return "", 0.0
|
|
|
+ kept: List[Tuple[str, float]] = []
|
|
|
+ for text, score in blocks:
|
|
|
+ t = (text or "").strip()
|
|
|
+ if not t:
|
|
|
+ continue
|
|
|
+ if drop_low_score_blocks and score < line_min_score:
|
|
|
+ continue
|
|
|
+ kept.append((t, score))
|
|
|
+ if not kept:
|
|
|
+ return "", 0.0
|
|
|
+ total_len = sum(len(t) for t, _ in kept)
|
|
|
+ if total_len <= 0:
|
|
|
+ return "", 0.0
|
|
|
+ combined = "".join(t for t, _ in kept)
|
|
|
+ weighted = sum(len(t) * s for t, s in kept) / total_len
|
|
|
+ return combined, weighted
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _det_box_center(box: Any) -> Tuple[float, float]:
|
|
|
+ if not box or len(box) < 4:
|
|
|
+ return 0.0, 0.0
|
|
|
+ if isinstance(box[0], (list, tuple)):
|
|
|
+ xs = [float(p[0]) for p in box]
|
|
|
+ ys = [float(p[1]) for p in box]
|
|
|
+ else:
|
|
|
+ xs = [float(box[i]) for i in range(0, len(box), 2)]
|
|
|
+ ys = [float(box[i]) for i in range(1, len(box), 2)]
|
|
|
+ return sum(xs) / len(xs), sum(ys) / len(ys)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _det_box_xyxy_float(box: Any) -> Optional[Tuple[float, float, float, float]]:
|
|
|
+ """检测框转 (x1, y1, x2, y2),用于阅读顺序排序。"""
|
|
|
+ if not box or len(box) < 4:
|
|
|
+ return None
|
|
|
+ if isinstance(box[0], (list, tuple)):
|
|
|
+ xs = [float(p[0]) for p in box]
|
|
|
+ ys = [float(p[1]) for p in box]
|
|
|
+ else:
|
|
|
+ xs = [float(box[i]) for i in range(0, len(box), 2)]
|
|
|
+ ys = [float(box[i]) for i in range(1, len(box), 2)]
|
|
|
+ x1, x2 = min(xs), max(xs)
|
|
|
+ y1, y2 = min(ys), max(ys)
|
|
|
+ if x2 <= x1 or y2 <= y1:
|
|
|
+ return None
|
|
|
+ return x1, y1, x2, y2
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def sort_det_boxes_reading_order(
|
|
|
+ dt_boxes: List[Any],
|
|
|
+ img_h: int,
|
|
|
+ img_w: int,
|
|
|
+ ) -> List[Any]:
|
|
|
+ """
|
|
|
+ 检测框阅读顺序:先按行(从上到下),行内从左到右。
|
|
|
+ 使用行聚类,避免仅用中心点 (cy,cx) 时同行框因高度差导致乱序。
|
|
|
+ """
|
|
|
+ if not dt_boxes:
|
|
|
+ return []
|
|
|
+ typed: List[Tuple[Any, float, float, float, float]] = []
|
|
|
+ for box in dt_boxes:
|
|
|
+ xyxy = TextFiller._det_box_xyxy_float(box)
|
|
|
+ if xyxy is None:
|
|
|
+ continue
|
|
|
+ x1, y1, x2, y2 = xyxy
|
|
|
+ typed.append((box, x1, y1, x2, y2))
|
|
|
+ if not typed:
|
|
|
+ return []
|
|
|
+ if len(typed) == 1:
|
|
|
+ return [typed[0][0]]
|
|
|
+
|
|
|
+ heights = [t[4] - t[2] for t in typed]
|
|
|
+ median_h = sorted(heights)[len(heights) // 2]
|
|
|
+ row_thresh = max(median_h * 0.5, 4.0)
|
|
|
+
|
|
|
+ # 先按行中心 y 排序,再聚类为行
|
|
|
+ typed.sort(key=lambda t: (t[2] + t[4]) * 0.5)
|
|
|
+ rows: List[List[Tuple[Any, float, float, float, float]]] = []
|
|
|
+ for item in typed:
|
|
|
+ cy = (item[2] + item[4]) * 0.5
|
|
|
+ if not rows:
|
|
|
+ rows.append([item])
|
|
|
+ continue
|
|
|
+ row_cy = sum((it[2] + it[4]) * 0.5 for it in rows[-1]) / len(rows[-1])
|
|
|
+ if abs(cy - row_cy) <= row_thresh:
|
|
|
+ rows[-1].append(item)
|
|
|
+ else:
|
|
|
+ rows.append([item])
|
|
|
+
|
|
|
+ ordered: List[Any] = []
|
|
|
+ for row in rows:
|
|
|
+ row.sort(key=lambda t: t[1]) # x1 从左到右
|
|
|
+ ordered.extend(item[0] for item in row)
|
|
|
+ return ordered
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _det_box_to_xyxy(box: Any, w: int, h: int) -> Optional[Tuple[int, int, int, int]]:
|
|
|
+ if not box or len(box) < 4:
|
|
|
+ return None
|
|
|
+ if isinstance(box[0], (list, tuple)):
|
|
|
+ xs = [p[0] for p in box]
|
|
|
+ ys = [p[1] for p in box]
|
|
|
+ else:
|
|
|
+ xs = [box[i] for i in range(0, len(box), 2)]
|
|
|
+ ys = [box[i] for i in range(1, len(box), 2)]
|
|
|
+ x1 = int(max(0, min(xs)))
|
|
|
+ y1 = int(max(0, min(ys)))
|
|
|
+ x2 = int(min(w, max(xs)))
|
|
|
+ y2 = int(min(h, max(ys)))
|
|
|
+ if x2 <= x1 or y2 <= y1:
|
|
|
+ return None
|
|
|
+ return x1, y1, x2, y2
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _parse_single_rec_item(rec_item: Any) -> Tuple[str, float]:
|
|
|
+ if rec_item is None:
|
|
|
+ return "", 0.0
|
|
|
+ if isinstance(rec_item, tuple) and len(rec_item) >= 2:
|
|
|
+ return str(rec_item[0] or "").strip(), float(rec_item[1] or 0.0)
|
|
|
+ if isinstance(rec_item, list) and len(rec_item) >= 2:
|
|
|
+ if isinstance(rec_item[0], (list, tuple, dict)):
|
|
|
+ texts_list: List[str] = []
|
|
|
+ scores_list: List[float] = []
|
|
|
+ for item in rec_item:
|
|
|
+ t, s = TextFiller._parse_single_rec_item(item)
|
|
|
+ if t:
|
|
|
+ texts_list.append(t)
|
|
|
+ scores_list.append(s)
|
|
|
+ if texts_list:
|
|
|
+ combined = "".join(texts_list)
|
|
|
+ total_len = sum(len(t) for t in texts_list)
|
|
|
+ if total_len > 0:
|
|
|
+ weighted = sum(len(t) * s for t, s in zip(texts_list, scores_list)) / total_len
|
|
|
+ return combined, weighted
|
|
|
+ return combined, sum(scores_list) / len(scores_list)
|
|
|
+ return "", 0.0
|
|
|
+ return str(rec_item[0] or "").strip(), float(rec_item[1] or 0.0)
|
|
|
+ if isinstance(rec_item, dict):
|
|
|
+ txt = str(rec_item.get("text") or rec_item.get("label") or "").strip()
|
|
|
+ sc = float(rec_item.get("score") or rec_item.get("confidence") or 0.0)
|
|
|
+ return txt, sc
|
|
|
+ return "", 0.0
|
|
|
+
|
|
|
+ def _extract_ocr_batch_results(self, rec_res: Any) -> List[Any]:
|
|
|
+ if not rec_res:
|
|
|
+ return []
|
|
|
+ if isinstance(rec_res, list) and len(rec_res) > 0:
|
|
|
+ if isinstance(rec_res[0], list):
|
|
|
+ return rec_res[0]
|
|
|
+ return rec_res
|
|
|
+ return []
|
|
|
+
|
|
|
+ def _recognize_whole_cell(self, cell_img: np.ndarray) -> Tuple[str, float]:
|
|
|
+ try:
|
|
|
+ rec_res = self.ocr_engine.ocr(cell_img, det=False, rec=True)
|
|
|
+ items = self._extract_ocr_batch_results(rec_res)
|
|
|
+ if not items:
|
|
|
+ return "", 0.0
|
|
|
+ return self._parse_single_rec_item(items[0] if len(items) == 1 else items)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"整格 OCR 失败: {e}")
|
|
|
+ return "", 0.0
|
|
|
+
|
|
|
+ def _recognize_cell_lines(self, cell_img: np.ndarray) -> List[Tuple[str, float]]:
|
|
|
+ """det 分行后逐行识别,检测框按阅读顺序(上行下、左到右)排序。"""
|
|
|
+ blocks: List[Tuple[str, float]] = []
|
|
|
+ try:
|
|
|
+ det_res = self.ocr_engine.ocr(cell_img, det=True, rec=False)
|
|
|
+ dt_boxes = []
|
|
|
+ if det_res and len(det_res) > 0:
|
|
|
+ dt_boxes = det_res[0] if det_res[0] else []
|
|
|
+ if not dt_boxes:
|
|
|
+ return blocks
|
|
|
+ h, w = cell_img.shape[:2]
|
|
|
+ sorted_boxes = self.sort_det_boxes_reading_order(dt_boxes, h, w)
|
|
|
+ rec_img_list: List[np.ndarray] = []
|
|
|
+ for box in sorted_boxes:
|
|
|
+ xyxy = self._det_box_to_xyxy(box, w, h)
|
|
|
+ if xyxy is None:
|
|
|
+ continue
|
|
|
+ x1, y1, x2, y2 = xyxy
|
|
|
+ cropped = cell_img[y1:y2, x1:x2]
|
|
|
+ if cropped.size > 0:
|
|
|
+ rec_img_list.append(cropped)
|
|
|
+ if not rec_img_list:
|
|
|
+ return blocks
|
|
|
+ rec_res = self.ocr_engine.ocr(rec_img_list, det=False, rec=True)
|
|
|
+ rec_items = self._extract_ocr_batch_results(rec_res)
|
|
|
+ for rec_item in rec_items:
|
|
|
+ text, score = self._parse_single_rec_item(rec_item)
|
|
|
+ if text:
|
|
|
+ blocks.append((text, score))
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"分行 OCR 失败: {e}")
|
|
|
+ return blocks
|
|
|
+
|
|
|
+ def _pick_line_vs_whole(
|
|
|
+ self,
|
|
|
+ line_text: str,
|
|
|
+ line_score: float,
|
|
|
+ whole_text: str,
|
|
|
+ whole_score: float,
|
|
|
+ ) -> Tuple[str, float, str]:
|
|
|
+ """返回 (text, score, strategy) strategy in lines|whole|tie_whole|tie_lines."""
|
|
|
+ if not self.second_pass_whole_fallback:
|
|
|
+ return line_text, line_score, "lines"
|
|
|
+ if not whole_text and line_text:
|
|
|
+ return line_text, line_score, "lines"
|
|
|
+ if whole_text and not line_text:
|
|
|
+ return whole_text, whole_score, "whole"
|
|
|
+ if not whole_text and not line_text:
|
|
|
+ return "", 0.0, "empty"
|
|
|
+ if line_score > whole_score:
|
|
|
+ return line_text, line_score, "lines"
|
|
|
+ if line_score < whole_score:
|
|
|
+ return whole_text, whole_score, "whole"
|
|
|
+ if self.second_pass_prefer_whole_on_tie and whole_text:
|
|
|
+ return whole_text, whole_score, "tie_whole"
|
|
|
+ return line_text, line_score, "tie_lines"
|
|
|
+
|
|
|
@staticmethod
|
|
|
def calculate_dynamic_confidence_threshold(text: str, base_threshold: float = 0.9) -> float:
|
|
|
"""
|
|
|
@@ -453,94 +704,54 @@ class TextFiller:
|
|
|
pdf_type: str = 'ocr', # 'ocr' 或 'txt'
|
|
|
force_all: bool = False,
|
|
|
output_dir: Optional[str] = None,
|
|
|
+ debug_prefix: Optional[str] = None,
|
|
|
) -> List[str]:
|
|
|
"""
|
|
|
- 二次OCR统一封装:
|
|
|
- - 对空文本单元格裁剪图块并少量外扩
|
|
|
- - 对低置信度文本进行重识别
|
|
|
- - 对竖排单元格(高宽比大)进行旋转后识别
|
|
|
- - 对 OCR 误合并的单元格进行重识别(OCR box 跨多个单元格或过大)
|
|
|
- - [New] force_all=True: 强制对所有单元格进行裁剪识别 (Full-page OCR 作为 fallback)
|
|
|
- - [New] output_dir: 输出目录,如果提供则保存单元格OCR图片用于调试
|
|
|
-
|
|
|
- Args:
|
|
|
- table_image: 表格图像
|
|
|
- bboxes: 单元格坐标列表
|
|
|
- texts: 当前文本列表
|
|
|
- scores: 当前置信度列表
|
|
|
- need_reocr_indices: 需要二次 OCR 的单元格索引列表(OCR 误合并检测结果)
|
|
|
- pdf_type: str, # 'ocr' 或 'txt'
|
|
|
- force_all: 是否强制对所有单元格进行 OCR (Default: False)
|
|
|
- output_dir: 单元格 OCR 调试目录(通常为 debug/table_recognition_wired/tablecell_ocr/)
|
|
|
+ 二次OCR:分行 det+rec(低分块丢弃、长度加权置信度)+ 整格 det=False 兜底择优。
|
|
|
+ debug 图落盘至 output_dir/{debug_prefix}/cell{idx}_{text}.png
|
|
|
"""
|
|
|
try:
|
|
|
if not self.ocr_engine:
|
|
|
return texts
|
|
|
-
|
|
|
- # 如果没有传入 scores,则默认全为 1.0(仅处理空文本)
|
|
|
+
|
|
|
if scores is None:
|
|
|
scores = [1.0 if t else 0.0 for t in texts]
|
|
|
-
|
|
|
- # 如果没有传入 need_reocr_indices,初始化为空列表
|
|
|
if need_reocr_indices is None:
|
|
|
need_reocr_indices = []
|
|
|
|
|
|
- cell_ocr_dir = None
|
|
|
+ cell_ocr_dir: Optional[str] = None
|
|
|
if output_dir:
|
|
|
cell_ocr_dir = output_dir
|
|
|
+ if debug_prefix:
|
|
|
+ safe_prefix = self.sanitize_debug_filename(debug_prefix, max_length=120)
|
|
|
+ cell_ocr_dir = os.path.join(output_dir, safe_prefix)
|
|
|
os.makedirs(cell_ocr_dir, exist_ok=True)
|
|
|
|
|
|
h_img, w_img = table_image.shape[:2]
|
|
|
margin = self.cell_crop_margin
|
|
|
-
|
|
|
- # 触发二次OCR的阈值
|
|
|
- trigger_score_thresh = 0.90
|
|
|
+ trigger_score_thresh = 0.90
|
|
|
|
|
|
crop_list: List[np.ndarray] = []
|
|
|
crop_indices: List[int] = []
|
|
|
|
|
|
- # 收集需要二次OCR的裁剪块
|
|
|
for i, t in enumerate(texts):
|
|
|
bbox = bboxes[i]
|
|
|
w_box = bbox[2] - bbox[0]
|
|
|
h_box = bbox[3] - bbox[1]
|
|
|
-
|
|
|
- # 判断是否需要二次OCR
|
|
|
+
|
|
|
need_reocr = False
|
|
|
- reocr_reason = ""
|
|
|
-
|
|
|
if force_all:
|
|
|
need_reocr = True
|
|
|
- reocr_reason = "强制全量OCR"
|
|
|
- else:
|
|
|
- # 1. OCR 误合并:OCR box 跨多个单元格或过大, 跨单元格中的一个单元格的文本可能是''
|
|
|
- if i in need_reocr_indices:
|
|
|
- need_reocr = True
|
|
|
- reocr_reason = "OCR误合并"
|
|
|
- # 2. 文本为空且置信度不是极高
|
|
|
- elif (not t or not t.strip()) and scores[i] < 0.95:
|
|
|
- if pdf_type == 'txt':
|
|
|
- # PDF文本模式下,空文本不触发二次OCR
|
|
|
- need_reocr = False
|
|
|
- else:
|
|
|
- need_reocr = True
|
|
|
- reocr_reason = "空文本"
|
|
|
- # 3. 置信度过低
|
|
|
- elif scores[i] < trigger_score_thresh:
|
|
|
- need_reocr = True
|
|
|
- reocr_reason = "低置信度"
|
|
|
- # 4. 竖排单元格 (高宽比 > 2.5) 且置信度不是极高
|
|
|
- elif h_box > w_box * 2.5 and scores[i] < 0.95:
|
|
|
- need_reocr = True
|
|
|
- reocr_reason = "竖排文本"
|
|
|
-
|
|
|
- if not need_reocr:
|
|
|
- continue
|
|
|
-
|
|
|
- # if reocr_reason:
|
|
|
- # logger.debug(f"单元格 {i} 触发二次OCR: {reocr_reason} (文本: '{t[:30]}...')")
|
|
|
-
|
|
|
- if i >= len(bboxes):
|
|
|
+ elif i in need_reocr_indices:
|
|
|
+ need_reocr = True
|
|
|
+ elif (not t or not t.strip()) and scores[i] < 0.95:
|
|
|
+ need_reocr = pdf_type != 'txt'
|
|
|
+ elif scores[i] < trigger_score_thresh:
|
|
|
+ need_reocr = True
|
|
|
+ elif h_box > w_box * 2.5 and scores[i] < 0.95:
|
|
|
+ need_reocr = True
|
|
|
+
|
|
|
+ if not need_reocr or i >= len(bboxes):
|
|
|
continue
|
|
|
|
|
|
x1, y1, x2, y2 = map(int, bboxes[i])
|
|
|
@@ -556,186 +767,64 @@ class TextFiller:
|
|
|
continue
|
|
|
|
|
|
ch, cw = cell_img.shape[:2]
|
|
|
- # 小图放大
|
|
|
if ch < 64 or cw < 64:
|
|
|
- cell_img = cv2.resize(cell_img, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
|
|
|
- ch, cw = cell_img.shape[:2]
|
|
|
- logger.debug(f"单元格({texts[i] if i < len(texts) and len(texts[i]) else 'empty'}) {i} 裁剪后图像过小,放大至 {cw}x{ch} 像素")
|
|
|
-
|
|
|
- # 竖排文本旋转为横排
|
|
|
- # 由于表格已经是正视的,不需要再考虑旋转角度
|
|
|
- # if ch > cw * 2.0:
|
|
|
- # cell_img = cv2.rotate(cell_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
|
|
+ cell_img = cv2.resize(
|
|
|
+ cell_img, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC
|
|
|
+ )
|
|
|
+ logger.debug(
|
|
|
+ f"单元格 {i} 裁剪过小,放大至 {cell_img.shape[1]}x{cell_img.shape[0]} 像素"
|
|
|
+ )
|
|
|
|
|
|
crop_list.append(cell_img)
|
|
|
crop_indices.append(i)
|
|
|
|
|
|
if not crop_list:
|
|
|
return texts
|
|
|
-
|
|
|
- logger.info(f"触发二次OCR: {len(crop_list)} 个单元格 (总数 {len(texts)})")
|
|
|
-
|
|
|
- # 先批量检测文本块,再批量识别(提高效率)
|
|
|
- # Step 1: 批量检测
|
|
|
- det_results = []
|
|
|
- for cell_img in crop_list:
|
|
|
- try:
|
|
|
- det_res = self.ocr_engine.ocr(cell_img, det=True, rec=False)
|
|
|
- if det_res and len(det_res) > 0:
|
|
|
- dt_boxes = det_res[0]
|
|
|
- det_results.append(dt_boxes if dt_boxes else [])
|
|
|
- else:
|
|
|
- det_results.append([])
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"单元格文本检测失败: {e}")
|
|
|
- det_results.append([])
|
|
|
-
|
|
|
- # Step 2: 从检测框中裁剪图像并批量识别
|
|
|
- rec_img_list = []
|
|
|
- rec_indices = []
|
|
|
- for cell_idx, dt_boxes in enumerate(det_results):
|
|
|
- if not dt_boxes:
|
|
|
- continue
|
|
|
- cell_img = crop_list[cell_idx]
|
|
|
- h, w = cell_img.shape[:2]
|
|
|
-
|
|
|
- for box_idx, box in enumerate(dt_boxes):
|
|
|
- if not box or len(box) < 4:
|
|
|
- continue
|
|
|
- # 将检测框转换为bbox格式并裁剪
|
|
|
- if isinstance(box[0], (list, tuple)):
|
|
|
- # 多边形格式
|
|
|
- xs = [p[0] for p in box]
|
|
|
- ys = [p[1] for p in box]
|
|
|
- x1, y1 = int(max(0, min(xs))), int(max(0, min(ys)))
|
|
|
- x2, y2 = int(min(w, max(xs))), int(min(h, max(ys)))
|
|
|
- else:
|
|
|
- # bbox格式
|
|
|
- xs = [box[i] for i in range(0, len(box), 2)]
|
|
|
- ys = [box[i] for i in range(1, len(box), 2)]
|
|
|
- x1, y1 = int(max(0, min(xs))), int(max(0, min(ys)))
|
|
|
- x2, y2 = int(min(w, max(xs))), int(min(h, max(ys)))
|
|
|
-
|
|
|
- if x2 > x1 and y2 > y1:
|
|
|
- cropped = cell_img[y1:y2, x1:x2]
|
|
|
- ch, cw = cropped.shape[:2]
|
|
|
- if cropped.size > 0:
|
|
|
- rec_img_list.append(cropped)
|
|
|
- rec_indices.append((cell_idx, box_idx))
|
|
|
-
|
|
|
- # Step 3: 批量识别
|
|
|
- results = [[] for _ in crop_list]
|
|
|
- if rec_img_list:
|
|
|
- try:
|
|
|
- rec_res = self.ocr_engine.ocr(rec_img_list, det=False, rec=True)
|
|
|
- if rec_res and len(rec_res) > 0:
|
|
|
- rec_results = rec_res[0] if isinstance(rec_res[0], list) else rec_res
|
|
|
- # 将识别结果回填到对应的单元格
|
|
|
- for (cell_idx, box_idx), rec_item in zip(rec_indices, rec_results):
|
|
|
- if rec_item:
|
|
|
- if isinstance(rec_item, (list, tuple)) and len(rec_item) >= 2:
|
|
|
- text = str(rec_item[0] or "").strip()
|
|
|
- score = float(rec_item[1] or 0.0)
|
|
|
- if text:
|
|
|
- results[cell_idx].append((text, score))
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"批量识别失败: {e}")
|
|
|
-
|
|
|
- # 解析为 (text, score) - 支持合并多个文本块
|
|
|
- def _parse_item(res_item) -> Tuple[str, float]:
|
|
|
- if res_item is None:
|
|
|
- return "", 0.0
|
|
|
-
|
|
|
- # 列表形式:包含多个文本块,需要合并
|
|
|
- if isinstance(res_item, list) and len(res_item) > 0:
|
|
|
- texts_list = []
|
|
|
- scores_list = []
|
|
|
-
|
|
|
- for item in res_item:
|
|
|
- if isinstance(item, tuple) and len(item) >= 2:
|
|
|
- text = str(item[0] or "").strip()
|
|
|
- score = float(item[1] or 0.0)
|
|
|
- if text:
|
|
|
- texts_list.append(text)
|
|
|
- scores_list.append(score)
|
|
|
- elif isinstance(item, list) and len(item) >= 2:
|
|
|
- text = str(item[0] or "").strip()
|
|
|
- score = float(item[1] or 0.0)
|
|
|
- if text:
|
|
|
- texts_list.append(text)
|
|
|
- scores_list.append(score)
|
|
|
- elif isinstance(item, dict):
|
|
|
- text = str(item.get("text") or item.get("label") or "").strip()
|
|
|
- score = float(item.get("score") or item.get("confidence") or 0.0)
|
|
|
- if text:
|
|
|
- texts_list.append(text)
|
|
|
- scores_list.append(score)
|
|
|
-
|
|
|
- if texts_list:
|
|
|
- combined_text = "".join(texts_list)
|
|
|
- avg_score = sum(scores_list) / len(scores_list) if scores_list else 0.0
|
|
|
- return combined_text, avg_score
|
|
|
- return "", 0.0
|
|
|
-
|
|
|
- # 直接 (text, score)
|
|
|
- if isinstance(res_item, tuple) and len(res_item) >= 2:
|
|
|
- return str(res_item[0] or ""), float(res_item[1] or 0.0)
|
|
|
-
|
|
|
- # 字典形式
|
|
|
- if isinstance(res_item, dict):
|
|
|
- txt = str(res_item.get("text") or res_item.get("label") or "")
|
|
|
- sc = float(res_item.get("score") or res_item.get("confidence") or 0.0)
|
|
|
- return txt, sc
|
|
|
-
|
|
|
- return "", 0.0
|
|
|
|
|
|
- # 对齐长度,避免越界
|
|
|
- n = min(len(results) if isinstance(results, list) else 0, len(crop_list), len(crop_indices))
|
|
|
+ logger.info(f"触发二次OCR: {len(crop_list)} 个单元格 (总数 {len(texts)})")
|
|
|
base_conf_th = self.ocr_conf_threshold
|
|
|
+ line_min = self.second_pass_line_min_score
|
|
|
+ drop_low = self.second_pass_drop_low
|
|
|
|
|
|
- # 辅助函数:清理文件名中的非法字符
|
|
|
- def sanitize_filename(text: str, max_length: int = 50) -> str:
|
|
|
- """清理文件名,移除非法字符并限制长度"""
|
|
|
- if not text:
|
|
|
- return "empty"
|
|
|
- # 替换或删除文件名中的非法字符
|
|
|
- # Windows/Linux 文件名非法字符: / \ : * ? " < > |
|
|
|
- illegal_chars = r'[/\\:*?"<>|]'
|
|
|
- sanitized = re.sub(illegal_chars, '_', text)
|
|
|
- # 限制长度
|
|
|
- if len(sanitized) > max_length:
|
|
|
- sanitized = sanitized[:max_length]
|
|
|
- # 移除首尾空格和下划线
|
|
|
- sanitized = sanitized.strip('_').strip()
|
|
|
- return sanitized if sanitized else "empty"
|
|
|
-
|
|
|
- for k in range(n):
|
|
|
- text_k, score_k = _parse_item(results[k])
|
|
|
+ for k, cell_img in enumerate(crop_list):
|
|
|
cell_idx = crop_indices[k]
|
|
|
- cell_img = crop_list[k]
|
|
|
-
|
|
|
- # 保存单元格OCR图片用于调试
|
|
|
+
|
|
|
+ line_blocks = self._recognize_cell_lines(cell_img)
|
|
|
+ line_text, line_score = self.aggregate_line_ocr(
|
|
|
+ line_blocks,
|
|
|
+ line_min_score=line_min,
|
|
|
+ drop_low_score_blocks=drop_low,
|
|
|
+ )
|
|
|
+
|
|
|
+ whole_text, whole_score = ("", 0.0)
|
|
|
+ if self.second_pass_whole_fallback and line_score < base_conf_th:
|
|
|
+ whole_text, whole_score = self._recognize_whole_cell(cell_img)
|
|
|
+
|
|
|
+ final_text, final_score, strategy = self._pick_line_vs_whole(
|
|
|
+ line_text, line_score, whole_text, whole_score
|
|
|
+ )
|
|
|
+
|
|
|
if cell_ocr_dir and cell_img is not None:
|
|
|
try:
|
|
|
- # 生成文件名:序号_识别内容
|
|
|
- sanitized_text = sanitize_filename(text_k)
|
|
|
- filename = f"{cell_idx:03d}_{sanitized_text}.png"
|
|
|
- filepath = os.path.join(cell_ocr_dir, filename)
|
|
|
- cv2.imwrite(filepath, cell_img)
|
|
|
+ tag = self.sanitize_debug_filename(final_text or "empty")
|
|
|
+ filename = f"cell{cell_idx:03d}_{strategy}_{tag}.png"
|
|
|
+ cv2.imwrite(os.path.join(cell_ocr_dir, filename), cell_img)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"保存单元格OCR图片失败 (cell {cell_idx}): {e}")
|
|
|
-
|
|
|
- if text_k:
|
|
|
- # 根据文本长度动态调整置信度阈值
|
|
|
- dynamic_conf_th = self.calculate_dynamic_confidence_threshold(text_k, base_conf_th)
|
|
|
-
|
|
|
- if score_k >= dynamic_conf_th:
|
|
|
- texts[cell_idx] = text_k
|
|
|
- else:
|
|
|
- logger.debug(
|
|
|
- f"单元格 {cell_idx} 二次OCR结果置信度({score_k:.2f})低于动态阈值({dynamic_conf_th:.2f}) "
|
|
|
- f"[文本长度={len(text_k)}, 基准阈值={base_conf_th:.2f}]: '{text_k[:30]}...'"
|
|
|
- )
|
|
|
+
|
|
|
+ if not final_text:
|
|
|
+ continue
|
|
|
+
|
|
|
+ dynamic_conf_th = self.calculate_dynamic_confidence_threshold(
|
|
|
+ final_text, base_conf_th
|
|
|
+ )
|
|
|
+ if final_score >= dynamic_conf_th:
|
|
|
+ texts[cell_idx] = final_text
|
|
|
+ else:
|
|
|
+ logger.debug(
|
|
|
+ f"单元格 {cell_idx} 二次OCR({strategy}) 置信度({final_score:.2f}) "
|
|
|
+ f"低于动态阈值({dynamic_conf_th:.2f}): '{final_text[:30]}...'"
|
|
|
+ )
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"二次OCR失败: {e}")
|