cell_sweep.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. #!/usr/bin/env python3
  2. """
  3. 单元格裁剪图预处理参数扫描:去水印 / threshold / contrast / upscale / det 阈值 / OCR 模式。
  4. 默认从 **原图**(`*_raw.png`)出发,与 pipeline 二次 OCR 一致,避免对已预处理 debug 图二次去水印。
  5. 用法:
  6. python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
  7. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out
  8. python cell_sweep.py cell.png --quick --no-save-images
  9. OCR_DET_MODEL_PATH=... OCR_REC_MODEL_PATH=... python cell_sweep.py cell.png
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import json
  14. import os
  15. import sys
  16. from itertools import product
  17. from pathlib import Path
  18. from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
  19. import cv2
  20. import numpy as np
  21. _repo_root = Path(__file__).resolve().parents[2]
  22. if str(_repo_root) not in sys.path:
  23. sys.path.insert(0, str(_repo_root))
  24. from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
  25. from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
  26. _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  27. _DEFAULT_MODEL_DIR = Path(
  28. "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
  29. "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
  30. )
  31. def _parse_csv_ints(s: str) -> List[Optional[int]]:
  32. out: List[Optional[int]] = []
  33. for part in s.split(","):
  34. part = part.strip()
  35. if not part or part.lower() in ("none", "d", "default"):
  36. out.append(None)
  37. else:
  38. out.append(int(part))
  39. return out
  40. def _parse_csv_floats(s: str) -> List[float]:
  41. return [float(x.strip()) for x in s.split(",") if x.strip()]
  42. def _parse_csv_bools(s: str) -> List[bool]:
  43. out: List[bool] = []
  44. for part in s.split(","):
  45. p = part.strip().lower()
  46. if p in ("1", "true", "yes", "on"):
  47. out.append(True)
  48. elif p in ("0", "false", "no", "off"):
  49. out.append(False)
  50. else:
  51. raise ValueError(f"无效的 bool 值: {part!r}")
  52. return out
  53. def _default_model_dir() -> Path:
  54. det = os.environ.get("OCR_DET_MODEL_PATH")
  55. if det:
  56. return Path(det).parent
  57. return _DEFAULT_MODEL_DIR
  58. def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
  59. h, w = img.shape[:2]
  60. if h >= min_side and w >= min_side:
  61. return img
  62. s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
  63. return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
  64. def _preprocess(
  65. raw: np.ndarray,
  66. *,
  67. method: str,
  68. thresh: Optional[int],
  69. contrast: bool,
  70. upscale: int,
  71. text_black_target: int,
  72. ) -> np.ndarray:
  73. user: Dict[str, Any] = {"enabled": True, "method": method}
  74. if method == "threshold" and thresh is not None:
  75. user["threshold"] = thresh
  76. cfg = merge_watermark_config("cell", user)
  77. img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
  78. if contrast:
  79. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  80. ce = dict(cfg.get("contrast_enhancement") or {})
  81. ce["enabled"] = True
  82. ce["text_black_target"] = text_black_target
  83. gray = apply_contrast_enhancement_config(gray, ce)
  84. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  85. return _upscale(img, upscale)
  86. def _parse_rec_pair(rec_part: Any) -> Tuple[str, float]:
  87. """从 OCR 返回的 (text, score) 或嵌套结构中解析识别结果。"""
  88. if rec_part is None:
  89. return "", 0.0
  90. if isinstance(rec_part, (list, tuple)) and len(rec_part) >= 2:
  91. if isinstance(rec_part[0], (list, tuple, dict)):
  92. return "", 0.0
  93. txt = str(rec_part[0] or "").strip()
  94. try:
  95. sc = float(rec_part[1] or 0.0)
  96. except (TypeError, ValueError):
  97. sc = 0.0
  98. return txt, sc if txt else 0.0
  99. if isinstance(rec_part, (list, tuple)) and len(rec_part) == 1:
  100. txt = str(rec_part[0] or "").strip()
  101. return txt, 0.0
  102. return "", 0.0
  103. def _aggregate_rec_score(boxes: List[Dict[str, Any]]) -> float:
  104. """按字符数加权平均识别分(与 pipeline aggregate_line_ocr 一致)。"""
  105. total_len = sum(len(b.get("text") or "") for b in boxes)
  106. if total_len <= 0:
  107. return 0.0
  108. weighted = sum(
  109. len(b.get("text") or "") * float(b.get("score") or 0.0) for b in boxes
  110. )
  111. return weighted / total_len
  112. def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
  113. empty: Dict[str, Any] = {
  114. "text": "",
  115. "score": 0.0,
  116. "boxes": [],
  117. "det": det,
  118. "rec": rec,
  119. "n_boxes": 0,
  120. }
  121. try:
  122. res = engine.ocr(img, det=det, rec=rec)
  123. items = res[0] if res and res[0] is not None else []
  124. boxes_out: List[Dict[str, Any]] = []
  125. if det:
  126. for item in items:
  127. if not item or len(item) < 2:
  128. continue
  129. text, score = _parse_rec_pair(item[1])
  130. bbox = item[0]
  131. if hasattr(bbox, "tolist"):
  132. bbox = bbox.tolist()
  133. entry: Dict[str, Any] = {
  134. "text": text,
  135. "score": round(score, 6),
  136. }
  137. if bbox is not None:
  138. entry["det_bbox"] = bbox
  139. boxes_out.append(entry)
  140. else:
  141. for item in items:
  142. text, score = _parse_rec_pair(item)
  143. if not text and isinstance(item, (list, tuple)) and len(item) >= 1:
  144. text, score = _parse_rec_pair(item[0])
  145. boxes_out.append({"text": text, "score": round(score, 6)})
  146. text = "".join(b["text"] for b in boxes_out if b.get("text")).strip()
  147. agg_score = _aggregate_rec_score(boxes_out)
  148. return {
  149. "text": text,
  150. "score": round(agg_score, 6),
  151. "boxes": boxes_out,
  152. "det": det,
  153. "rec": rec,
  154. "n_boxes": len(boxes_out),
  155. }
  156. except Exception as e:
  157. out = dict(empty)
  158. out["error"] = str(e)
  159. return out
  160. def _make_engine(det_thresh: float, model_dir: Path) -> Any:
  161. from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
  162. det_path = os.environ.get("OCR_DET_MODEL_PATH") or str(
  163. model_dir / "ch_PP-OCRv5_det_infer.pth"
  164. )
  165. rec_path = os.environ.get("OCR_REC_MODEL_PATH") or str(
  166. model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
  167. )
  168. return PytorchPaddleOCR(
  169. lang="ch",
  170. det_model_path=det_path,
  171. rec_model_path=rec_path,
  172. det_db_box_thresh=det_thresh,
  173. )
  174. def resolve_input_image(path: Path, *, prefer_raw: bool) -> Path:
  175. """优先使用与 pipeline debug 配套的 *_raw.png。"""
  176. if not prefer_raw or path.stem.endswith("_raw"):
  177. return path
  178. raw_path = path.parent / f"{path.stem}_raw{path.suffix}"
  179. if raw_path.is_file():
  180. print(f" 使用原图: {raw_path.name}(跳过 {path.name})")
  181. return raw_path
  182. return path
  183. def collect_inputs(path: Path, *, prefer_raw: bool) -> List[Path]:
  184. if path.is_file():
  185. if path.suffix.lower() not in _IMAGE_SUFFIXES:
  186. raise ValueError(f"不支持的图像格式: {path}")
  187. return [resolve_input_image(path, prefer_raw=prefer_raw)]
  188. if not path.is_dir():
  189. raise FileNotFoundError(path)
  190. all_images = sorted(
  191. p
  192. for p in path.iterdir()
  193. if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
  194. )
  195. if not all_images:
  196. raise FileNotFoundError(f"目录内无图像: {path}")
  197. if prefer_raw:
  198. raws = [p for p in all_images if p.stem.endswith("_raw")]
  199. if raws:
  200. return raws
  201. chosen: List[Path] = []
  202. for p in all_images:
  203. if p.stem.endswith("_raw"):
  204. continue
  205. raw_sibling = p.parent / f"{p.stem}_raw{p.suffix}"
  206. if prefer_raw and raw_sibling.is_file():
  207. continue
  208. chosen.append(p)
  209. return chosen or all_images
  210. def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
  211. if not text:
  212. return None
  213. if not target:
  214. return "nonempty"
  215. if target in text:
  216. return "full"
  217. if len(target) >= 6 and target.isdigit() and len(text) >= 6 and text.isdigit():
  218. return "partial"
  219. return None
  220. def run_sweep(
  221. input_path: Path,
  222. out_dir: Path,
  223. *,
  224. prefer_raw: bool,
  225. target: Optional[str],
  226. model_dir: Path,
  227. methods: Sequence[str],
  228. thresholds: Sequence[Optional[int]],
  229. contrasts: Sequence[bool],
  230. upscales: Sequence[int],
  231. det_threshs: Sequence[float],
  232. text_black_target: int,
  233. save_images: bool,
  234. run_baseline: bool,
  235. baseline_upscale: int,
  236. ) -> Dict[str, Any]:
  237. resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
  238. raw = cv2.imread(str(resolved))
  239. if raw is None:
  240. raise RuntimeError(f"无法读取图像: {resolved}")
  241. stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
  242. cell_out = out_dir / stem
  243. cell_out.mkdir(parents=True, exist_ok=True)
  244. ocr_modes: List[Tuple[str, bool, bool]] = [
  245. ("det_rec", True, True),
  246. ("whole_rec", False, True),
  247. ]
  248. results: List[Dict[str, Any]] = []
  249. hits: List[Dict[str, Any]] = []
  250. engines: Dict[float, Any] = {}
  251. total = 0
  252. for method, thresh, contrast, upscale, det_th in product(
  253. methods, thresholds, contrasts, upscales, det_threshs
  254. ):
  255. if method != "threshold" and thresh is not None:
  256. continue
  257. if det_th not in engines:
  258. print(f" [{stem}] 加载 OCR det_db_box_thresh={det_th} ...")
  259. engines[det_th] = _make_engine(det_th, model_dir)
  260. img = _preprocess(
  261. raw,
  262. method=method,
  263. thresh=thresh,
  264. contrast=contrast,
  265. upscale=upscale,
  266. text_black_target=text_black_target,
  267. )
  268. tag = f"{method}_t{thresh or 'd'}_c{int(contrast)}_u{upscale}_det{det_th}"
  269. if save_images:
  270. cv2.imwrite(str(cell_out / f"{tag}.png"), img)
  271. for mode_name, det, rec in ocr_modes:
  272. total += 1
  273. ocr = _ocr(engines[det_th], img, det=det, rec=rec)
  274. row: Dict[str, Any] = {
  275. "tag": tag,
  276. "method": method,
  277. "threshold": thresh,
  278. "contrast": contrast,
  279. "upscale": upscale,
  280. "det_db_box_thresh": det_th,
  281. "ocr_mode": mode_name,
  282. **ocr,
  283. }
  284. results.append(row)
  285. m = _match_hit(row.get("text", ""), target)
  286. if m:
  287. row["match"] = m
  288. hits.append(row)
  289. print(
  290. f" HIT [{m}] {mode_name} {tag} "
  291. f"score={row.get('score')} -> {row.get('text')!r}"
  292. )
  293. if run_baseline:
  294. for det_th in det_threshs:
  295. if det_th not in engines:
  296. engines[det_th] = _make_engine(det_th, model_dir)
  297. base_img = _upscale(raw, baseline_upscale)
  298. if save_images:
  299. cv2.imwrite(str(cell_out / f"baseline_upscale{baseline_upscale}.png"), base_img)
  300. for mode_name, det, rec in ocr_modes:
  301. ocr = _ocr(engines[det_th], base_img, det=det, rec=rec)
  302. row = {
  303. "tag": f"baseline_upscale{baseline_upscale}",
  304. "det_db_box_thresh": det_th,
  305. "ocr_mode": mode_name,
  306. **ocr,
  307. }
  308. results.append(row)
  309. m = _match_hit(row.get("text", ""), target)
  310. if m:
  311. row["match"] = m
  312. hits.append(row)
  313. report = {
  314. "input": str(resolved),
  315. "input_requested": str(input_path),
  316. "output_dir": str(cell_out),
  317. "target": target,
  318. "total_trials": total,
  319. "hits": hits,
  320. "all_results": results,
  321. }
  322. report_path = cell_out / "sweep_report.json"
  323. report_path.write_text(
  324. json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
  325. )
  326. return report
  327. def _build_arg_parser() -> argparse.ArgumentParser:
  328. p = argparse.ArgumentParser(
  329. description="单元格图预处理 + OCR 参数网格扫描(对齐 pipeline 格级二次 OCR)",
  330. )
  331. p.add_argument(
  332. "input",
  333. type=Path,
  334. help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
  335. )
  336. p.add_argument(
  337. "-o",
  338. "--output",
  339. type=Path,
  340. default=None,
  341. help="输出目录,默认 <input_dir|input_parent>/sweep_out/<stem>",
  342. )
  343. p.add_argument(
  344. "-t",
  345. "--target",
  346. default=None,
  347. help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
  348. )
  349. p.add_argument(
  350. "--model-dir",
  351. type=Path,
  352. default=None,
  353. help="PaddleOCR torch 模型目录(含 det/rec .pth),也可用 OCR_*_MODEL_PATH",
  354. )
  355. p.add_argument(
  356. "--no-prefer-raw",
  357. action="store_true",
  358. help="不自动选用同名的 *_raw.png",
  359. )
  360. p.add_argument(
  361. "--quick",
  362. action="store_true",
  363. help="缩小网格(threshold 170,175 × upscale 128,192 × det 0.3,0.5)",
  364. )
  365. p.add_argument(
  366. "--methods",
  367. default="threshold,masked_adaptive",
  368. help="去水印方式,逗号分隔",
  369. )
  370. p.add_argument(
  371. "--thresholds",
  372. default="155,165,170,175,180,none",
  373. help="threshold 法的阈值;none=预设默认",
  374. )
  375. p.add_argument(
  376. "--contrasts",
  377. default="false,true",
  378. help="是否 contrast,逗号分隔 false,true",
  379. )
  380. p.add_argument(
  381. "--upscales",
  382. default="64,96,128,192",
  383. help="最短边放大目标,逗号分隔整数",
  384. )
  385. p.add_argument(
  386. "--det-threshs",
  387. default="0.2,0.3,0.4,0.5",
  388. help="det_db_box_thresh,逗号分隔",
  389. )
  390. p.add_argument(
  391. "--text-black-target",
  392. type=int,
  393. default=88,
  394. help="contrast text_restore 目标黑度",
  395. )
  396. p.add_argument(
  397. "--no-save-images",
  398. action="store_true",
  399. help="不写出中间预处理 png(仅报告)",
  400. )
  401. p.add_argument(
  402. "--no-baseline",
  403. action="store_true",
  404. help="跳过「仅放大、不去水印」对照组",
  405. )
  406. p.add_argument(
  407. "--baseline-upscale",
  408. type=int,
  409. default=128,
  410. help="baseline 对照组的最短边放大",
  411. )
  412. return p
  413. def main(argv: Optional[Sequence[str]] = None) -> None:
  414. args = _build_arg_parser().parse_args(argv)
  415. inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
  416. if not inputs:
  417. raise SystemExit("未找到可扫描的图像")
  418. if args.output is not None:
  419. out_root = args.output
  420. elif args.input.is_file():
  421. out_root = args.input.parent / "sweep_out"
  422. else:
  423. out_root = args.input / "sweep_out"
  424. out_root.mkdir(parents=True, exist_ok=True)
  425. model_dir = args.model_dir or _default_model_dir()
  426. methods = [m.strip() for m in args.methods.split(",") if m.strip()]
  427. if args.quick:
  428. thresholds = [170, 175]
  429. upscales = [128, 192]
  430. det_threshs = [0.3, 0.5]
  431. contrasts = [False, True]
  432. else:
  433. thresholds = _parse_csv_ints(args.thresholds)
  434. upscales = [int(x) for x in args.upscales.split(",") if x.strip()]
  435. det_threshs = _parse_csv_floats(args.det_threshs)
  436. contrasts = _parse_csv_bools(args.contrasts)
  437. print(f"扫描 {len(inputs)} 张图 -> {out_root}")
  438. print(f" methods={methods} thresholds={thresholds} upscales={upscales}")
  439. if args.target:
  440. print(f" target={args.target!r}")
  441. summary: List[Dict[str, Any]] = []
  442. for img_path in inputs:
  443. print(f"\n=== {img_path.name} ===")
  444. report = run_sweep(
  445. img_path,
  446. out_root,
  447. prefer_raw=not args.no_prefer_raw,
  448. target=args.target,
  449. model_dir=model_dir,
  450. methods=methods,
  451. thresholds=thresholds,
  452. contrasts=contrasts,
  453. upscales=upscales,
  454. det_threshs=det_threshs,
  455. text_black_target=args.text_black_target,
  456. save_images=not args.no_save_images,
  457. run_baseline=not args.no_baseline,
  458. baseline_upscale=args.baseline_upscale,
  459. )
  460. summary.append(
  461. {
  462. "input": report["input"],
  463. "hits": len(report["hits"]),
  464. "report": str(Path(report["output_dir"]) / "sweep_report.json"),
  465. }
  466. )
  467. index_path = out_root / "sweep_index.json"
  468. index_path.write_text(
  469. json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
  470. )
  471. print(f"\n全部完成,索引: {index_path}")
  472. for s in summary:
  473. print(f" {s['input']}: {s['hits']} hits -> {s['report']}")
  474. if __name__ == "__main__":
  475. if len(sys.argv) == 1:
  476. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  477. default_config = {
  478. "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty_raw.png",
  479. "output": "./output/彭_广东兴宁农村商业银行/cell219_sweep",
  480. "target": "ATM存折取款",
  481. }
  482. sys.argv = [sys.argv[0], default_config["input"]]
  483. for key, value in default_config.items():
  484. if key == "input":
  485. continue
  486. flag = f"--{key.replace('_', '-')}"
  487. if isinstance(value, bool) and value:
  488. sys.argv.append(flag)
  489. elif not isinstance(value, bool):
  490. sys.argv.extend([flag, str(value)])
  491. sys.exit(main())