cell_sweep.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971
  1. #!/usr/bin/env python3
  2. """
  3. 单元格裁剪图预处理参数扫描:去水印 / contrast(clahe/gamma/linear/text_restore)/ upscale / det 阈值 / OCR 模式。
  4. 支持 contrast 在放大前/后执行两种顺序对比。
  5. 默认从 **原图**(`*_raw.png`)出发,与 pipeline 二次 OCR 一致,避免对已预处理 debug 图二次去水印。
  6. 用法:
  7. python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
  8. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out
  9. python cell_sweep.py cell.png --quick --no-save-images
  10. python cell_sweep.py cell.png --contrast-orders before_upscale,after_upscale
  11. OCR_DET_MODEL_PATH=... OCR_REC_MODEL_PATH=... python cell_sweep.py cell.png
  12. # 统计出的最优参数 tag: threshold_t150_cl_1.0_8_ob_u128_det0.5
  13. # 对目录下所有 *_raw.png 验证适配性
  14. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only
  15. # 自定义最优参数
  16. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only \
  17. --best-config threshold_t150_cl_1.0_8_ob_u128_det0.5
  18. # 指定目标文字,自动统计 HIT 命中率
  19. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only -t "交易类型"
  20. """
  21. from __future__ import annotations
  22. import argparse
  23. import json
  24. import os
  25. import sys
  26. from itertools import product
  27. from pathlib import Path
  28. from typing import Any, Dict, List, Optional, Sequence, Tuple
  29. import cv2
  30. import numpy as np
  31. _repo_root = Path(__file__).resolve().parents[3]
  32. if str(_repo_root) not in sys.path:
  33. sys.path.insert(0, str(_repo_root))
  34. from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
  35. from ocr_utils.watermark.contrast import enhance_document_contrast
  36. _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  37. _DEFAULT_MODEL_DIR = Path(
  38. "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
  39. "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
  40. )
  41. def _parse_csv_ints(s: str) -> List[Optional[int]]:
  42. out: List[Optional[int]] = []
  43. for part in s.split(","):
  44. part = part.strip()
  45. if not part or part.lower() in ("none", "d", "default"):
  46. out.append(None)
  47. else:
  48. out.append(int(part))
  49. return out
  50. def _parse_csv_floats(s: str) -> List[float]:
  51. return [float(x.strip()) for x in s.split(",") if x.strip()]
  52. def _parse_csv_bools(s: str) -> List[bool]:
  53. out: List[bool] = []
  54. for part in s.split(","):
  55. p = part.strip().lower()
  56. if p in ("1", "true", "yes", "on"):
  57. out.append(True)
  58. elif p in ("0", "false", "no", "off"):
  59. out.append(False)
  60. else:
  61. raise ValueError(f"无效的 bool 值: {part!r}")
  62. return out
  63. def _default_model_dir() -> Path:
  64. det = os.environ.get("OCR_DET_MODEL_PATH")
  65. if det:
  66. return Path(det).parent
  67. return _DEFAULT_MODEL_DIR
  68. def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
  69. h, w = img.shape[:2]
  70. if h >= min_side and w >= min_side:
  71. return img
  72. s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
  73. return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
  74. # ── 对比度增强方法(clahe / gamma / linear / text_restore / none)──
  75. def _apply_contrast(
  76. gray: np.ndarray,
  77. *,
  78. method: str,
  79. clip_limit: float = 1.0,
  80. tile_grid_size: int = 8,
  81. gamma: float = 0.85,
  82. black_percentile: float = 2.0,
  83. white_percentile: float = 98.0,
  84. text_black_target: int = 85,
  85. background_threshold: int = 248,
  86. ) -> np.ndarray:
  87. """对灰度图应用对比度增强;method="none" 时原样返回。"""
  88. if method == "none":
  89. return gray
  90. if method == "text_restore":
  91. return enhance_document_contrast(
  92. gray, method="text_restore",
  93. text_black_target=text_black_target,
  94. background_threshold=background_threshold,
  95. )
  96. if method == "clahe":
  97. return enhance_document_contrast(
  98. gray, method="clahe",
  99. clip_limit=clip_limit, tile_grid_size=tile_grid_size,
  100. )
  101. if method == "gamma":
  102. return enhance_document_contrast(gray, method="gamma", gamma=gamma)
  103. if method == "linear":
  104. return enhance_document_contrast(
  105. gray, method="linear",
  106. black_percentile=black_percentile,
  107. white_percentile=white_percentile,
  108. )
  109. return gray
  110. def _contrast_tag(cfg: Dict[str, Any]) -> str:
  111. """生成 contrast 配置的短标签。"""
  112. m = cfg.get("method", "none")
  113. if m == "none":
  114. return "c0"
  115. if m == "text_restore":
  116. return f"tr_{cfg.get('text_black_target', 85)}"
  117. if m == "clahe":
  118. return f"cl_{cfg.get('clip_limit', 1.0)}_{cfg.get('tile_grid_size', 8)}"
  119. if m == "gamma":
  120. return f"gm_{cfg.get('gamma', 0.85)}"
  121. if m == "linear":
  122. return f"ln_{cfg.get('black_percentile', 2.0)}_{cfg.get('white_percentile', 98.0)}"
  123. return m
  124. def _build_contrast_grid(quick: bool = False) -> List[Dict[str, Any]]:
  125. """构建 contrast 参数网格(对齐 contrast_sweep.py 的设计)。
  126. 返回列表,每个元素是一个 Dict,至少包含 "method" 字段。
  127. """
  128. grid: List[Dict[str, Any]] = [{"method": "none"}] # 对照组:不增强
  129. # text_restore
  130. if quick:
  131. tbt = [60, 85]
  132. bts = [240, 248]
  133. else:
  134. tbt = [60, 85, 100, 120]
  135. bts = [240, 248, 252]
  136. for target, bg_th in product(tbt, bts):
  137. grid.append({"method": "text_restore", "text_black_target": target, "background_threshold": bg_th})
  138. # clahe
  139. if quick:
  140. cl = [1.0, 2.0]
  141. ts = [4, 8]
  142. else:
  143. cl = [0.5, 1.0, 2.0, 3.0, 5.0]
  144. ts = [4, 8]
  145. for clip, tile in product(cl, ts):
  146. grid.append({"method": "clahe", "clip_limit": clip, "tile_grid_size": tile})
  147. # # gamma
  148. # if quick:
  149. # gvs = [0.5, 0.85]
  150. # else:
  151. # gvs = [0.4, 0.55, 0.7, 0.85]
  152. # for g in gvs:
  153. # grid.append({"method": "gamma", "gamma": g})
  154. # # linear
  155. # if quick:
  156. # bps = [2.0, 5.0]
  157. # wps = [95.0, 98.0]
  158. # else:
  159. # bps = [2.0, 5.0, 8.0]
  160. # wps = [95.0, 98.0]
  161. # for bp, wp in product(bps, wps):
  162. # grid.append({"method": "linear", "black_percentile": bp, "white_percentile": wp})
  163. return grid
  164. def _preprocess(
  165. raw: np.ndarray,
  166. *,
  167. method: str,
  168. thresh: Optional[int],
  169. contrast_cfg: Dict[str, Any],
  170. upscale: int,
  171. contrast_order: str = "before_upscale",
  172. ) -> np.ndarray:
  173. """预处理管线:去水印 → [contrast] → 放大(或去水印 → 放大 → contrast)。
  174. method="none" 时跳过去水印,直接从原图开始处理。
  175. """
  176. if method == "none":
  177. img = raw.copy() # 不处理水印,直接使用原图
  178. else:
  179. user: Dict[str, Any] = {"enabled": True, "method": method}
  180. if method == "threshold" and thresh is not None:
  181. user["threshold"] = thresh
  182. cfg = merge_watermark_config("cell", user)
  183. img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
  184. contrast_method = contrast_cfg.get("method", "none")
  185. if contrast_method != "none" and contrast_order == "before_upscale":
  186. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  187. gray = _apply_contrast(gray, **contrast_cfg)
  188. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  189. img = _upscale(img, upscale)
  190. if contrast_method != "none" and contrast_order == "after_upscale":
  191. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  192. gray = _apply_contrast(gray, **contrast_cfg)
  193. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  194. return img
  195. def _parse_rec_pair(rec_part: Any) -> Tuple[str, float]:
  196. """从 OCR 返回的 (text, score) 或嵌套结构中解析识别结果。"""
  197. if rec_part is None:
  198. return "", 0.0
  199. if isinstance(rec_part, (list, tuple)) and len(rec_part) >= 2:
  200. if isinstance(rec_part[0], (list, tuple, dict)):
  201. return "", 0.0
  202. txt = str(rec_part[0] or "").strip()
  203. try:
  204. sc = float(rec_part[1] or 0.0)
  205. except (TypeError, ValueError):
  206. sc = 0.0
  207. return txt, sc if txt else 0.0
  208. if isinstance(rec_part, (list, tuple)) and len(rec_part) == 1:
  209. txt = str(rec_part[0] or "").strip()
  210. return txt, 0.0
  211. return "", 0.0
  212. def _aggregate_rec_score(boxes: List[Dict[str, Any]]) -> float:
  213. """按字符数加权平均识别分(与 pipeline aggregate_line_ocr 一致)。"""
  214. total_len = sum(len(b.get("text") or "") for b in boxes)
  215. if total_len <= 0:
  216. return 0.0
  217. weighted = sum(
  218. len(b.get("text") or "") * float(b.get("score") or 0.0) for b in boxes
  219. )
  220. return weighted / total_len
  221. def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
  222. empty: Dict[str, Any] = {
  223. "text": "",
  224. "score": 0.0,
  225. "boxes": [],
  226. "det": det,
  227. "rec": rec,
  228. "n_boxes": 0,
  229. }
  230. try:
  231. res = engine.ocr(img, det=det, rec=rec)
  232. items = res[0] if res and res[0] is not None else []
  233. boxes_out: List[Dict[str, Any]] = []
  234. if det:
  235. for item in items:
  236. if not item or len(item) < 2:
  237. continue
  238. text, score = _parse_rec_pair(item[1])
  239. bbox = item[0]
  240. if hasattr(bbox, "tolist"):
  241. bbox = bbox.tolist()
  242. entry: Dict[str, Any] = {
  243. "text": text,
  244. "score": round(score, 6),
  245. }
  246. if bbox is not None:
  247. entry["det_bbox"] = bbox
  248. boxes_out.append(entry)
  249. else:
  250. for item in items:
  251. text, score = _parse_rec_pair(item)
  252. if not text and isinstance(item, (list, tuple)) and len(item) >= 1:
  253. text, score = _parse_rec_pair(item[0])
  254. boxes_out.append({"text": text, "score": round(score, 6)})
  255. text = "".join(b["text"] for b in boxes_out if b.get("text")).strip()
  256. agg_score = _aggregate_rec_score(boxes_out)
  257. return {
  258. "text": text,
  259. "score": round(agg_score, 6),
  260. "boxes": boxes_out,
  261. "det": det,
  262. "rec": rec,
  263. "n_boxes": len(boxes_out),
  264. }
  265. except Exception as e:
  266. out = dict(empty)
  267. out["error"] = str(e)
  268. return out
  269. def _make_engine(det_thresh: float, model_dir: Path) -> Any:
  270. from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
  271. det_path = os.environ.get("OCR_DET_MODEL_PATH") or str(
  272. model_dir / "ch_PP-OCRv5_det_infer.pth"
  273. )
  274. rec_path = os.environ.get("OCR_REC_MODEL_PATH") or str(
  275. model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
  276. )
  277. return PytorchPaddleOCR(
  278. lang="ch",
  279. det_model_path=det_path,
  280. rec_model_path=rec_path,
  281. det_db_box_thresh=det_thresh,
  282. )
  283. def resolve_input_image(path: Path, *, prefer_raw: bool) -> Path:
  284. """优先使用与 pipeline debug 配套的 *_raw.png。"""
  285. if not prefer_raw or path.stem.endswith("_raw"):
  286. return path
  287. raw_path = path.parent / f"{path.stem}_raw{path.suffix}"
  288. if raw_path.is_file():
  289. print(f" 使用原图: {raw_path.name}(跳过 {path.name})")
  290. return raw_path
  291. return path
  292. def collect_inputs(path: Path, *, prefer_raw: bool) -> List[Path]:
  293. if path.is_file():
  294. if path.suffix.lower() not in _IMAGE_SUFFIXES:
  295. raise ValueError(f"不支持的图像格式: {path}")
  296. return [resolve_input_image(path, prefer_raw=prefer_raw)]
  297. if not path.is_dir():
  298. raise FileNotFoundError(path)
  299. all_images = sorted(
  300. p
  301. for p in path.iterdir()
  302. if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
  303. )
  304. if not all_images:
  305. raise FileNotFoundError(f"目录内无图像: {path}")
  306. if prefer_raw:
  307. raws = [p for p in all_images if p.stem.endswith("_raw")]
  308. if raws:
  309. return raws
  310. chosen: List[Path] = []
  311. for p in all_images:
  312. if p.stem.endswith("_raw"):
  313. continue
  314. raw_sibling = p.parent / f"{p.stem}_raw{p.suffix}"
  315. if prefer_raw and raw_sibling.is_file():
  316. continue
  317. chosen.append(p)
  318. return chosen or all_images
  319. def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
  320. if not text:
  321. return None
  322. if not target:
  323. return "nonempty"
  324. if target in text:
  325. return "full"
  326. if len(target) >= 6 and target.isdigit() and len(text) >= 6 and text.isdigit():
  327. return "partial"
  328. return None
  329. def run_sweep(
  330. input_path: Path,
  331. out_dir: Path,
  332. *,
  333. prefer_raw: bool,
  334. target: Optional[str],
  335. model_dir: Path,
  336. methods: Sequence[str],
  337. thresholds: Sequence[Optional[int]],
  338. contrast_grid: List[Dict[str, Any]],
  339. contrast_orders: Sequence[str],
  340. upscales: Sequence[int],
  341. det_threshs: Sequence[float],
  342. save_images: bool,
  343. run_baseline: bool,
  344. baseline_upscale: int,
  345. ) -> Dict[str, Any]:
  346. resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
  347. raw = cv2.imread(str(resolved))
  348. if raw is None:
  349. raise RuntimeError(f"无法读取图像: {resolved}")
  350. stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
  351. cell_out = out_dir / stem
  352. cell_out.mkdir(parents=True, exist_ok=True)
  353. ocr_modes: List[Tuple[str, bool, bool]] = [
  354. ("det_rec", True, True),
  355. ("whole_rec", False, True),
  356. ]
  357. results: List[Dict[str, Any]] = []
  358. hits: List[Dict[str, Any]] = []
  359. engines: Dict[float, Any] = {}
  360. total = 0
  361. for method, thresh, contrast_cfg, c_order, upscale, det_th in product(
  362. methods, thresholds, contrast_grid, contrast_orders, upscales, det_threshs
  363. ):
  364. # 过滤无效组合:非 threshold 方法不需要阈值
  365. if method not in ("threshold",):
  366. if thresh is not None:
  367. continue
  368. if det_th not in engines:
  369. print(f" [{stem}] 加载 OCR det_db_box_thresh={det_th} ...")
  370. engines[det_th] = _make_engine(det_th, model_dir)
  371. img = _preprocess(
  372. raw,
  373. method=method,
  374. thresh=thresh,
  375. contrast_cfg=contrast_cfg,
  376. upscale=upscale,
  377. contrast_order=c_order,
  378. )
  379. c_tag = _contrast_tag(contrast_cfg)
  380. o_tag = "b" if c_order == "before_upscale" else "a"
  381. tag = f"{method}_t{thresh or 'd'}_{c_tag}_o{o_tag}_u{upscale}_det{det_th}"
  382. if save_images:
  383. cv2.imwrite(str(cell_out / f"{tag}.png"), img)
  384. for mode_name, det, rec in ocr_modes:
  385. total += 1
  386. ocr = _ocr(engines[det_th], img, det=det, rec=rec)
  387. row: Dict[str, Any] = {
  388. "tag": tag,
  389. "method": method,
  390. "threshold": thresh,
  391. "contrast_method": contrast_cfg.get("method", "none"),
  392. "contrast_order": c_order,
  393. "contrast_cfg": contrast_cfg,
  394. "upscale": upscale,
  395. "det_db_box_thresh": det_th,
  396. "ocr_mode": mode_name,
  397. **ocr,
  398. }
  399. results.append(row)
  400. m = _match_hit(row.get("text", ""), target)
  401. if m:
  402. row["match"] = m
  403. hits.append(row)
  404. print(
  405. f" HIT [{m}] {mode_name} {tag} "
  406. f"score={row.get('score')} -> {row.get('text')!r}"
  407. )
  408. if run_baseline:
  409. for det_th in det_threshs:
  410. if det_th not in engines:
  411. engines[det_th] = _make_engine(det_th, model_dir)
  412. base_img = _upscale(raw, baseline_upscale)
  413. if save_images:
  414. cv2.imwrite(str(cell_out / f"baseline_upscale{baseline_upscale}.png"), base_img)
  415. for mode_name, det, rec in ocr_modes:
  416. ocr = _ocr(engines[det_th], base_img, det=det, rec=rec)
  417. row = {
  418. "tag": f"baseline_upscale{baseline_upscale}",
  419. "det_db_box_thresh": det_th,
  420. "ocr_mode": mode_name,
  421. **ocr,
  422. }
  423. results.append(row)
  424. m = _match_hit(row.get("text", ""), target)
  425. if m:
  426. row["match"] = m
  427. hits.append(row)
  428. report = {
  429. "input": str(resolved),
  430. "input_requested": str(input_path),
  431. "output_dir": str(cell_out),
  432. "target": target,
  433. "total_trials": total,
  434. "hits": hits,
  435. "all_results": results,
  436. }
  437. report_path = cell_out / "sweep_report.json"
  438. report_path.write_text(
  439. json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
  440. )
  441. # ── 结论报告:按 OCR score 排序,分组对比 ──
  442. _print_conclusions(stem, results, target)
  443. return report
  444. def _print_conclusions(
  445. stem: str,
  446. results: List[Dict[str, Any]],
  447. target: Optional[str],
  448. ) -> None:
  449. """打印实验结论:按 OCR score 排序,分组展示最优组合。"""
  450. if not results:
  451. return
  452. print(f"\n{'='*70}")
  453. print(f" 实验结论: {stem}")
  454. if target:
  455. print(f" 目标文字: {target}")
  456. print(f"{'='*70}")
  457. # 取 det_rec 模式的结果(优先用检测+识别完整结果)
  458. dr_results = [r for r in results if r.get("ocr_mode") == "det_rec" and r.get("text")]
  459. if not dr_results:
  460. dr_results = [r for r in results if r.get("text")]
  461. if not dr_results:
  462. print(" (无有效 OCR 结果)")
  463. return
  464. # ── 1. 全局 Top-5 ──
  465. scored = sorted(dr_results, key=lambda r: -(r.get("score") or 0))
  466. print("\n 全局 OCR 得分 Top-5:")
  467. for i, r in enumerate(scored[:5], 1):
  468. print(f" {i}. score={r.get('score', 0):.4f} text={r.get('text', '')!r}")
  469. print(f" tag={r.get('tag', '')}")
  470. # ── 2. 按 contrast 方法分组最佳 ──
  471. print("\n 按 contrast 方法分组最优(score 最高):")
  472. groups: Dict[str, List[Dict[str, Any]]] = {}
  473. for r in scored:
  474. cm = r.get("contrast_method", "?")
  475. groups.setdefault(cm, []).append(r)
  476. for cm in sorted(groups.keys()):
  477. best = groups[cm][0]
  478. wm = best.get("method", "?")
  479. print(f" [{cm}] 最佳: score={best.get('score', 0):.4f} "
  480. f"wm={wm} upscale={best.get('upscale')} "
  481. f"text={best.get('text', '')!r}")
  482. # ── 3. 有 watermark 处理 vs 无 watermark 处理对比 ──
  483. print("\n 去水印开关对比(同 contrast 方法,最高 score):")
  484. wm_groups: Dict[str, Dict[str, Any]] = {}
  485. for r in scored:
  486. cm = r.get("contrast_method", "?")
  487. wm = r.get("method", "?") if r.get("method") != "none" else "无去水印"
  488. key = f"{cm}|{wm}"
  489. cur_score = r.get("score") or 0
  490. prev_score = (wm_groups.get(key) or {}).get("score") or 0
  491. if key not in wm_groups or cur_score > prev_score:
  492. wm_groups[key] = r
  493. for cm in sorted(set(r.get("contrast_method", "?") for r in scored)):
  494. wm_rows = [r for k, r in wm_groups.items() if k.startswith(cm + "|")]
  495. if wm_rows:
  496. best_row = max(wm_rows, key=lambda r: r.get("score") or 0)
  497. wm_label = "无去水印" if best_row.get("method") == "none" else best_row.get("method", "?")
  498. print(f" [{cm}] 最优: wm={wm_label} score={best_row.get('score', 0):.4f} "
  499. f"text={best_row.get('text', '')!r}")
  500. # ── 4. 放大顺序对比 ──
  501. print("\n 放大前/后对比(同方法,最高 score):")
  502. order_data: Dict[str, Dict[str, Any]] = {}
  503. for r in scored:
  504. cm = r.get("contrast_method", "?")
  505. co = r.get("contrast_order", "?")
  506. key = f"{cm}|{co}"
  507. cur_score = r.get("score") or 0
  508. prev_score = (order_data.get(key) or {}).get("score") or 0
  509. if key not in order_data or cur_score > prev_score:
  510. order_data[key] = r
  511. for cm in sorted(set(r.get("contrast_method", "?") for r in scored)):
  512. b_score = (order_data.get(f"{cm}|before_upscale") or {}).get("score") or 0
  513. a_score = (order_data.get(f"{cm}|after_upscale") or {}).get("score") or 0
  514. better = "放大前" if b_score > a_score else ("放大后" if a_score > b_score else "持平")
  515. if b_score or a_score:
  516. print(f" [{cm}] 放大前={b_score:.4f} 放大后={a_score:.4f} 更优: {better}")
  517. # ── 5. HIT 命中率统计 ──
  518. if target:
  519. hit_count = sum(1 for r in results if r.get("match"))
  520. hit_by_cm: Dict[str, int] = {}
  521. for r in results:
  522. if r.get("match"):
  523. cm = r.get("contrast_method", "?")
  524. hit_by_cm[cm] = hit_by_cm.get(cm, 0) + 1
  525. print(f"\n HIT 命中率 (target={target}): {hit_count}/{len(results)}")
  526. for cm in sorted(hit_by_cm.keys()):
  527. print(f" [{cm}] HIT={hit_by_cm[cm]}")
  528. print(f"{'='*70}\n")
  529. def _parse_best_config(tag: str) -> Dict[str, Any]:
  530. """解析最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5。
  531. tag 格式: {method}_t{thresh}_{c_tag}_o{b|a}_u{upscale}_det{det_th}
  532. """
  533. import re
  534. cfg: Dict[str, Any] = {}
  535. tag = tag.strip()
  536. # 解析 method: threshold | masked_adaptive | none
  537. m = re.match(r"(threshold|masked_adaptive|none)_t(\w+?)_(.+?)_o([ba])_u(\d+)_det([\d.]+)$", tag)
  538. if not m:
  539. raise ValueError(f"无法解析 best-config tag: {tag!r}")
  540. method, thresh_str, c_part, order_char, upscale, det_th = m.groups()
  541. cfg["method"] = method
  542. cfg["threshold"] = int(thresh_str) if thresh_str.isdigit() else None
  543. cfg["contrast_order"] = "before_upscale" if order_char == "b" else "after_upscale"
  544. cfg["upscale"] = int(upscale)
  545. cfg["det_db_box_thresh"] = float(det_th)
  546. # 解析 contrast 部分: cl_1.0_8 | tr_85 | gm_0.85 | ln_2.0_98.0 | c0
  547. if c_part == "c0":
  548. cfg["contrast_cfg"] = {"method": "none"}
  549. elif c_part.startswith("cl_"):
  550. parts = c_part.split("_")
  551. cfg["contrast_cfg"] = {"method": "clahe", "clip_limit": float(parts[1]), "tile_grid_size": int(parts[2])}
  552. elif c_part.startswith("tr_"):
  553. parts = c_part.split("_")
  554. cfg["contrast_cfg"] = {"method": "text_restore", "text_black_target": int(parts[1])}
  555. elif c_part.startswith("gm_"):
  556. parts = c_part.split("_")
  557. cfg["contrast_cfg"] = {"method": "gamma", "gamma": float(parts[1])}
  558. elif c_part.startswith("ln_"):
  559. parts = c_part.split("_")
  560. cfg["contrast_cfg"] = {"method": "linear", "black_percentile": float(parts[1]), "white_percentile": float(parts[2])}
  561. else:
  562. raise ValueError(f"无法解析 contrast tag: {c_part!r} (in {tag})")
  563. return cfg
  564. def run_best_config(
  565. input_path: Path,
  566. out_dir: Path,
  567. *,
  568. prefer_raw: bool,
  569. best_cfg: Dict[str, Any],
  570. model_dir: Path,
  571. save_images: bool,
  572. ) -> Dict[str, Any]:
  573. """对单图用指定最优参数跑一次 OCR。"""
  574. resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
  575. raw = cv2.imread(str(resolved))
  576. if raw is None:
  577. raise RuntimeError(f"无法读取图像: {resolved}")
  578. stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
  579. cell_out = out_dir / stem
  580. cell_out.mkdir(parents=True, exist_ok=True)
  581. engine = _make_engine(best_cfg["det_db_box_thresh"], model_dir)
  582. img = _preprocess(
  583. raw,
  584. method=best_cfg["method"],
  585. thresh=best_cfg.get("threshold"),
  586. contrast_cfg=best_cfg["contrast_cfg"],
  587. upscale=best_cfg["upscale"],
  588. contrast_order=best_cfg["contrast_order"],
  589. )
  590. tag = best_cfg.get("_tag", "best")
  591. if save_images:
  592. cv2.imwrite(str(cell_out / f"{tag}.png"), img)
  593. ocr = _ocr(engine, img, det=True, rec=True)
  594. row: Dict[str, Any] = {
  595. "tag": tag,
  596. "method": best_cfg["method"],
  597. "threshold": best_cfg.get("threshold"),
  598. "contrast_method": best_cfg["contrast_cfg"].get("method", "none"),
  599. "contrast_order": best_cfg["contrast_order"],
  600. "contrast_cfg": best_cfg["contrast_cfg"],
  601. "upscale": best_cfg["upscale"],
  602. "det_db_box_thresh": best_cfg["det_db_box_thresh"],
  603. "ocr_mode": "det_rec",
  604. **ocr,
  605. }
  606. report = {
  607. "input": str(resolved),
  608. "input_requested": str(input_path),
  609. "output_dir": str(cell_out),
  610. "result": row,
  611. }
  612. report_path = cell_out / "best_result.json"
  613. report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
  614. return report
  615. def _build_arg_parser() -> argparse.ArgumentParser:
  616. p = argparse.ArgumentParser(
  617. description="单元格图预处理 + OCR 参数网格扫描(对齐 pipeline 格级二次 OCR)",
  618. )
  619. p.add_argument(
  620. "input",
  621. type=Path,
  622. help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
  623. )
  624. p.add_argument(
  625. "-o",
  626. "--output",
  627. type=Path,
  628. default=None,
  629. help="输出目录,默认 <input_dir|input_parent>/sweep_out/<stem>",
  630. )
  631. p.add_argument(
  632. "-t",
  633. "--target",
  634. default=None,
  635. help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
  636. )
  637. p.add_argument(
  638. "--model-dir",
  639. type=Path,
  640. default=None,
  641. help="PaddleOCR torch 模型目录(含 det/rec .pth),也可用 OCR_*_MODEL_PATH",
  642. )
  643. p.add_argument(
  644. "--no-prefer-raw",
  645. action="store_true",
  646. help="不自动选用同名的 *_raw.png",
  647. )
  648. p.add_argument(
  649. "--quick",
  650. action="store_true",
  651. help="缩小网格(threshold 155,165 × upscale 128,192 × det 0.5 × contrast 精简)",
  652. )
  653. p.add_argument(
  654. "--methods",
  655. default="threshold,masked_adaptive,none",
  656. help="去水印方式,逗号分隔;none=不去水印",
  657. )
  658. p.add_argument(
  659. "--thresholds",
  660. default="155,165,none",
  661. help="threshold 法的阈值;none=预设默认",
  662. )
  663. p.add_argument(
  664. "--contrast-orders",
  665. default="before_upscale,after_upscale",
  666. help="contrast 执行顺序: before_upscale(放大前), after_upscale(放大后), 逗号组合",
  667. )
  668. p.add_argument(
  669. "--upscales",
  670. default="128,192",
  671. help="最短边放大目标,逗号分隔整数",
  672. )
  673. p.add_argument(
  674. "--det-threshs",
  675. # default="0.2,0.3,0.4,0.5",
  676. default="0.5",
  677. help="det_db_box_thresh,逗号分隔",
  678. )
  679. p.add_argument(
  680. "--no-save-images",
  681. action="store_true",
  682. help="不写出中间预处理 png(仅报告)",
  683. )
  684. p.add_argument(
  685. "--no-baseline",
  686. action="store_true",
  687. help="跳过「仅放大、不去水印」对照组",
  688. )
  689. p.add_argument(
  690. "--baseline-upscale",
  691. type=int,
  692. default=192,
  693. help="baseline 对照组的最短边放大",
  694. )
  695. p.add_argument(
  696. "--best-only",
  697. action="store_true",
  698. help="不跑参数网格,对目录下所有图用 --best-config 指定参数跑一次,验证适配性",
  699. )
  700. p.add_argument(
  701. "--best-config",
  702. default="threshold_t150_cl_1.0_8_ob_u128_det0.5",
  703. help="最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5",
  704. )
  705. return p
  706. def main(argv: Optional[Sequence[str]] = None) -> None:
  707. args = _build_arg_parser().parse_args(argv)
  708. inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
  709. if not inputs:
  710. raise SystemExit("未找到可扫描的图像")
  711. if args.output is not None:
  712. out_root = args.output
  713. elif args.input.is_file():
  714. out_root = args.input.parent / "sweep_out"
  715. else:
  716. out_root = args.input / "sweep_out"
  717. out_root.mkdir(parents=True, exist_ok=True)
  718. model_dir = args.model_dir or _default_model_dir()
  719. if args.best_only:
  720. # 验证适配性模式:对目录下所有图用最优参数跑一次
  721. best_cfg = _parse_best_config(args.best_config)
  722. best_cfg["_tag"] = args.best_config
  723. print(f"最佳参数验证模式: {args.best_config}")
  724. print(f" 解析: method={best_cfg['method']} contrast={best_cfg['contrast_cfg'].get('method')} "
  725. f"upscale={best_cfg['upscale']} order={best_cfg['contrast_order']}")
  726. print(f" 共 {len(inputs)} 张图")
  727. all_texts: List[Dict[str, Any]] = []
  728. hit_count = 0
  729. for img_path in inputs:
  730. report = run_best_config(
  731. img_path, out_root,
  732. prefer_raw=not args.no_prefer_raw,
  733. best_cfg=best_cfg,
  734. model_dir=model_dir,
  735. save_images=not args.no_save_images,
  736. )
  737. result = report["result"]
  738. text = result.get("text", "")
  739. score = result.get("score", 0)
  740. all_texts.append({
  741. "input": img_path.name,
  742. "text": text,
  743. "score": score,
  744. "report": str(Path(report["output_dir"]) / "best_result.json"),
  745. })
  746. m = _match_hit(text, args.target)
  747. hit_info = f" [HIT: {m}]" if m else ""
  748. print(f" {img_path.name}: score={score:.4f} text={text!r}{hit_info}")
  749. if m:
  750. hit_count += 1
  751. # 汇总
  752. summary_path = out_root / "best_summary.json"
  753. summary_data = {
  754. "best_config": args.best_config,
  755. "total": len(all_texts),
  756. "hits": hit_count,
  757. "target": args.target,
  758. "results": all_texts,
  759. }
  760. summary_path.write_text(json.dumps(summary_data, ensure_ascii=False, indent=2), encoding="utf-8")
  761. print(f"\n汇总: {hit_count}/{len(all_texts)} HIT -> {summary_path}")
  762. return
  763. # 正常参数网格扫描模式
  764. methods = [m.strip() for m in args.methods.split(",") if m.strip()]
  765. contrast_orders = [o.strip() for o in args.contrast_orders.split(",") if o.strip()]
  766. if args.quick:
  767. thresholds = [150, 155]
  768. upscales = [128, 192]
  769. det_threshs = [0.5]
  770. else:
  771. thresholds = _parse_csv_ints(args.thresholds)
  772. upscales = [int(x) for x in args.upscales.split(",") if x.strip()]
  773. det_threshs = _parse_csv_floats(args.det_threshs)
  774. contrast_grid = _build_contrast_grid(quick=args.quick)
  775. print(f"扫描 {len(inputs)} 张图 -> {out_root}")
  776. print(f" methods={methods} thresholds={thresholds} upscales={upscales}")
  777. print(f" contrast_methods={len(contrast_grid)} orders={contrast_orders}")
  778. if args.target:
  779. print(f" target={args.target!r}")
  780. summary: List[Dict[str, Any]] = []
  781. for img_path in inputs:
  782. print(f"\n=== {img_path.name} ===")
  783. report = run_sweep(
  784. img_path,
  785. out_root,
  786. prefer_raw=not args.no_prefer_raw,
  787. target=args.target,
  788. model_dir=model_dir,
  789. methods=methods,
  790. thresholds=thresholds,
  791. contrast_grid=contrast_grid,
  792. contrast_orders=contrast_orders,
  793. upscales=upscales,
  794. det_threshs=det_threshs,
  795. save_images=not args.no_save_images,
  796. run_baseline=not args.no_baseline,
  797. baseline_upscale=args.baseline_upscale,
  798. )
  799. summary.append(
  800. {
  801. "input": report["input"],
  802. "hits": len(report["hits"]),
  803. "report": str(Path(report["output_dir"]) / "sweep_report.json"),
  804. }
  805. )
  806. index_path = out_root / "sweep_index.json"
  807. index_path.write_text(
  808. json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
  809. )
  810. print(f"\n全部完成,索引: {index_path}")
  811. for s in summary:
  812. print(f" {s['input']}: {s['hits']} hits -> {s['report']}")
  813. if __name__ == "__main__":
  814. if len(sys.argv) == 1:
  815. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  816. default_config = {
  817. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty_raw.png",
  818. # "output": "./output/彭_广东兴宁农村商业银行/cell219_sweep",
  819. # "target": "ATM存折取款",
  820. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell007_whole_longer_易型交类_raw.png",
  821. # "output": "./output/彭_广东兴宁农村商业银行/cell007_sweep",
  822. # "target": "交易类型",
  823. # "quick": True,
  824. # "input": "/Users/zhch158/workspace/data/流水分析/钟_广东陆丰农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/钟_广东陆丰农村商业银行_page_001_0/cell217_empty_empty_raw.png",
  825. # "output": "./output/钟_广东陆丰农村商业银行/cell217_sweep",
  826. # "target": "专项资金",
  827. # "quick": True,
  828. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0",
  829. # "output": "./output/彭_广东兴宁农村商业银行",
  830. # "best-config": "threshold_t150_cl_1.0_8_ob_u128_det0.5",
  831. # "best-only": True,
  832. "input": "/Users/zhch158/workspace/data/流水分析/钟_广东陆丰农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/钟_广东陆丰农村商业银行_page_001_0",
  833. "output": "./output/钟_广东陆丰农村商业银行",
  834. # "best-config": "threshold_t150_cl_1.0_8_ob_u128_det0.5",
  835. "best-config": "threshold_t150_cl_1.0_4_ob_u128_det0.5",
  836. "best-only": True,
  837. }
  838. sys.argv = [sys.argv[0], default_config["input"]]
  839. for key, value in default_config.items():
  840. if key == "input":
  841. continue
  842. flag = f"--{key.replace('_', '-')}"
  843. if isinstance(value, bool) and value:
  844. sys.argv.append(flag)
  845. elif not isinstance(value, bool):
  846. sys.argv.extend([flag, str(value)])
  847. sys.exit(main())