#!/usr/bin/env python3 """ 水印 mask 参数网格扫描:对比 light_on_white / diagonal_midtone / fused 三种策略及参数组合。 自动遍历多种参数组合,对每张输入图生成 mask overlay 并写入扫描报告, 帮助评估哪种参数组合能最好地覆盖水印区域。 用法: cd ocr_platform/ocr_tools/watermark_lab # 单张图 python watermark_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png # 批量扫描目录 python watermark_sweep.py ../gan_experiments_lab/test_images/input/ # 快速模式(缩小参数网格) python watermark_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --quick # 指定输出目录 python watermark_sweep.py input.png -o ./my_sweep_out # 跳过 mask overlay 图片(仅出 JSON 报告) python watermark_sweep.py input.png --no-save-images # 同时运行 LaMa 修复(需要指定权重路径) python watermark_sweep.py input.png --lama-ckpt /Users/zhch158/models/big-lama/models/best.ckpt """ from __future__ import annotations import argparse import json import sys import time from itertools import product from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import cv2 import numpy as np _repo_root = Path(__file__).resolve().parents[3] if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) from loguru import logger from fused_mask import build_fused_watermark_mask _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"} # ── 工具函数 ──────────────────────────────────────────────────── def _render_mask_overlay(bgr: np.ndarray, mask: np.ndarray, color=(0, 0, 255)) -> np.ndarray: """将 mask 以指定颜色叠加到原图上。""" ov = bgr.copy() ov[mask] = (ov[mask] * 0.4 + np.array(color, dtype=np.float32) * 0.6).astype(np.uint8) return ov def _tag_from_config(cfg: Dict[str, Any]) -> str: """从参数配置生成可读标签。""" mode = cfg.get("mask_mode", "?") parts = [mode] if mode in ("light_on_white", "fused"): parts.append(f"l{cfg.get('light_gray_low', '?')}") parts.append(f"t{cfg.get('text_protect', '?')}") parts.append(f"{cfg.get('direction_filter', '?')}") parts.append(f"ca{cfg.get('min_component_area', '?')}") if mode == "fused": parts.append(f"mk{cfg.get('median_kernel', '?')}") return "_".join(parts) # ── 扫描核心 ──────────────────────────────────────────────────── def _collect_images(path: Path) -> List[Path]: """收集输入图片。""" if path.is_file(): if path.suffix.lower() not in _IMAGE_SUFFIXES: raise ValueError(f"不支持的图像格式: {path}") return [path] if not path.is_dir(): raise FileNotFoundError(path) return sorted( p for p in path.iterdir() if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES ) def _build_param_grid(quick: bool = False) -> List[Dict[str, Any]]: """ 构建参数网格。 三组正交维度: 1. mask_mode: light_on_white | diagonal_midtone | fused 2. light_on_white / fused 专属: light_gray_low, text_protect, direction_filter 3. fused 专属: median_kernel 4. 通用: min_component_area """ if quick: modes = ["light_on_white", "fused"] light_gray_lows = [200, 236] text_protects = [110, 130] direction_filters = ["none", "hough"] min_areas = [80, 200] median_kernels = [21, 31] else: modes = ["light_on_white", "diagonal_midtone", "fused"] light_gray_lows = [200, 220, 236] text_protects = [110, 130] direction_filters = ["none", "hough"] min_areas = [80, 200, 500] median_kernels = [21, 31, 41] grid: List[Dict[str, Any]] = [] for mode in modes: if mode == "diagonal_midtone": for ca in min_areas: grid.append({"mask_mode": "diagonal_midtone", "min_component_area": ca}) elif mode == "light_on_white": for lgl, tp, df, ca in product(light_gray_lows, text_protects, direction_filters, min_areas): grid.append({ "mask_mode": "light_on_white", "light_gray_low": lgl, "text_protect": tp, "direction_filter": df, "min_component_area": ca, }) else: # fused for lgl, tp, df, ca, mk in product( light_gray_lows, text_protects, direction_filters, min_areas, median_kernels ): grid.append({ "mask_mode": "fused", "light_gray_low": lgl, "text_protect": tp, "direction_filter": df, "min_component_area": ca, "median_kernel": mk, }) return grid def _build_mask(bgr: np.ndarray, gray: np.ndarray, cfg: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]: """根据参数配置构建水印 mask。""" mode = cfg["mask_mode"] if mode == "fused": return build_fused_watermark_mask( gray, bgr=bgr, a_enabled=True, a_light_gray_low=cfg["light_gray_low"], a_direction_filter=cfg["direction_filter"], a_text_protect_gray_max=cfg["text_protect"], a_min_component_area=cfg["min_component_area"], b_enabled=True, c_enabled=True, c_median_kernel=cfg["median_kernel"], min_component_area=cfg["min_component_area"], seal_protect=True, ) # 单策略模式 import ocr_utils.watermark.algorithms as _algo if mode == "light_on_white": return _algo.build_watermark_mask( gray, bgr=bgr, mask_mode="light_on_white", light_gray_low=cfg["light_gray_low"], direction_filter=cfg["direction_filter"], text_protect_gray_max=cfg["text_protect"], min_component_area=cfg["min_component_area"], seal_protect=True, ) if mode == "diagonal_midtone": return _algo.build_watermark_mask( gray, bgr=bgr, mask_mode="diagonal_midtone", min_component_area=cfg["min_component_area"], ) raise ValueError(f"未知 mask_mode: {mode}") def run_sweep( input_path: Path, out_dir: Path, *, quick: bool = False, save_images: bool = True, lama_ckpt: Optional[Path] = None, lama_repo: Optional[Path] = None, ) -> Dict[str, Any]: """对单张图执行参数网格扫描。""" bgr = cv2.imread(str(input_path)) if bgr is None: raise RuntimeError(f"无法读取图像: {input_path}") gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) stem = input_path.stem img_out = out_dir / stem img_out.mkdir(parents=True, exist_ok=True) grid = _build_param_grid(quick=quick) logger.info(f" {stem}: {len(grid)} 组参数组合") # LaMa(按需) lama = None if lama_ckpt: try: from lab.gan_experiments_lab.lama_inpaint import LamaInpainter lama = LamaInpainter( device="cpu", model_ckpt_path=str(lama_ckpt), lama_repo_path=str(lama_repo) if lama_repo else None, ) lama.is_available # 触发检测 logger.info(f" LaMa 后端: {lama._backend}") except Exception as e: logger.warning(f" 加载 LaMa 失败: {e}") results: List[Dict[str, Any]] = [] for cfg in grid: tag = _tag_from_config(cfg) t0 = time.perf_counter() try: mask, debug = _build_mask(bgr, gray, cfg) except Exception as e: logger.warning(f" [{tag}] 构建 mask 失败: {e}") continue elapsed = time.perf_counter() - t0 ratio = float(mask.sum() / gray.size) row: Dict[str, Any] = { "tag": tag, **cfg, "wm_mask_ratio": round(ratio, 6), "mask_build_time_s": round(elapsed, 3), } # 融合模式下记录各策略 ratio if cfg["mask_mode"] == "fused" and "strategies" in debug: for strategy_name, sinfo in debug["strategies"].items(): if isinstance(sinfo, dict) and "ratio" in sinfo: row[f"ratio_{strategy_name}"] = round(sinfo["ratio"], 6) # mask overlay 图片 if save_images: overlay = _render_mask_overlay(bgr, mask) overlay_path = img_out / f"{tag}_overlay.png" cv2.imwrite(str(overlay_path), overlay) row["overlay_path"] = str(overlay_path) # LaMa 修复(如启用) if lama and np.any(mask): try: t1 = time.perf_counter() result = lama.inpaint(bgr, mask) lama_time = time.perf_counter() - t1 if result is not None and save_images: inpaint_path = img_out / f"{tag}_inpainted.png" cv2.imwrite(str(inpaint_path), result) row["inpainted_path"] = str(inpaint_path) row["lama_success"] = result is not None row["lama_time_s"] = round(lama_time, 2) except Exception as e: logger.warning(f" [{tag}] LaMa 修复失败: {e}") row["lama_success"] = False results.append(row) # ── 排序 ── # 1. 排除异常的 ratio(如 0 或接近全图) reasonable = [r for r in results if 0.005 < r["wm_mask_ratio"] < 0.80] # 2. 按 mask_ratio 接近中位数排序(太小的可能漏检,太大的可能过检) if reasonable: ratios = [r["wm_mask_ratio"] for r in reasonable] median_ratio = np.median(ratios) logger.info(f" ratio 中位数: {median_ratio:.4f}") def _score(r: Dict[str, Any]) -> float: return -abs(r["wm_mask_ratio"] - median_ratio) reasonable.sort(key=_score, reverse=True) top_n = min(5, len(reasonable)) for i, r in enumerate(reasonable[:top_n]): logger.info( f" Top{i+1}: {r['tag']} ratio={r['wm_mask_ratio']:.4f}" ) # ── 写入报告 ── report = { "input": str(input_path), "output_dir": str(img_out), "n_configs_tested": len(results), "n_reasonable": len(reasonable), "median_ratio": round(float(median_ratio), 6) if reasonable else None, "top_results": reasonable[:5] if reasonable else [], "all_results": results, } report_path = img_out / "sweep_report.json" report_path.write_text( json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8" ) # CSV 摘要 csv_path = img_out / "sweep_summary.csv" if results: csv_keys = [k for k in results[0].keys() if not k.endswith("_path")] lines = [",".join(csv_keys)] for r in results: vals = [str(r.get(k, "")) for k in csv_keys] lines.append(",".join(vals)) csv_path.write_text("\n".join(lines), encoding="utf-8") logger.info(f" 报告: {report_path}") return report # ── CLI ────────────────────────────────────────────────────────── def _build_arg_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description="水印 mask 参数网格扫描(light_on_white / diagonal_midtone / fused)", ) p.add_argument( "input", type=Path, help="单张图片路径或图片目录", ) p.add_argument( "-o", "--output", type=Path, default=None, help="输出根目录,默认 input 同级 sweep_out/", ) p.add_argument( "--quick", action="store_true", help="缩小参数网格", ) p.add_argument( "--no-save-images", action="store_true", help="不写出 mask overlay 图片(仅 JSON 报告)", ) p.add_argument( "--lama-ckpt", type=Path, default=None, help="LaMa 权重文件路径(启用则每组跑 LaMa 修复)", ) p.add_argument( "--lama-repo", type=Path, default=None, help="LaMa 仓库路径(用于导入 saicinpainting)", ) return p def main(argv: Optional[Sequence[str]] = None) -> None: args = _build_arg_parser().parse_args(argv) images = _collect_images(args.input) if not images: raise SystemExit("未找到可扫描的图像") if args.output is not None: out_root = args.output elif args.input.is_file(): out_root = args.input.parent / "sweep_out" else: out_root = args.input / "sweep_out" out_root.mkdir(parents=True, exist_ok=True) logger.info(f"扫描 {len(images)} 张图 -> {out_root}") logger.info(f" quick={args.quick}") summary: List[Dict[str, Any]] = [] for img_path in images: logger.info(f"\n=== {img_path.name} ===") report = run_sweep( img_path, out_root, quick=args.quick, save_images=not args.no_save_images, lama_ckpt=args.lama_ckpt, lama_repo=args.lama_repo, ) summary.append({ "input": report["input"], "n_tested": report["n_configs_tested"], "median_ratio": report["median_ratio"], "report": str(Path(report["output_dir"]) / "sweep_report.json"), }) index_path = out_root / "sweep_index.json" index_path.write_text( json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8" ) logger.info(f"\n全部完成。索引: {index_path}") for s in summary: logger.info( f" {Path(s['input']).name}: " f"{s['n_tested']} 组, median_ratio={s['median_ratio']} -> {s['report']}" ) if __name__ == "__main__": main()