cell121_sweep.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #!/usr/bin/env python3
  2. """cell121 参数扫描:去水印方式 / threshold / contrast / upscale / det 阈值 / 整格 rec。"""
  3. from __future__ import annotations
  4. import json
  5. import os
  6. import sys
  7. from itertools import product
  8. from pathlib import Path
  9. from typing import Any, Dict, List, Optional, Tuple
  10. import cv2
  11. import numpy as np
  12. _repo_root = Path(__file__).resolve().parents[2]
  13. if str(_repo_root) not in sys.path:
  14. sys.path.insert(0, str(_repo_root))
  15. from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
  16. from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
  17. CELL121 = Path(
  18. "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/"
  19. "bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/"
  20. "彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png"
  21. )
  22. OUT_DIR = Path(__file__).parent / "output/彭_广东兴宁农村商业银行/cell121_sweep"
  23. MODEL_DIR = Path(
  24. "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
  25. "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
  26. )
  27. TARGET = "20240927"
  28. def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
  29. h, w = img.shape[:2]
  30. if h >= min_side and w >= min_side:
  31. return img
  32. s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
  33. return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
  34. def _preprocess(
  35. raw: np.ndarray,
  36. *,
  37. method: str,
  38. thresh: Optional[int],
  39. contrast: bool,
  40. upscale: int,
  41. ) -> np.ndarray:
  42. user: Dict[str, Any] = {"enabled": True, "method": method}
  43. if method == "threshold" and thresh is not None:
  44. user["threshold"] = thresh
  45. cfg = merge_watermark_config("cell", user)
  46. img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
  47. if contrast:
  48. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  49. ce = dict(cfg.get("contrast_enhancement") or {})
  50. ce["enabled"] = True
  51. ce["text_black_target"] = 88
  52. gray = apply_contrast_enhancement_config(gray, ce)
  53. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  54. return _upscale(img, upscale)
  55. def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
  56. try:
  57. res = engine.ocr(img, det=det, rec=rec)
  58. texts: List[str] = []
  59. if res and res[0]:
  60. if det:
  61. for item in res[0]:
  62. if item and len(item) >= 2 and item[1]:
  63. texts.append(str(item[1][0] or ""))
  64. else:
  65. for item in res[0]:
  66. if isinstance(item, (list, tuple)) and len(item) >= 1:
  67. texts.append(str(item[0] or ""))
  68. text = "".join(texts).strip()
  69. return {
  70. "text": text,
  71. "det": det,
  72. "rec": rec,
  73. "n_boxes": len(res[0]) if res and res[0] else 0,
  74. }
  75. except Exception as e:
  76. return {"text": "", "error": str(e), "det": det, "rec": rec}
  77. def _make_engine(det_thresh: float) -> Any:
  78. from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
  79. return PytorchPaddleOCR(
  80. lang="ch",
  81. det_model_path=str(MODEL_DIR / "ch_PP-OCRv5_det_infer.pth"),
  82. rec_model_path=str(MODEL_DIR / "ch_PP-OCRv4_rec_server_doc_infer.pth"),
  83. det_db_box_thresh=det_thresh,
  84. )
  85. def main() -> None:
  86. if not CELL121.is_file():
  87. raise FileNotFoundError(CELL121)
  88. raw = cv2.imread(str(CELL121))
  89. OUT_DIR.mkdir(parents=True, exist_ok=True)
  90. methods = ["threshold", "masked_adaptive"]
  91. thresholds = [155, 165, 170, 175, 180, None]
  92. contrasts = [False, True]
  93. upscales = [64, 96, 128, 192]
  94. det_threshs = [0.2, 0.3, 0.4, 0.5]
  95. ocr_modes = [("det_rec", True, True), ("whole_rec", False, True)]
  96. results: List[Dict[str, Any]] = []
  97. hits: List[Dict[str, Any]] = []
  98. engines: Dict[float, Any] = {}
  99. total = 0
  100. for method, thresh, contrast, upscale, det_th in product(
  101. methods, thresholds, contrasts, upscales, det_threshs
  102. ):
  103. if method != "threshold" and thresh is not None:
  104. continue
  105. if det_th not in engines:
  106. print(f"加载 OCR det_db_box_thresh={det_th} ...")
  107. engines[det_th] = _make_engine(det_th)
  108. img = _preprocess(
  109. raw, method=method, thresh=thresh, contrast=contrast, upscale=upscale
  110. )
  111. tag = (
  112. f"{method}_t{thresh or 'd'}_c{int(contrast)}_u{upscale}_det{det_th}"
  113. )
  114. cv2.imwrite(str(OUT_DIR / f"{tag}.png"), img)
  115. for mode_name, det, rec in ocr_modes:
  116. total += 1
  117. ocr = _ocr(engines[det_th], img, det=det, rec=rec)
  118. row = {
  119. "tag": tag,
  120. "method": method,
  121. "threshold": thresh,
  122. "contrast": contrast,
  123. "upscale": upscale,
  124. "det_db_box_thresh": det_th,
  125. "ocr_mode": mode_name,
  126. **ocr,
  127. }
  128. results.append(row)
  129. t = row.get("text", "")
  130. if TARGET in t or (len(t) >= 6 and t.isdigit()):
  131. row["match"] = "full" if TARGET in t else "partial"
  132. hits.append(row)
  133. print(f"HIT [{row['match']}] {mode_name} {tag} -> {t!r}")
  134. # 原图对照
  135. for det_th in [0.3, 0.5]:
  136. if det_th not in engines:
  137. engines[det_th] = _make_engine(det_th)
  138. for mode_name, det, rec in ocr_modes:
  139. ocr = _ocr(engines[det_th], _upscale(raw, 128), det=det, rec=rec)
  140. row = {
  141. "tag": "raw_upscale128",
  142. "det_db_box_thresh": det_th,
  143. "ocr_mode": mode_name,
  144. **ocr,
  145. }
  146. results.append(row)
  147. if TARGET in (row.get("text") or ""):
  148. hits.append(row)
  149. report = {
  150. "input": str(CELL121),
  151. "target": TARGET,
  152. "total_trials": total,
  153. "hits": hits,
  154. "all_results": results,
  155. }
  156. out_json = OUT_DIR / "cell121_sweep_report.json"
  157. out_json.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
  158. print(f"\n完成 {total} 次 OCR 试验,命中 {len(hits)} 条")
  159. print(f"报告: {out_json}")
  160. if hits:
  161. print("\n最佳命中:")
  162. for h in hits[:10]:
  163. print(f" {h.get('ocr_mode')} {h.get('tag')}: {h.get('text')!r}")
  164. else:
  165. print("未出现完整 20240927,请查看 cell121_sweep/*.png 与 report 中 partial 结果")
  166. if __name__ == "__main__":
  167. main()