contrast_sweep.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. #!/usr/bin/env python3
  2. """
  3. 方向1:对比度增强参数网格扫描。
  4. 不去水印,直接对原图做多种对比度增强,验证哪种参数组合能让水印
  5. 在视觉上"淡化"、正文保持清晰,从而使后续 OCR 不受水印干扰。
  6. 用法:
  7. cd ocr_platform/ocr_tools/watermark_lab
  8. # 单张图快速扫描
  9. python contrast_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --quick
  10. # 全量扫描(更多参数组合 + 生成增强对比图)
  11. python contrast_sweep.py ../gan_experiments_lab/test_images/input/
  12. # 同时跑 OCR 整页对比(det+rec,每方法 Top-1 组合)
  13. python contrast_sweep.py input.png --ocr --model-dir /path/to/models
  14. # 每方法 Top-3 组合跑 OCR 对比
  15. python contrast_sweep.py input.png --ocr --ocr-top-n 3
  16. 输出:
  17. output/<stem>/
  18. ├── sweep_report.json # 参数扫描结果汇总(含 OCR 对比结果)
  19. ├── sweep_summary.csv # CSV 表格
  20. ├── quad_compare.png # 四宫格对比图
  21. ├── text_restore_t60_bg248.png # 各组合增强结果图
  22. ├── clahe_cl3.0_t8.png
  23. ├── gamma_g0.5.png
  24. └── ocr/ # OCR 对比结果(--ocr 时生成)
  25. ├── <stem>_original_ocr_spans.png # 原始图 OCR 可视化
  26. ├── <stem>_original_ocr_spans.json # 原始图 OCR JSON
  27. ├── <stem>_<tag>_ocr_spans.png # 各增强组合 OCR 可视化
  28. ├── <stem>_<tag>_ocr_spans.json # 各增强组合 OCR JSON
  29. └── ocr_comparison.json # OCR 差异汇总报告
  30. """
  31. from __future__ import annotations
  32. import argparse
  33. import json
  34. import sys
  35. import time
  36. from itertools import product
  37. from pathlib import Path
  38. from typing import Any, Dict, List, Optional, Sequence, Tuple
  39. import cv2
  40. import numpy as np
  41. _repo_root = Path(__file__).resolve().parents[3]
  42. if str(_repo_root) not in sys.path:
  43. sys.path.insert(0, str(_repo_root))
  44. from loguru import logger
  45. from ocr_utils.watermark.contrast import enhance_document_contrast
  46. _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  47. # ── 参数网格 ────────────────────────────────────────────────────
  48. def _build_param_grid(quick: bool = False) -> List[Dict[str, Any]]:
  49. """构建参数网格。
  50. 四个维度:
  51. 1. method: text_restore | clahe | gamma | linear
  52. 2. text_restore 专属: text_black_target + background_threshold
  53. 3. clahe 专属: clip_limit + tile_grid_size
  54. 4. gamma 专属: gamma
  55. """
  56. grid: List[Dict[str, Any]] = []
  57. # ── text_restore ──
  58. if quick:
  59. tbt = [40, 60, 85]
  60. bts = [240, 248]
  61. else:
  62. tbt = [40, 60, 80, 100, 120]
  63. bts = [235, 240, 248, 252]
  64. for target, bg_th in product(tbt, bts):
  65. grid.append({
  66. "method": "text_restore",
  67. "text_black_target": target,
  68. "background_threshold": bg_th,
  69. })
  70. # ── clahe ──
  71. if quick:
  72. cl = [1.0, 3.0, 5.0]
  73. ts = [8, 16]
  74. else:
  75. cl = [0.5, 1.0, 2.0, 3.0, 5.0, 8.0]
  76. ts = [4, 8, 16, 32]
  77. for clip, tile in product(cl, ts):
  78. grid.append({
  79. "method": "clahe",
  80. "clip_limit": clip,
  81. "tile_grid_size": tile,
  82. })
  83. # ── gamma ──
  84. if quick:
  85. gvs = [0.4, 0.55, 0.7, 0.85]
  86. else:
  87. gvs = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  88. for g in gvs:
  89. grid.append({"method": "gamma", "gamma": g})
  90. # ── linear ──
  91. if quick:
  92. bps = [2.0, 5.0]
  93. wps = [95.0, 98.0]
  94. else:
  95. bps = [1.0, 2.0, 5.0, 8.0]
  96. wps = [92.0, 95.0, 98.0]
  97. for bp, wp in product(bps, wps):
  98. grid.append({"method": "linear", "black_percentile": bp, "white_percentile": wp})
  99. return grid
  100. # ── 标签生成 ────────────────────────────────────────────────────
  101. def _tag_from_cfg(cfg: Dict[str, Any]) -> str:
  102. m = cfg["method"]
  103. if m == "text_restore":
  104. return f"{m}_t{cfg['text_black_target']}_bg{cfg['background_threshold']}"
  105. if m == "clahe":
  106. return f"{m}_cl{cfg['clip_limit']}_t{cfg['tile_grid_size']}"
  107. if m == "gamma":
  108. return f"{m}_g{cfg['gamma']}"
  109. if m == "linear":
  110. return f"{m}_b{cfg['black_percentile']}_w{cfg['white_percentile']}"
  111. return m
  112. # ── 工具函数 ────────────────────────────────────────────────────
  113. def _collect_images(path: Path) -> List[Path]:
  114. if path.is_file():
  115. if path.suffix.lower() not in _IMAGE_SUFFIXES:
  116. raise ValueError(f"不支持的图像格式: {path}")
  117. return [path]
  118. if not path.is_dir():
  119. raise FileNotFoundError(path)
  120. return sorted(
  121. p for p in path.iterdir() if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
  122. )
  123. def _compute_watermark_fade_score(
  124. original: np.ndarray, enhanced: np.ndarray, window: int = 31
  125. ) -> float:
  126. """
  127. 量化水印淡化程度。
  128. 原理:大核中值滤波估计背景,残差的方差越小 = 水印纹理越弱。
  129. """
  130. o_f = original.astype(np.float32)
  131. e_f = enhanced.astype(np.float32)
  132. k = max(3, window) | 1
  133. o_bg = cv2.medianBlur(o_f.astype(np.uint8), k).astype(np.float32)
  134. e_bg = cv2.medianBlur(e_f.astype(np.uint8), k).astype(np.float32)
  135. o_res = cv2.absdiff(o_f, o_bg)
  136. e_res = cv2.absdiff(e_f, e_bg)
  137. return float(1.0 - np.var(e_res) / max(np.var(o_res), 1.0))
  138. def _compute_text_sharpness_score(
  139. enhanced: np.ndarray, win: int = 3
  140. ) -> float:
  141. """局部标准差均值,越大 = 文字越清晰。"""
  142. e_f = enhanced.astype(np.float32)
  143. kernel = np.ones((win, win), np.float32) / (win * win)
  144. mean = cv2.filter2D(e_f, -1, kernel)
  145. sq_mean = cv2.filter2D(e_f * e_f, -1, kernel)
  146. var = np.maximum(sq_mean - mean * mean, 0)
  147. return float(np.sqrt(var).mean())
  148. # ── 对比图生成 ──────────────────────────────────────────────────
  149. def _make_quad_compare(
  150. original: np.ndarray,
  151. top_enhanced: List[Tuple[str, np.ndarray]],
  152. ) -> np.ndarray:
  153. """生成四宫格对比图:原图 | 最佳 text_restore | 最佳 clahe | 最佳 gamma。"""
  154. panels = [original]
  155. labels = ["Original"]
  156. for label, img in top_enhanced:
  157. panels.append(img)
  158. labels.append(label)
  159. # 全部转 BGR
  160. bgr_panels: List[np.ndarray] = []
  161. for p in panels:
  162. if p.ndim == 2:
  163. bgr_panels.append(cv2.cvtColor(p, cv2.COLOR_GRAY2BGR))
  164. else:
  165. bgr_panels.append(p)
  166. # 统一高度
  167. h = max(p.shape[0] for p in bgr_panels)
  168. w = max(p.shape[1] for p in bgr_panels)
  169. resized: List[np.ndarray] = []
  170. for p, label in zip(bgr_panels, labels):
  171. if p.shape[0] != h or p.shape[1] != w:
  172. p = cv2.resize(p, (w, h))
  173. bar = np.ones((40, w, 3), dtype=np.uint8) * 240
  174. cv2.putText(bar, label, (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 2)
  175. resized.append(np.vstack([p, bar]))
  176. return np.hstack(resized)
  177. # ── OCR(整页对比)───────────────────────────────────────────────
  178. def _poly_to_bbox(poly: List[List[float]]) -> List[int]:
  179. """四点 polygon 转轴对齐 bbox [x0,y0,x1,y1]."""
  180. xs = [p[0] for p in poly]
  181. ys = [p[1] for p in poly]
  182. return [int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))]
  183. def _ocr_full_page(engine: Any, img: np.ndarray) -> List[Dict[str, Any]]:
  184. """整页 OCR(det+rec),返回 spans 列表。
  185. 每个 span: {"poly": [[x,y],...], "bbox": [x0,y0,x1,y1], "text": "...", "confidence": 0.9}
  186. """
  187. res = engine.ocr(img, det=True, rec=True)
  188. items = res[0] if res and res[0] is not None else []
  189. spans: List[Dict[str, Any]] = []
  190. for item in items:
  191. if not item or len(item) < 2:
  192. continue
  193. box, rec_part = item[0], item[1]
  194. text = str(rec_part[0] if isinstance(rec_part, (list, tuple)) else rec_part)
  195. conf = float(rec_part[1]) if isinstance(rec_part, (list, tuple)) and len(rec_part) > 1 else 0.0
  196. poly = [[float(p[0]), float(p[1])] for p in box] if box else []
  197. spans.append({
  198. "poly": poly,
  199. "bbox": _poly_to_bbox(poly) if poly else [],
  200. "text": text.strip(),
  201. "confidence": round(conf, 4),
  202. })
  203. from ocr_tools.universal_doc_parser.core.layout_utils import SpanMatcher
  204. spans = SpanMatcher.remove_duplicate_spans(spans)
  205. return spans
  206. def _save_ocr_debug_for_sweep(
  207. image: np.ndarray,
  208. spans: List[Dict[str, Any]],
  209. out_dir: Path,
  210. tag: str,
  211. ) -> Dict[str, str]:
  212. """保存 OCR 可视化图和 JSON。复用 module_debug_viz.draw_ocr_spans_cv2。"""
  213. from ocr_utils.module_debug_viz import draw_ocr_spans_cv2
  214. ocr_dir = out_dir / "ocr"
  215. ocr_dir.mkdir(parents=True, exist_ok=True)
  216. # 可视化图
  217. vis = draw_ocr_spans_cv2(image, spans)
  218. img_path = ocr_dir / f"{tag}_ocr_spans.png"
  219. cv2.imwrite(str(img_path), vis)
  220. # JSON
  221. json_data = {
  222. "tag": tag,
  223. "count": len(spans),
  224. "spans": [
  225. {
  226. "bbox": s.get("bbox"),
  227. "poly": s.get("poly"),
  228. "text": s.get("text"),
  229. "confidence": s.get("confidence"),
  230. }
  231. for s in spans
  232. ],
  233. }
  234. json_path = ocr_dir / f"{tag}_ocr_spans.json"
  235. json_path.write_text(json.dumps(json_data, ensure_ascii=False, indent=2), encoding="utf-8")
  236. logger.info(f" OCR debug: {img_path}")
  237. return {"image": str(img_path), "json": str(json_path)}
  238. def _compare_ocr_results(
  239. orig_spans: List[Dict[str, Any]],
  240. enh_spans: List[Dict[str, Any]],
  241. iou_threshold: float = 0.5,
  242. ) -> Dict[str, Any]:
  243. """对比两组 OCR spans 的检测+识别差异。
  244. Returns:
  245. {
  246. "detection": { orig_count, enh_count, matched, new, missing },
  247. "recognition": { text_changed_count, char_diff_rate, details: [...] },
  248. "summary": "一句话摘要"
  249. }
  250. """
  251. def _bbox_iou(a: List[int], b: List[int]) -> float:
  252. if not a or not b:
  253. return 0.0
  254. xa = max(a[0], b[0])
  255. ya = max(a[1], b[1])
  256. xb = min(a[2], b[2])
  257. yb = min(a[3], b[3])
  258. inter = max(0, xb - xa) * max(0, yb - ya)
  259. area_a = max(0, a[2] - a[0]) * max(0, a[3] - a[1])
  260. area_b = max(0, b[2] - b[0]) * max(0, b[3] - b[1])
  261. union = area_a + area_b - inter
  262. return inter / union if union > 0 else 0.0
  263. # ── 检测对比 ──
  264. orig_boxes = [s.get("bbox", []) for s in orig_spans]
  265. enh_boxes = [s.get("bbox", []) for s in enh_spans]
  266. matched_orig_idxs: set = set()
  267. matched_enh_idxs: set = set()
  268. recognition_details: List[Dict[str, Any]] = []
  269. for i, ob in enumerate(orig_boxes):
  270. if not ob:
  271. continue
  272. best_j, best_iou = -1, 0.0
  273. for j, eb in enumerate(enh_boxes):
  274. if j in matched_enh_idxs or not eb:
  275. continue
  276. iou = _bbox_iou(ob, eb)
  277. if iou > best_iou:
  278. best_iou, best_j = iou, j
  279. if best_iou >= iou_threshold:
  280. matched_orig_idxs.add(i)
  281. matched_enh_idxs.add(best_j)
  282. orig_text = orig_spans[i].get("text", "")
  283. enh_text = enh_spans[best_j].get("text", "")
  284. orig_score = orig_spans[i].get("confidence", 0)
  285. enh_score = enh_spans[best_j].get("confidence", 0)
  286. rec_detail: Dict[str, Any] = {
  287. "orig_bbox": ob,
  288. "orig_text": orig_text,
  289. "orig_score": orig_score,
  290. "enh_text": enh_text,
  291. "enh_score": enh_score,
  292. "iou": round(best_iou, 4),
  293. }
  294. if orig_text != enh_text:
  295. rec_detail["text_changed"] = True
  296. else:
  297. rec_detail["text_changed"] = False
  298. recognition_details.append(rec_detail)
  299. new_boxes = len(enh_boxes) - len(matched_enh_idxs)
  300. missing_boxes = len(orig_boxes) - len(matched_orig_idxs)
  301. # 字符差异率
  302. orig_concat = "".join(s.get("text", "") for s in orig_spans)
  303. enh_concat = "".join(s.get("text", "") for s in enh_spans)
  304. total_chars = max(len(orig_concat), len(enh_concat), 1)
  305. char_diff = sum(1 for a, b in zip(orig_concat, enh_concat) if a != b) + abs(
  306. len(orig_concat) - len(enh_concat)
  307. )
  308. char_diff_rate = round(char_diff / total_chars, 4)
  309. detection = {
  310. "orig_count": len(orig_boxes),
  311. "enh_count": len(enh_boxes),
  312. "matched": len(matched_orig_idxs),
  313. "new": new_boxes,
  314. "missing": missing_boxes,
  315. }
  316. recognition = {
  317. "text_changed_count": len(recognition_details),
  318. "char_diff_rate": char_diff_rate,
  319. "details": recognition_details[:50], # 最多保存50条差异明细
  320. }
  321. summary = (
  322. f"检测: {detection['orig_count']}→{detection['enh_count']} (匹配{detection['matched']}, "
  323. f"新增{detection['new']}, 遗失{detection['missing']}); "
  324. f"识别: 文字变化{recognition['text_changed_count']}处, 字符差异率{char_diff_rate:.2%}"
  325. )
  326. return {"detection": detection, "recognition": recognition, "summary": summary}
  327. def _load_paddle_engine(model_dir: Path, det_thresh: float = 0.3):
  328. from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
  329. det = model_dir / "ch_PP-OCRv5_det_infer.pth"
  330. rec = model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
  331. return PytorchPaddleOCR(
  332. lang="ch",
  333. det_model_path=str(det) if det.exists() else None,
  334. rec_model_path=str(rec) if rec.exists() else None,
  335. det_db_box_thresh=det_thresh,
  336. )
  337. # ── 扫描核心 ────────────────────────────────────────────────────
  338. def run_sweep(
  339. input_path: Path,
  340. out_dir: Path,
  341. *,
  342. quick: bool = False,
  343. save_images: bool = True,
  344. ocr_enabled: bool = False,
  345. model_dir: Optional[Path] = None,
  346. ocr_top_n: int = 1,
  347. ocr_all: bool = False,
  348. ) -> Dict[str, Any]:
  349. bgr = cv2.imread(str(input_path))
  350. if bgr is None:
  351. raise RuntimeError(f"无法读取图像: {input_path}")
  352. gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
  353. stem = input_path.stem
  354. img_out = out_dir / stem
  355. img_out.mkdir(parents=True, exist_ok=True)
  356. grid = _build_param_grid(quick=quick)
  357. logger.info(f" {stem}: {len(grid)} 组参数组合")
  358. engine = None
  359. baseline_spans: List[Dict[str, Any]] = []
  360. if ocr_enabled:
  361. try:
  362. md = model_dir or Path(
  363. "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
  364. "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
  365. )
  366. engine = _load_paddle_engine(md)
  367. logger.info(" OCR 引擎已加载")
  368. # 基线 OCR(原始灰度图)
  369. baseline_spans = _ocr_full_page(engine, gray)
  370. logger.info(f" 基线 OCR: {len(baseline_spans)} 个文本块")
  371. _save_ocr_debug_for_sweep(bgr, baseline_spans, img_out, f"{stem}_original")
  372. except Exception as e:
  373. logger.warning(f" OCR 引擎加载失败: {e}")
  374. results: List[Dict[str, Any]] = []
  375. # 按 method 分组, 便于后面取各类别最优
  376. method_groups: Dict[str, List[Dict[str, Any]]] = {}
  377. for cfg in grid:
  378. tag = _tag_from_cfg(cfg)
  379. t0 = time.perf_counter()
  380. try:
  381. enhanced = enhance_document_contrast(gray, **cfg)
  382. except Exception as e:
  383. logger.warning(f" [{tag}] 增强失败: {e}")
  384. continue
  385. elapsed = time.perf_counter() - t0
  386. fade = _compute_watermark_fade_score(gray, enhanced)
  387. sharpness = _compute_text_sharpness_score(enhanced)
  388. # 综合分:fade(水印淡化) 和 sharpness(文字清晰度) 同等权重
  389. combined = round(fade * 0.5 + sharpness / max(sharpness, 0.01) * 0.5, 4)
  390. row: Dict[str, Any] = {
  391. "tag": tag,
  392. **cfg,
  393. "fade_score": round(fade, 6),
  394. "sharpness_score": round(sharpness, 4),
  395. "combined_score": round(combined, 4),
  396. "time_ms": round(elapsed * 1000, 1),
  397. }
  398. if save_images:
  399. out_path = img_out / f"{tag}.png"
  400. cv2.imwrite(str(out_path), enhanced)
  401. row["image_path"] = str(out_path)
  402. results.append(row)
  403. method = cfg["method"]
  404. method_groups.setdefault(method, []).append(row)
  405. # ── 排序 ──
  406. results.sort(key=lambda r: -r["combined_score"])
  407. for mname, entries in method_groups.items():
  408. entries.sort(key=lambda r: -r["combined_score"])
  409. # Top 各方法最优
  410. tops: List[Tuple[str, str, float]] = []
  411. for mname, entries in method_groups.items():
  412. if entries:
  413. top = entries[0]
  414. tops.append((mname, top["tag"], top["combined_score"]))
  415. logger.info(f" [{mname}] Top: {top['tag']} combined={top['combined_score']:.4f}")
  416. logger.info(f" 全局 Top1: {results[0]['tag']} combined={results[0]['combined_score']:.4f}")
  417. # ── 阶段二:OCR 对比(整页)─────────────────────────────────
  418. ocr_comparisons: List[Dict[str, Any]] = []
  419. if engine and baseline_spans:
  420. # 选择要跑 OCR 的组合列表
  421. if ocr_all:
  422. ocr_candidates = results
  423. else:
  424. ocr_candidates: List[Dict[str, Any]] = []
  425. for mname, entries in method_groups.items():
  426. for r in entries[:ocr_top_n]:
  427. if r not in ocr_candidates:
  428. ocr_candidates.append(r)
  429. logger.info(f" OCR 对比 {len(ocr_candidates)} 个组合(每方法 Top-{ocr_top_n})")
  430. for r in ocr_candidates:
  431. tag = r["tag"]
  432. enhanced_path = r.get("image_path")
  433. if enhanced_path:
  434. enhanced_bgr = cv2.imread(enhanced_path)
  435. if enhanced_bgr is None:
  436. logger.warning(f" [{tag}] 无法读取增强图")
  437. continue
  438. enhanced_gray = cv2.cvtColor(enhanced_bgr, cv2.COLOR_BGR2GRAY)
  439. else:
  440. # 从 raws 重新生成
  441. enhanced_gray = enhance_document_contrast(gray, **{
  442. k: v for k, v in r.items()
  443. if k not in ("tag", "image_path", "fade_score", "sharpness_score", "combined_score", "time_ms")
  444. })
  445. enhanced_bgr = cv2.cvtColor(enhanced_gray, cv2.COLOR_GRAY2BGR)
  446. try:
  447. enh_spans = _ocr_full_page(engine, enhanced_gray)
  448. except Exception as e:
  449. logger.warning(f" [{tag}] OCR 失败: {e}")
  450. continue
  451. _save_ocr_debug_for_sweep(enhanced_bgr, enh_spans, img_out, f"{stem}_{tag}")
  452. cmp = _compare_ocr_results(baseline_spans, enh_spans)
  453. cmp["tag"] = tag
  454. cmp["method"] = r["method"]
  455. r["ocr_comparison"] = cmp
  456. ocr_comparisons.append(cmp)
  457. for k, v in cmp["detection"].items():
  458. r[f"ocr_det_{k}"] = v
  459. for k, v in cmp["recognition"].items():
  460. if not isinstance(v, list):
  461. r[f"ocr_rec_{k}"] = v
  462. logger.info(f" [{tag}] {cmp['summary']}")
  463. # ── 四宫格对比图 ──
  464. if save_images and len(tops) >= 3:
  465. selected_labels = []
  466. selected_imgs = []
  467. seen_methods = set()
  468. for r in results:
  469. m = r["method"]
  470. if m in seen_methods:
  471. continue
  472. seen_methods.add(m)
  473. path = r.get("image_path")
  474. if path:
  475. img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
  476. if img is not None:
  477. selected_labels.append(r["tag"])
  478. selected_imgs.append(img)
  479. if len(selected_imgs) >= 3:
  480. break
  481. if selected_imgs:
  482. quad = _make_quad_compare(gray, list(zip(selected_labels, selected_imgs)))
  483. quad_path = img_out / "quad_compare.png"
  484. cv2.imwrite(str(quad_path), quad)
  485. logger.info(f" 四宫格对比图: {quad_path}")
  486. # ── 报告 ──
  487. report: Dict[str, Any] = {
  488. "input": str(input_path),
  489. "output_dir": str(img_out),
  490. "n_configs_tested": len(results),
  491. "top_overall": results[0] if results else None,
  492. "top_by_method": {
  493. m: e[0] for m, e in method_groups.items() if e
  494. },
  495. }
  496. if engine:
  497. baseline_text = "".join(s.get("text", "") for s in baseline_spans)
  498. report["baseline_ocr"] = {
  499. "span_count": len(baseline_spans),
  500. "full_text": baseline_text,
  501. }
  502. report["ocr_comparisons"] = {
  503. "n_compared": len(ocr_comparisons),
  504. "results": ocr_comparisons,
  505. }
  506. report_path = img_out / "contrast_report.json"
  507. report_path.write_text(
  508. json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
  509. )
  510. # 单独的 OCR 对比汇总报告(含完整检测+识别对比数据)
  511. if engine:
  512. ocr_summary_path = img_out / "ocr" / "ocr_comparison.json"
  513. ocr_summary_path.parent.mkdir(parents=True, exist_ok=True)
  514. ocr_summary = {
  515. "input": str(input_path),
  516. "baseline_spans": len(baseline_spans),
  517. "compared": ocr_comparisons,
  518. }
  519. ocr_summary_path.write_text(
  520. json.dumps(ocr_summary, ensure_ascii=False, indent=2), encoding="utf-8"
  521. )
  522. # CSV
  523. if results:
  524. csv_keys = [k for k in results[0].keys() if not k.endswith("_path") and k != "ocr_comparison"]
  525. lines = [",".join(csv_keys)]
  526. for r in results:
  527. lines.append(",".join(str(r.get(k, "")) for k in csv_keys))
  528. (img_out / "contrast_summary.csv").write_text("\n".join(lines), encoding="utf-8")
  529. logger.info(f" 报告: {report_path}")
  530. return report
  531. # ── CLI ──────────────────────────────────────────────────────────
  532. def _build_arg_parser() -> argparse.ArgumentParser:
  533. p = argparse.ArgumentParser(
  534. description="对比度增强参数网格扫描(不去水印,直接增强前后对比)",
  535. )
  536. p.add_argument("input", type=Path, help="单张图片路径或图片目录")
  537. p.add_argument("-o", "--output", type=Path, default=None,
  538. help="输出根目录,默认 input 同级 contrast_out/<stem>")
  539. p.add_argument("--quick", action="store_true",
  540. help="缩小参数网格")
  541. p.add_argument("--no-save-images", action="store_true",
  542. help="不写出增强结果图")
  543. p.add_argument("--ocr", action="store_true",
  544. help="启用整页 OCR 对比(det+rec):基线 OCR + Top-N 增强图 OCR,输出 spans 可视化和 JSON")
  545. p.add_argument("--ocr-top-n", type=int, default=1,
  546. help="OCR 对比时每方法取 Top-N 组合(默认 1)")
  547. p.add_argument("--ocr-all", action="store_true",
  548. help="对所有参数组合跑 OCR 对比(覆盖 --ocr-top-n)")
  549. p.add_argument("--model-dir", type=Path, default=None,
  550. help="PaddleOCR 模型目录")
  551. return p
  552. def main(argv: Optional[Sequence[str]] = None) -> None:
  553. args = _build_arg_parser().parse_args(argv)
  554. images = _collect_images(args.input)
  555. if not images:
  556. raise SystemExit("未找到可扫描的图像")
  557. if args.output is not None:
  558. out_root = args.output
  559. elif args.input.is_file():
  560. out_root = args.input.parent / "contrast_out"
  561. else:
  562. out_root = args.input / "contrast_out"
  563. out_root.mkdir(parents=True, exist_ok=True)
  564. logger.info(f"扫描 {len(images)} 张图 -> {out_root}")
  565. logger.info(f" quick={args.quick} ocr={args.ocr} ocr_top_n={args.ocr_top_n} ocr_all={args.ocr_all}")
  566. summary: List[Dict[str, Any]] = []
  567. for img_path in images:
  568. logger.info(f"\n=== {img_path.name} ===")
  569. report = run_sweep(
  570. img_path,
  571. out_root,
  572. quick=args.quick,
  573. save_images=not args.no_save_images,
  574. ocr_enabled=args.ocr,
  575. model_dir=args.model_dir,
  576. ocr_top_n=args.ocr_top_n,
  577. ocr_all=args.ocr_all,
  578. )
  579. to = report.get("top_overall")
  580. summary.append({
  581. "input": report["input"],
  582. "n_tested": report["n_configs_tested"],
  583. "top_tag": to["tag"] if to else None,
  584. "top_combined": to["combined_score"] if to else None,
  585. "report": str(Path(report["output_dir"]) / "contrast_report.json"),
  586. })
  587. index_path = out_root / "contrast_index.json"
  588. index_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
  589. logger.info(f"\n全部完成。索引: {index_path}")
  590. for s in summary:
  591. logger.info(f" {Path(s['input']).name}: Top={s['top_tag']} combined={s['top_combined']}")
  592. if __name__ == "__main__":
  593. # python contrast_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --ocr --ocr-top-n 1
  594. if len(sys.argv) == 1:
  595. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  596. default_config = {
  597. "input": "../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png",
  598. "output": "./output/彭_广东兴宁农村商业银行_page_002/contrast_sweep",
  599. "ocr": True,
  600. "ocr_top_n": 3,
  601. "quick": True,
  602. }
  603. sys.argv = [sys.argv[0], default_config["input"]]
  604. for key, value in default_config.items():
  605. if key == "input":
  606. continue
  607. flag = f"--{key.replace('_', '-')}"
  608. if isinstance(value, bool) and value:
  609. sys.argv.append(flag)
  610. elif not isinstance(value, bool):
  611. sys.argv.extend([flag, str(value)])
  612. sys.exit(main())