| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- #!/usr/bin/env python3
- """
- 水印 mask 参数网格扫描:对比 light_on_white / diagonal_midtone / fused 三种策略及参数组合。
- 自动遍历多种参数组合,对每张输入图生成 mask overlay 并写入扫描报告,
- 帮助评估哪种参数组合能最好地覆盖水印区域。
- 用法:
- cd ocr_platform/ocr_tools/watermark_lab
- # 单张图
- python watermark_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png
- # 批量扫描目录
- python watermark_sweep.py ../gan_experiments_lab/test_images/input/
- # 快速模式(缩小参数网格)
- python watermark_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --quick
- # 指定输出目录
- python watermark_sweep.py input.png -o ./my_sweep_out
- # 跳过 mask overlay 图片(仅出 JSON 报告)
- python watermark_sweep.py input.png --no-save-images
- # 同时运行 LaMa 修复(需要指定权重路径)
- python watermark_sweep.py input.png --lama-ckpt /Users/zhch158/models/big-lama/models/best.ckpt
- """
- from __future__ import annotations
- import argparse
- import json
- import sys
- import time
- from itertools import product
- from pathlib import Path
- from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
- import cv2
- import numpy as np
- _repo_root = Path(__file__).resolve().parents[3]
- if str(_repo_root) not in sys.path:
- sys.path.insert(0, str(_repo_root))
- from loguru import logger
- from fused_mask import build_fused_watermark_mask
- _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
- # ── 工具函数 ────────────────────────────────────────────────────
- def _render_mask_overlay(bgr: np.ndarray, mask: np.ndarray, color=(0, 0, 255)) -> np.ndarray:
- """将 mask 以指定颜色叠加到原图上。"""
- ov = bgr.copy()
- ov[mask] = (ov[mask] * 0.4 + np.array(color, dtype=np.float32) * 0.6).astype(np.uint8)
- return ov
- def _tag_from_config(cfg: Dict[str, Any]) -> str:
- """从参数配置生成可读标签。"""
- mode = cfg.get("mask_mode", "?")
- parts = [mode]
- if mode in ("light_on_white", "fused"):
- parts.append(f"l{cfg.get('light_gray_low', '?')}")
- parts.append(f"t{cfg.get('text_protect', '?')}")
- parts.append(f"{cfg.get('direction_filter', '?')}")
- parts.append(f"ca{cfg.get('min_component_area', '?')}")
- if mode == "fused":
- parts.append(f"mk{cfg.get('median_kernel', '?')}")
- return "_".join(parts)
- # ── 扫描核心 ────────────────────────────────────────────────────
- def _collect_images(path: Path) -> List[Path]:
- """收集输入图片。"""
- if path.is_file():
- if path.suffix.lower() not in _IMAGE_SUFFIXES:
- raise ValueError(f"不支持的图像格式: {path}")
- return [path]
- if not path.is_dir():
- raise FileNotFoundError(path)
- return sorted(
- p for p in path.iterdir() if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
- )
- def _build_param_grid(quick: bool = False) -> List[Dict[str, Any]]:
- """
- 构建参数网格。
- 三组正交维度:
- 1. mask_mode: light_on_white | diagonal_midtone | fused
- 2. light_on_white / fused 专属: light_gray_low, text_protect, direction_filter
- 3. fused 专属: median_kernel
- 4. 通用: min_component_area
- """
- if quick:
- modes = ["light_on_white", "fused"]
- light_gray_lows = [200, 236]
- text_protects = [110, 130]
- direction_filters = ["none", "hough"]
- min_areas = [80, 200]
- median_kernels = [21, 31]
- else:
- modes = ["light_on_white", "diagonal_midtone", "fused"]
- light_gray_lows = [200, 220, 236]
- text_protects = [110, 130]
- direction_filters = ["none", "hough"]
- min_areas = [80, 200, 500]
- median_kernels = [21, 31, 41]
- grid: List[Dict[str, Any]] = []
- for mode in modes:
- if mode == "diagonal_midtone":
- for ca in min_areas:
- grid.append({"mask_mode": "diagonal_midtone", "min_component_area": ca})
- elif mode == "light_on_white":
- for lgl, tp, df, ca in product(light_gray_lows, text_protects, direction_filters, min_areas):
- grid.append({
- "mask_mode": "light_on_white",
- "light_gray_low": lgl,
- "text_protect": tp,
- "direction_filter": df,
- "min_component_area": ca,
- })
- else: # fused
- for lgl, tp, df, ca, mk in product(
- light_gray_lows, text_protects, direction_filters, min_areas, median_kernels
- ):
- grid.append({
- "mask_mode": "fused",
- "light_gray_low": lgl,
- "text_protect": tp,
- "direction_filter": df,
- "min_component_area": ca,
- "median_kernel": mk,
- })
- return grid
- def _build_mask(bgr: np.ndarray, gray: np.ndarray, cfg: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
- """根据参数配置构建水印 mask。"""
- mode = cfg["mask_mode"]
- if mode == "fused":
- return build_fused_watermark_mask(
- gray,
- bgr=bgr,
- a_enabled=True,
- a_light_gray_low=cfg["light_gray_low"],
- a_direction_filter=cfg["direction_filter"],
- a_text_protect_gray_max=cfg["text_protect"],
- a_min_component_area=cfg["min_component_area"],
- b_enabled=True,
- c_enabled=True,
- c_median_kernel=cfg["median_kernel"],
- min_component_area=cfg["min_component_area"],
- seal_protect=True,
- )
- # 单策略模式
- import ocr_utils.watermark.algorithms as _algo
- if mode == "light_on_white":
- return _algo.build_watermark_mask(
- gray,
- bgr=bgr,
- mask_mode="light_on_white",
- light_gray_low=cfg["light_gray_low"],
- direction_filter=cfg["direction_filter"],
- text_protect_gray_max=cfg["text_protect"],
- min_component_area=cfg["min_component_area"],
- seal_protect=True,
- )
- if mode == "diagonal_midtone":
- return _algo.build_watermark_mask(
- gray,
- bgr=bgr,
- mask_mode="diagonal_midtone",
- min_component_area=cfg["min_component_area"],
- )
- raise ValueError(f"未知 mask_mode: {mode}")
- def run_sweep(
- input_path: Path,
- out_dir: Path,
- *,
- quick: bool = False,
- save_images: bool = True,
- lama_ckpt: Optional[Path] = None,
- lama_repo: Optional[Path] = None,
- ) -> Dict[str, Any]:
- """对单张图执行参数网格扫描。"""
- bgr = cv2.imread(str(input_path))
- if bgr is None:
- raise RuntimeError(f"无法读取图像: {input_path}")
- gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
- stem = input_path.stem
- img_out = out_dir / stem
- img_out.mkdir(parents=True, exist_ok=True)
- grid = _build_param_grid(quick=quick)
- logger.info(f" {stem}: {len(grid)} 组参数组合")
- # LaMa(按需)
- lama = None
- if lama_ckpt:
- try:
- from lab.gan_experiments_lab.lama_inpaint import LamaInpainter
- lama = LamaInpainter(
- device="cpu",
- model_ckpt_path=str(lama_ckpt),
- lama_repo_path=str(lama_repo) if lama_repo else None,
- )
- lama.is_available # 触发检测
- logger.info(f" LaMa 后端: {lama._backend}")
- except Exception as e:
- logger.warning(f" 加载 LaMa 失败: {e}")
- results: List[Dict[str, Any]] = []
- for cfg in grid:
- tag = _tag_from_config(cfg)
- t0 = time.perf_counter()
- try:
- mask, debug = _build_mask(bgr, gray, cfg)
- except Exception as e:
- logger.warning(f" [{tag}] 构建 mask 失败: {e}")
- continue
- elapsed = time.perf_counter() - t0
- ratio = float(mask.sum() / gray.size)
- row: Dict[str, Any] = {
- "tag": tag,
- **cfg,
- "wm_mask_ratio": round(ratio, 6),
- "mask_build_time_s": round(elapsed, 3),
- }
- # 融合模式下记录各策略 ratio
- if cfg["mask_mode"] == "fused" and "strategies" in debug:
- for strategy_name, sinfo in debug["strategies"].items():
- if isinstance(sinfo, dict) and "ratio" in sinfo:
- row[f"ratio_{strategy_name}"] = round(sinfo["ratio"], 6)
- # mask overlay 图片
- if save_images:
- overlay = _render_mask_overlay(bgr, mask)
- overlay_path = img_out / f"{tag}_overlay.png"
- cv2.imwrite(str(overlay_path), overlay)
- row["overlay_path"] = str(overlay_path)
- # LaMa 修复(如启用)
- if lama and np.any(mask):
- try:
- t1 = time.perf_counter()
- result = lama.inpaint(bgr, mask)
- lama_time = time.perf_counter() - t1
- if result is not None and save_images:
- inpaint_path = img_out / f"{tag}_inpainted.png"
- cv2.imwrite(str(inpaint_path), result)
- row["inpainted_path"] = str(inpaint_path)
- row["lama_success"] = result is not None
- row["lama_time_s"] = round(lama_time, 2)
- except Exception as e:
- logger.warning(f" [{tag}] LaMa 修复失败: {e}")
- row["lama_success"] = False
- results.append(row)
- # ── 排序 ──
- # 1. 排除异常的 ratio(如 0 或接近全图)
- reasonable = [r for r in results if 0.005 < r["wm_mask_ratio"] < 0.80]
- # 2. 按 mask_ratio 接近中位数排序(太小的可能漏检,太大的可能过检)
- if reasonable:
- ratios = [r["wm_mask_ratio"] for r in reasonable]
- median_ratio = np.median(ratios)
- logger.info(f" ratio 中位数: {median_ratio:.4f}")
- def _score(r: Dict[str, Any]) -> float:
- return -abs(r["wm_mask_ratio"] - median_ratio)
- reasonable.sort(key=_score, reverse=True)
- top_n = min(5, len(reasonable))
- for i, r in enumerate(reasonable[:top_n]):
- logger.info(
- f" Top{i+1}: {r['tag']} ratio={r['wm_mask_ratio']:.4f}"
- )
- # ── 写入报告 ──
- report = {
- "input": str(input_path),
- "output_dir": str(img_out),
- "n_configs_tested": len(results),
- "n_reasonable": len(reasonable),
- "median_ratio": round(float(median_ratio), 6) if reasonable else None,
- "top_results": reasonable[:5] if reasonable else [],
- "all_results": results,
- }
- report_path = img_out / "sweep_report.json"
- report_path.write_text(
- json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- # CSV 摘要
- csv_path = img_out / "sweep_summary.csv"
- if results:
- csv_keys = [k for k in results[0].keys() if not k.endswith("_path")]
- lines = [",".join(csv_keys)]
- for r in results:
- vals = [str(r.get(k, "")) for k in csv_keys]
- lines.append(",".join(vals))
- csv_path.write_text("\n".join(lines), encoding="utf-8")
- logger.info(f" 报告: {report_path}")
- return report
- # ── CLI ──────────────────────────────────────────────────────────
- def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(
- description="水印 mask 参数网格扫描(light_on_white / diagonal_midtone / fused)",
- )
- p.add_argument(
- "input",
- type=Path,
- help="单张图片路径或图片目录",
- )
- p.add_argument(
- "-o", "--output",
- type=Path,
- default=None,
- help="输出根目录,默认 input 同级 sweep_out/<stem>",
- )
- p.add_argument(
- "--quick",
- action="store_true",
- help="缩小参数网格",
- )
- p.add_argument(
- "--no-save-images",
- action="store_true",
- help="不写出 mask overlay 图片(仅 JSON 报告)",
- )
- p.add_argument(
- "--lama-ckpt",
- type=Path,
- default=None,
- help="LaMa 权重文件路径(启用则每组跑 LaMa 修复)",
- )
- p.add_argument(
- "--lama-repo",
- type=Path,
- default=None,
- help="LaMa 仓库路径(用于导入 saicinpainting)",
- )
- return p
- def main(argv: Optional[Sequence[str]] = None) -> None:
- args = _build_arg_parser().parse_args(argv)
- images = _collect_images(args.input)
- if not images:
- raise SystemExit("未找到可扫描的图像")
- if args.output is not None:
- out_root = args.output
- elif args.input.is_file():
- out_root = args.input.parent / "sweep_out"
- else:
- out_root = args.input / "sweep_out"
- out_root.mkdir(parents=True, exist_ok=True)
- logger.info(f"扫描 {len(images)} 张图 -> {out_root}")
- logger.info(f" quick={args.quick}")
- summary: List[Dict[str, Any]] = []
- for img_path in images:
- logger.info(f"\n=== {img_path.name} ===")
- report = run_sweep(
- img_path,
- out_root,
- quick=args.quick,
- save_images=not args.no_save_images,
- lama_ckpt=args.lama_ckpt,
- lama_repo=args.lama_repo,
- )
- summary.append({
- "input": report["input"],
- "n_tested": report["n_configs_tested"],
- "median_ratio": report["median_ratio"],
- "report": str(Path(report["output_dir"]) / "sweep_report.json"),
- })
- index_path = out_root / "sweep_index.json"
- index_path.write_text(
- json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- logger.info(f"\n全部完成。索引: {index_path}")
- for s in summary:
- logger.info(
- f" {Path(s['input']).name}: "
- f"{s['n_tested']} 组, median_ratio={s['median_ratio']} -> {s['report']}"
- )
- if __name__ == "__main__":
- main()
|