watermark_sweep.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. #!/usr/bin/env python3
  2. """
  3. 水印 mask 参数网格扫描:对比 light_on_white / diagonal_midtone / fused 三种策略及参数组合。
  4. 自动遍历多种参数组合,对每张输入图生成 mask overlay 并写入扫描报告,
  5. 帮助评估哪种参数组合能最好地覆盖水印区域。
  6. 用法:
  7. cd ocr_platform/ocr_tools/watermark_lab
  8. # 单张图
  9. python watermark_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png
  10. # 批量扫描目录
  11. python watermark_sweep.py ../gan_experiments_lab/test_images/input/
  12. # 快速模式(缩小参数网格)
  13. python watermark_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --quick
  14. # 指定输出目录
  15. python watermark_sweep.py input.png -o ./my_sweep_out
  16. # 跳过 mask overlay 图片(仅出 JSON 报告)
  17. python watermark_sweep.py input.png --no-save-images
  18. # 同时运行 LaMa 修复(需要指定权重路径)
  19. python watermark_sweep.py input.png --lama-ckpt /Users/zhch158/models/big-lama/models/best.ckpt
  20. """
  21. from __future__ import annotations
  22. import argparse
  23. import json
  24. import sys
  25. import time
  26. from itertools import product
  27. from pathlib import Path
  28. from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
  29. import cv2
  30. import numpy as np
  31. _repo_root = Path(__file__).resolve().parents[3]
  32. if str(_repo_root) not in sys.path:
  33. sys.path.insert(0, str(_repo_root))
  34. from loguru import logger
  35. from fused_mask import build_fused_watermark_mask
  36. _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  37. # ── 工具函数 ────────────────────────────────────────────────────
  38. def _render_mask_overlay(bgr: np.ndarray, mask: np.ndarray, color=(0, 0, 255)) -> np.ndarray:
  39. """将 mask 以指定颜色叠加到原图上。"""
  40. ov = bgr.copy()
  41. ov[mask] = (ov[mask] * 0.4 + np.array(color, dtype=np.float32) * 0.6).astype(np.uint8)
  42. return ov
  43. def _tag_from_config(cfg: Dict[str, Any]) -> str:
  44. """从参数配置生成可读标签。"""
  45. mode = cfg.get("mask_mode", "?")
  46. parts = [mode]
  47. if mode in ("light_on_white", "fused"):
  48. parts.append(f"l{cfg.get('light_gray_low', '?')}")
  49. parts.append(f"t{cfg.get('text_protect', '?')}")
  50. parts.append(f"{cfg.get('direction_filter', '?')}")
  51. parts.append(f"ca{cfg.get('min_component_area', '?')}")
  52. if mode == "fused":
  53. parts.append(f"mk{cfg.get('median_kernel', '?')}")
  54. return "_".join(parts)
  55. # ── 扫描核心 ────────────────────────────────────────────────────
  56. def _collect_images(path: Path) -> List[Path]:
  57. """收集输入图片。"""
  58. if path.is_file():
  59. if path.suffix.lower() not in _IMAGE_SUFFIXES:
  60. raise ValueError(f"不支持的图像格式: {path}")
  61. return [path]
  62. if not path.is_dir():
  63. raise FileNotFoundError(path)
  64. return sorted(
  65. p for p in path.iterdir() if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
  66. )
  67. def _build_param_grid(quick: bool = False) -> List[Dict[str, Any]]:
  68. """
  69. 构建参数网格。
  70. 三组正交维度:
  71. 1. mask_mode: light_on_white | diagonal_midtone | fused
  72. 2. light_on_white / fused 专属: light_gray_low, text_protect, direction_filter
  73. 3. fused 专属: median_kernel
  74. 4. 通用: min_component_area
  75. """
  76. if quick:
  77. modes = ["light_on_white", "fused"]
  78. light_gray_lows = [200, 236]
  79. text_protects = [110, 130]
  80. direction_filters = ["none", "hough"]
  81. min_areas = [80, 200]
  82. median_kernels = [21, 31]
  83. else:
  84. modes = ["light_on_white", "diagonal_midtone", "fused"]
  85. light_gray_lows = [200, 220, 236]
  86. text_protects = [110, 130]
  87. direction_filters = ["none", "hough"]
  88. min_areas = [80, 200, 500]
  89. median_kernels = [21, 31, 41]
  90. grid: List[Dict[str, Any]] = []
  91. for mode in modes:
  92. if mode == "diagonal_midtone":
  93. for ca in min_areas:
  94. grid.append({"mask_mode": "diagonal_midtone", "min_component_area": ca})
  95. elif mode == "light_on_white":
  96. for lgl, tp, df, ca in product(light_gray_lows, text_protects, direction_filters, min_areas):
  97. grid.append({
  98. "mask_mode": "light_on_white",
  99. "light_gray_low": lgl,
  100. "text_protect": tp,
  101. "direction_filter": df,
  102. "min_component_area": ca,
  103. })
  104. else: # fused
  105. for lgl, tp, df, ca, mk in product(
  106. light_gray_lows, text_protects, direction_filters, min_areas, median_kernels
  107. ):
  108. grid.append({
  109. "mask_mode": "fused",
  110. "light_gray_low": lgl,
  111. "text_protect": tp,
  112. "direction_filter": df,
  113. "min_component_area": ca,
  114. "median_kernel": mk,
  115. })
  116. return grid
  117. def _build_mask(bgr: np.ndarray, gray: np.ndarray, cfg: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
  118. """根据参数配置构建水印 mask。"""
  119. mode = cfg["mask_mode"]
  120. if mode == "fused":
  121. return build_fused_watermark_mask(
  122. gray,
  123. bgr=bgr,
  124. a_enabled=True,
  125. a_light_gray_low=cfg["light_gray_low"],
  126. a_direction_filter=cfg["direction_filter"],
  127. a_text_protect_gray_max=cfg["text_protect"],
  128. a_min_component_area=cfg["min_component_area"],
  129. b_enabled=True,
  130. c_enabled=True,
  131. c_median_kernel=cfg["median_kernel"],
  132. min_component_area=cfg["min_component_area"],
  133. seal_protect=True,
  134. )
  135. # 单策略模式
  136. import ocr_utils.watermark.algorithms as _algo
  137. if mode == "light_on_white":
  138. return _algo.build_watermark_mask(
  139. gray,
  140. bgr=bgr,
  141. mask_mode="light_on_white",
  142. light_gray_low=cfg["light_gray_low"],
  143. direction_filter=cfg["direction_filter"],
  144. text_protect_gray_max=cfg["text_protect"],
  145. min_component_area=cfg["min_component_area"],
  146. seal_protect=True,
  147. )
  148. if mode == "diagonal_midtone":
  149. return _algo.build_watermark_mask(
  150. gray,
  151. bgr=bgr,
  152. mask_mode="diagonal_midtone",
  153. min_component_area=cfg["min_component_area"],
  154. )
  155. raise ValueError(f"未知 mask_mode: {mode}")
  156. def run_sweep(
  157. input_path: Path,
  158. out_dir: Path,
  159. *,
  160. quick: bool = False,
  161. save_images: bool = True,
  162. lama_ckpt: Optional[Path] = None,
  163. lama_repo: Optional[Path] = None,
  164. ) -> Dict[str, Any]:
  165. """对单张图执行参数网格扫描。"""
  166. bgr = cv2.imread(str(input_path))
  167. if bgr is None:
  168. raise RuntimeError(f"无法读取图像: {input_path}")
  169. gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
  170. stem = input_path.stem
  171. img_out = out_dir / stem
  172. img_out.mkdir(parents=True, exist_ok=True)
  173. grid = _build_param_grid(quick=quick)
  174. logger.info(f" {stem}: {len(grid)} 组参数组合")
  175. # LaMa(按需)
  176. lama = None
  177. if lama_ckpt:
  178. try:
  179. from lab.gan_experiments_lab.lama_inpaint import LamaInpainter
  180. lama = LamaInpainter(
  181. device="cpu",
  182. model_ckpt_path=str(lama_ckpt),
  183. lama_repo_path=str(lama_repo) if lama_repo else None,
  184. )
  185. lama.is_available # 触发检测
  186. logger.info(f" LaMa 后端: {lama._backend}")
  187. except Exception as e:
  188. logger.warning(f" 加载 LaMa 失败: {e}")
  189. results: List[Dict[str, Any]] = []
  190. for cfg in grid:
  191. tag = _tag_from_config(cfg)
  192. t0 = time.perf_counter()
  193. try:
  194. mask, debug = _build_mask(bgr, gray, cfg)
  195. except Exception as e:
  196. logger.warning(f" [{tag}] 构建 mask 失败: {e}")
  197. continue
  198. elapsed = time.perf_counter() - t0
  199. ratio = float(mask.sum() / gray.size)
  200. row: Dict[str, Any] = {
  201. "tag": tag,
  202. **cfg,
  203. "wm_mask_ratio": round(ratio, 6),
  204. "mask_build_time_s": round(elapsed, 3),
  205. }
  206. # 融合模式下记录各策略 ratio
  207. if cfg["mask_mode"] == "fused" and "strategies" in debug:
  208. for strategy_name, sinfo in debug["strategies"].items():
  209. if isinstance(sinfo, dict) and "ratio" in sinfo:
  210. row[f"ratio_{strategy_name}"] = round(sinfo["ratio"], 6)
  211. # mask overlay 图片
  212. if save_images:
  213. overlay = _render_mask_overlay(bgr, mask)
  214. overlay_path = img_out / f"{tag}_overlay.png"
  215. cv2.imwrite(str(overlay_path), overlay)
  216. row["overlay_path"] = str(overlay_path)
  217. # LaMa 修复(如启用)
  218. if lama and np.any(mask):
  219. try:
  220. t1 = time.perf_counter()
  221. result = lama.inpaint(bgr, mask)
  222. lama_time = time.perf_counter() - t1
  223. if result is not None and save_images:
  224. inpaint_path = img_out / f"{tag}_inpainted.png"
  225. cv2.imwrite(str(inpaint_path), result)
  226. row["inpainted_path"] = str(inpaint_path)
  227. row["lama_success"] = result is not None
  228. row["lama_time_s"] = round(lama_time, 2)
  229. except Exception as e:
  230. logger.warning(f" [{tag}] LaMa 修复失败: {e}")
  231. row["lama_success"] = False
  232. results.append(row)
  233. # ── 排序 ──
  234. # 1. 排除异常的 ratio(如 0 或接近全图)
  235. reasonable = [r for r in results if 0.005 < r["wm_mask_ratio"] < 0.80]
  236. # 2. 按 mask_ratio 接近中位数排序(太小的可能漏检,太大的可能过检)
  237. if reasonable:
  238. ratios = [r["wm_mask_ratio"] for r in reasonable]
  239. median_ratio = np.median(ratios)
  240. logger.info(f" ratio 中位数: {median_ratio:.4f}")
  241. def _score(r: Dict[str, Any]) -> float:
  242. return -abs(r["wm_mask_ratio"] - median_ratio)
  243. reasonable.sort(key=_score, reverse=True)
  244. top_n = min(5, len(reasonable))
  245. for i, r in enumerate(reasonable[:top_n]):
  246. logger.info(
  247. f" Top{i+1}: {r['tag']} ratio={r['wm_mask_ratio']:.4f}"
  248. )
  249. # ── 写入报告 ──
  250. report = {
  251. "input": str(input_path),
  252. "output_dir": str(img_out),
  253. "n_configs_tested": len(results),
  254. "n_reasonable": len(reasonable),
  255. "median_ratio": round(float(median_ratio), 6) if reasonable else None,
  256. "top_results": reasonable[:5] if reasonable else [],
  257. "all_results": results,
  258. }
  259. report_path = img_out / "sweep_report.json"
  260. report_path.write_text(
  261. json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
  262. )
  263. # CSV 摘要
  264. csv_path = img_out / "sweep_summary.csv"
  265. if results:
  266. csv_keys = [k for k in results[0].keys() if not k.endswith("_path")]
  267. lines = [",".join(csv_keys)]
  268. for r in results:
  269. vals = [str(r.get(k, "")) for k in csv_keys]
  270. lines.append(",".join(vals))
  271. csv_path.write_text("\n".join(lines), encoding="utf-8")
  272. logger.info(f" 报告: {report_path}")
  273. return report
  274. # ── CLI ──────────────────────────────────────────────────────────
  275. def _build_arg_parser() -> argparse.ArgumentParser:
  276. p = argparse.ArgumentParser(
  277. description="水印 mask 参数网格扫描(light_on_white / diagonal_midtone / fused)",
  278. )
  279. p.add_argument(
  280. "input",
  281. type=Path,
  282. help="单张图片路径或图片目录",
  283. )
  284. p.add_argument(
  285. "-o", "--output",
  286. type=Path,
  287. default=None,
  288. help="输出根目录,默认 input 同级 sweep_out/<stem>",
  289. )
  290. p.add_argument(
  291. "--quick",
  292. action="store_true",
  293. help="缩小参数网格",
  294. )
  295. p.add_argument(
  296. "--no-save-images",
  297. action="store_true",
  298. help="不写出 mask overlay 图片(仅 JSON 报告)",
  299. )
  300. p.add_argument(
  301. "--lama-ckpt",
  302. type=Path,
  303. default=None,
  304. help="LaMa 权重文件路径(启用则每组跑 LaMa 修复)",
  305. )
  306. p.add_argument(
  307. "--lama-repo",
  308. type=Path,
  309. default=None,
  310. help="LaMa 仓库路径(用于导入 saicinpainting)",
  311. )
  312. return p
  313. def main(argv: Optional[Sequence[str]] = None) -> None:
  314. args = _build_arg_parser().parse_args(argv)
  315. images = _collect_images(args.input)
  316. if not images:
  317. raise SystemExit("未找到可扫描的图像")
  318. if args.output is not None:
  319. out_root = args.output
  320. elif args.input.is_file():
  321. out_root = args.input.parent / "sweep_out"
  322. else:
  323. out_root = args.input / "sweep_out"
  324. out_root.mkdir(parents=True, exist_ok=True)
  325. logger.info(f"扫描 {len(images)} 张图 -> {out_root}")
  326. logger.info(f" quick={args.quick}")
  327. summary: List[Dict[str, Any]] = []
  328. for img_path in images:
  329. logger.info(f"\n=== {img_path.name} ===")
  330. report = run_sweep(
  331. img_path,
  332. out_root,
  333. quick=args.quick,
  334. save_images=not args.no_save_images,
  335. lama_ckpt=args.lama_ckpt,
  336. lama_repo=args.lama_repo,
  337. )
  338. summary.append({
  339. "input": report["input"],
  340. "n_tested": report["n_configs_tested"],
  341. "median_ratio": report["median_ratio"],
  342. "report": str(Path(report["output_dir"]) / "sweep_report.json"),
  343. })
  344. index_path = out_root / "sweep_index.json"
  345. index_path.write_text(
  346. json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
  347. )
  348. logger.info(f"\n全部完成。索引: {index_path}")
  349. for s in summary:
  350. logger.info(
  351. f" {Path(s['input']).name}: "
  352. f"{s['n_tested']} 组, median_ratio={s['median_ratio']} -> {s['report']}"
  353. )
  354. if __name__ == "__main__":
  355. main()