#!/usr/bin/env python3 """cell121 参数扫描:去水印方式 / threshold / contrast / upscale / det 阈值 / 整格 rec。""" from __future__ import annotations import json import os import sys from itertools import product from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import cv2 import numpy as np _repo_root = Path(__file__).resolve().parents[2] if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config from ocr_utils.watermark.contrast import apply_contrast_enhancement_config CELL121 = Path( "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/" "bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/" "彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png" ) OUT_DIR = Path(__file__).parent / "output/彭_广东兴宁农村商业银行/cell121_sweep" MODEL_DIR = Path( "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/" "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch" ) TARGET = "20240927" def _upscale(img: np.ndarray, min_side: int) -> np.ndarray: h, w = img.shape[:2] if h >= min_side and w >= min_side: return img s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0) return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC) def _preprocess( raw: np.ndarray, *, method: str, thresh: Optional[int], contrast: bool, upscale: int, ) -> np.ndarray: user: Dict[str, Any] = {"enabled": True, "method": method} if method == "threshold" and thresh is not None: user["threshold"] = thresh cfg = merge_watermark_config("cell", user) img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True) if contrast: gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) ce = dict(cfg.get("contrast_enhancement") or {}) ce["enabled"] = True ce["text_black_target"] = 88 gray = apply_contrast_enhancement_config(gray, ce) img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) return _upscale(img, upscale) def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]: try: res = engine.ocr(img, det=det, rec=rec) texts: List[str] = [] if res and res[0]: if det: for item in res[0]: if item and len(item) >= 2 and item[1]: texts.append(str(item[1][0] or "")) else: for item in res[0]: if isinstance(item, (list, tuple)) and len(item) >= 1: texts.append(str(item[0] or "")) text = "".join(texts).strip() return { "text": text, "det": det, "rec": rec, "n_boxes": len(res[0]) if res and res[0] else 0, } except Exception as e: return {"text": "", "error": str(e), "det": det, "rec": rec} def _make_engine(det_thresh: float) -> Any: from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR return PytorchPaddleOCR( lang="ch", det_model_path=str(MODEL_DIR / "ch_PP-OCRv5_det_infer.pth"), rec_model_path=str(MODEL_DIR / "ch_PP-OCRv4_rec_server_doc_infer.pth"), det_db_box_thresh=det_thresh, ) def main() -> None: if not CELL121.is_file(): raise FileNotFoundError(CELL121) raw = cv2.imread(str(CELL121)) OUT_DIR.mkdir(parents=True, exist_ok=True) methods = ["threshold", "masked_adaptive"] thresholds = [155, 165, 170, 175, 180, None] contrasts = [False, True] upscales = [64, 96, 128, 192] det_threshs = [0.2, 0.3, 0.4, 0.5] ocr_modes = [("det_rec", True, True), ("whole_rec", False, True)] results: List[Dict[str, Any]] = [] hits: List[Dict[str, Any]] = [] engines: Dict[float, Any] = {} total = 0 for method, thresh, contrast, upscale, det_th in product( methods, thresholds, contrasts, upscales, det_threshs ): if method != "threshold" and thresh is not None: continue if det_th not in engines: print(f"加载 OCR det_db_box_thresh={det_th} ...") engines[det_th] = _make_engine(det_th) img = _preprocess( raw, method=method, thresh=thresh, contrast=contrast, upscale=upscale ) tag = ( f"{method}_t{thresh or 'd'}_c{int(contrast)}_u{upscale}_det{det_th}" ) cv2.imwrite(str(OUT_DIR / f"{tag}.png"), img) for mode_name, det, rec in ocr_modes: total += 1 ocr = _ocr(engines[det_th], img, det=det, rec=rec) row = { "tag": tag, "method": method, "threshold": thresh, "contrast": contrast, "upscale": upscale, "det_db_box_thresh": det_th, "ocr_mode": mode_name, **ocr, } results.append(row) t = row.get("text", "") if TARGET in t or (len(t) >= 6 and t.isdigit()): row["match"] = "full" if TARGET in t else "partial" hits.append(row) print(f"HIT [{row['match']}] {mode_name} {tag} -> {t!r}") # 原图对照 for det_th in [0.3, 0.5]: if det_th not in engines: engines[det_th] = _make_engine(det_th) for mode_name, det, rec in ocr_modes: ocr = _ocr(engines[det_th], _upscale(raw, 128), det=det, rec=rec) row = { "tag": "raw_upscale128", "det_db_box_thresh": det_th, "ocr_mode": mode_name, **ocr, } results.append(row) if TARGET in (row.get("text") or ""): hits.append(row) report = { "input": str(CELL121), "target": TARGET, "total_trials": total, "hits": hits, "all_results": results, } out_json = OUT_DIR / "cell121_sweep_report.json" out_json.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") print(f"\n完成 {total} 次 OCR 试验,命中 {len(hits)} 条") print(f"报告: {out_json}") if hits: print("\n最佳命中:") for h in hits[:10]: print(f" {h.get('ocr_mode')} {h.get('tag')}: {h.get('text')!r}") else: print("未出现完整 20240927,请查看 cell121_sweep/*.png 与 report 中 partial 结果") if __name__ == "__main__": main()