| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- #!/usr/bin/env python3
- """cell121 参数扫描:去水印方式 / threshold / contrast / upscale / det 阈值 / 整格 rec。"""
- from __future__ import annotations
- import json
- import os
- import sys
- from itertools import product
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- import cv2
- import numpy as np
- _repo_root = Path(__file__).resolve().parents[2]
- if str(_repo_root) not in sys.path:
- sys.path.insert(0, str(_repo_root))
- from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
- from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
- CELL121 = Path(
- "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/"
- "bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/"
- "彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png"
- )
- OUT_DIR = Path(__file__).parent / "output/彭_广东兴宁农村商业银行/cell121_sweep"
- MODEL_DIR = Path(
- "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
- "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
- )
- TARGET = "20240927"
- def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
- h, w = img.shape[:2]
- if h >= min_side and w >= min_side:
- return img
- s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
- return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
- def _preprocess(
- raw: np.ndarray,
- *,
- method: str,
- thresh: Optional[int],
- contrast: bool,
- upscale: int,
- ) -> np.ndarray:
- user: Dict[str, Any] = {"enabled": True, "method": method}
- if method == "threshold" and thresh is not None:
- user["threshold"] = thresh
- cfg = merge_watermark_config("cell", user)
- img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
- if contrast:
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- ce = dict(cfg.get("contrast_enhancement") or {})
- ce["enabled"] = True
- ce["text_black_target"] = 88
- gray = apply_contrast_enhancement_config(gray, ce)
- img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
- return _upscale(img, upscale)
- def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
- try:
- res = engine.ocr(img, det=det, rec=rec)
- texts: List[str] = []
- if res and res[0]:
- if det:
- for item in res[0]:
- if item and len(item) >= 2 and item[1]:
- texts.append(str(item[1][0] or ""))
- else:
- for item in res[0]:
- if isinstance(item, (list, tuple)) and len(item) >= 1:
- texts.append(str(item[0] or ""))
- text = "".join(texts).strip()
- return {
- "text": text,
- "det": det,
- "rec": rec,
- "n_boxes": len(res[0]) if res and res[0] else 0,
- }
- except Exception as e:
- return {"text": "", "error": str(e), "det": det, "rec": rec}
- def _make_engine(det_thresh: float) -> Any:
- from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
- return PytorchPaddleOCR(
- lang="ch",
- det_model_path=str(MODEL_DIR / "ch_PP-OCRv5_det_infer.pth"),
- rec_model_path=str(MODEL_DIR / "ch_PP-OCRv4_rec_server_doc_infer.pth"),
- det_db_box_thresh=det_thresh,
- )
- def main() -> None:
- if not CELL121.is_file():
- raise FileNotFoundError(CELL121)
- raw = cv2.imread(str(CELL121))
- OUT_DIR.mkdir(parents=True, exist_ok=True)
- methods = ["threshold", "masked_adaptive"]
- thresholds = [155, 165, 170, 175, 180, None]
- contrasts = [False, True]
- upscales = [64, 96, 128, 192]
- det_threshs = [0.2, 0.3, 0.4, 0.5]
- ocr_modes = [("det_rec", True, True), ("whole_rec", False, True)]
- results: List[Dict[str, Any]] = []
- hits: List[Dict[str, Any]] = []
- engines: Dict[float, Any] = {}
- total = 0
- for method, thresh, contrast, upscale, det_th in product(
- methods, thresholds, contrasts, upscales, det_threshs
- ):
- if method != "threshold" and thresh is not None:
- continue
- if det_th not in engines:
- print(f"加载 OCR det_db_box_thresh={det_th} ...")
- engines[det_th] = _make_engine(det_th)
- img = _preprocess(
- raw, method=method, thresh=thresh, contrast=contrast, upscale=upscale
- )
- tag = (
- f"{method}_t{thresh or 'd'}_c{int(contrast)}_u{upscale}_det{det_th}"
- )
- cv2.imwrite(str(OUT_DIR / f"{tag}.png"), img)
- for mode_name, det, rec in ocr_modes:
- total += 1
- ocr = _ocr(engines[det_th], img, det=det, rec=rec)
- row = {
- "tag": tag,
- "method": method,
- "threshold": thresh,
- "contrast": contrast,
- "upscale": upscale,
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- t = row.get("text", "")
- if TARGET in t or (len(t) >= 6 and t.isdigit()):
- row["match"] = "full" if TARGET in t else "partial"
- hits.append(row)
- print(f"HIT [{row['match']}] {mode_name} {tag} -> {t!r}")
- # 原图对照
- for det_th in [0.3, 0.5]:
- if det_th not in engines:
- engines[det_th] = _make_engine(det_th)
- for mode_name, det, rec in ocr_modes:
- ocr = _ocr(engines[det_th], _upscale(raw, 128), det=det, rec=rec)
- row = {
- "tag": "raw_upscale128",
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- if TARGET in (row.get("text") or ""):
- hits.append(row)
- report = {
- "input": str(CELL121),
- "target": TARGET,
- "total_trials": total,
- "hits": hits,
- "all_results": results,
- }
- out_json = OUT_DIR / "cell121_sweep_report.json"
- out_json.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f"\n完成 {total} 次 OCR 试验,命中 {len(hits)} 条")
- print(f"报告: {out_json}")
- if hits:
- print("\n最佳命中:")
- for h in hits[:10]:
- print(f" {h.get('ocr_mode')} {h.get('tag')}: {h.get('text')!r}")
- else:
- print("未出现完整 20240927,请查看 cell121_sweep/*.png 与 report 中 partial 结果")
- if __name__ == "__main__":
- main()
|