| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466 |
- """
- 去水印评估脚本:对比 baseline (masked_adaptive) 与 LaMa GAN 方法。
- 用法:
- cd ocr_platform/ocr_tools/gan_experiments_lab
- # 对 test_images/input/ 下所有图片做对比
- python evaluate.py
- # 指定输入/输出目录
- python evaluate.py --input ./test_images/synthetic/ --output ./output/synthetic_compare
- # 有clean参考图时计算 PSNR/SSIM
- python evaluate.py --input ./test_images/synthetic/ --clean-dir ./test_images/clean/
- 生成物:
- output/compare/ — 三联对比图 (原图 | baseline | GAN)
- output/inpainted/ — GAN 修复结果
- output/mask_debug/ — 掩膜可视化
- output/metrics/ — 评估指标 JSON
- """
- from __future__ import annotations
- import argparse
- import json
- import sys
- import time
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- import cv2
- import numpy as np
- # 将 ocr_platform 根目录加入 sys.path,以便导入 ocr_utils
- _repo_root = Path(__file__).resolve().parents[3]
- if str(_repo_root) not in sys.path:
- sys.path.insert(0, str(_repo_root))
- from loguru import logger
- from PIL import Image
- from ocr_utils.watermark import (
- WatermarkProcessor,
- build_watermark_mask,
- detect_watermark,
- merge_watermark_config,
- render_watermark_mask_overlay,
- )
- from lama_inpaint import LamaInpainter
- # ── 评估指标 ────────────────────────────────────────────────────
- def _to_gray(img: np.ndarray) -> np.ndarray:
- if img.ndim == 3:
- return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float64)
- return img.astype(np.float64)
- def compute_psnr(img1: np.ndarray, img2: np.ndarray) -> float:
- g1, g2 = _to_gray(img1), _to_gray(img2)
- mse = np.mean((g1 - g2) ** 2)
- if mse < 1e-10:
- return 100.0
- return float(20 * np.log10(255.0 / np.sqrt(mse)))
- def compute_ssim(img1: np.ndarray, img2: np.ndarray) -> float:
- """简易 SSIM 实现(灰度,8x8 block)。"""
- from math import exp, pi, sqrt
- g1, g2 = _to_gray(img1), _to_gray(img2)
- k1, k2 = 0.01, 0.03
- l = 255.0
- c1, c2 = (k1 * l) ** 2, (k2 * l) ** 2
- kernel = cv2.getGaussianKernel(11, 1.5)
- window = np.outer(kernel, kernel)
- window /= window.sum()
- mu1 = cv2.filter2D(g1, -1, window, borderType=cv2.BORDER_REFLECT)
- mu2 = cv2.filter2D(g2, -1, window, borderType=cv2.BORDER_REFLECT)
- mu1_sq = mu1 * mu1
- mu2_sq = mu2 * mu2
- mu1_mu2 = mu1 * mu2
- sigma1_sq = cv2.filter2D(g1 * g1, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_sq
- sigma2_sq = cv2.filter2D(g2 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu2_sq
- sigma12 = cv2.filter2D(g1 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_mu2
- num = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
- denom = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
- ssim_map = num / (denom + 1e-10)
- return float(ssim_map.mean())
- # ── 水印配置 ────────────────────────────────────────────────────
- def _baseline_config() -> Dict[str, Any]:
- return merge_watermark_config("page", {
- "method": "masked_adaptive",
- "threshold": 175,
- "contrast_enhancement": {"enabled": True, "method": "text_restore", "text_black_target": 85},
- })
- def _gan_wm_config() -> Dict[str, Any]:
- return merge_watermark_config("page", {"method": "masked_adaptive", "threshold": 175})
- # ── 单图处理 ────────────────────────────────────────────────────
- def _load_image(path: Path) -> np.ndarray:
- """加载图片为 BGR ndarray。"""
- pil = Image.open(str(path)).convert("RGB")
- np_img = np.array(pil)
- return cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
- def _run_baseline(bgr: np.ndarray, cfg: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
- """运行 masked_adaptive 方法。"""
- proc = WatermarkProcessor(cfg, scope="page")
- debug: Dict[str, Any] = {}
- result, stages = proc.process(bgr, apply_removal=True, removal_debug=debug)
- return np.asarray(result), debug
- def _run_gan(
- bgr: np.ndarray,
- wm_cfg: Dict[str, Any],
- inpainter: LamaInpainter,
- ) -> Tuple[np.ndarray, Dict[str, Any]]:
- """
- 使用GAN修复水印区域。
- 1. 用 build_watermark_mask 检测水印区域
- 2. 用 LaMa 修复
- 3. 失败则回退 baseline
- """
- debug: Dict[str, Any] = {"mode": "gan"}
- gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
- mask_cfg = wm_cfg.get("mask", {})
- wm_mask, mask_debug = build_watermark_mask(gray, bgr=bgr, **mask_cfg)
- debug.update({k: v for k, v in mask_debug.items()
- if not isinstance(v, np.ndarray)})
- debug["wm_mask"] = wm_mask
- if not np.any(wm_mask):
- logger.info(" 未检测到水印区域,跳过GAN")
- debug["mode"] = "gan_no_mask"
- clean_gray, _ = _run_baseline(bgr, wm_cfg)
- return clean_gray, debug
- logger.info(f" 水印区域: {wm_mask.sum()} 像素 "
- f"({100 * wm_mask.sum() / wm_mask.size:.2f}%)")
- t0 = time.perf_counter()
- result = inpainter.inpaint(bgr, wm_mask)
- elapsed = time.perf_counter() - t0
- if result is not None:
- debug["mode"] = "gan"
- debug["gan_success"] = True
- debug["gan_inference_time_s"] = round(elapsed, 2)
- logger.info(f" GAN修复成功 ({elapsed:.1f}s)")
- # 对修复结果做对比度增强
- from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
- ce_cfg = wm_cfg.get("contrast_enhancement")
- result_gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
- result_gray = apply_contrast_enhancement_config(result_gray, ce_cfg)
- return result_gray, debug
- # GAN 失败,回退
- logger.warning(" GAN修复失败,回退 baseline")
- debug["mode"] = "gan_fallback"
- debug["fallback_reason"] = "gan_inference_failed"
- clean_gray, fallback_debug = _run_baseline(bgr, wm_cfg)
- debug["fallback_debug"] = fallback_debug
- return clean_gray, debug
- # ── 输出 ──────────────────────────────────────────────────────────
- def _make_compare_image(
- bgr: np.ndarray,
- baseline_gray: np.ndarray,
- gan_result: np.ndarray,
- wm_mask: Optional[np.ndarray] = None,
- ) -> np.ndarray:
- """生成四联对比图。"""
- def _to_bgr(arr: np.ndarray) -> np.ndarray:
- if arr.ndim == 2:
- return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
- return arr
- def _resize(arr: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
- if arr.shape[0] != target_h or arr.shape[1] != target_w:
- return cv2.resize(arr, (target_w, target_h))
- return arr
- # GAN 结果可能是灰度或BGR
- gan_bgr = _to_bgr(gan_result) if gan_result.ndim == 2 else gan_result
- if gan_result.ndim == 3 and gan_result.shape[2] == 3:
- gan_bgr = gan_result # 已经是BGR
- # 统一尺寸
- ref_h, ref_w = bgr.shape[:2]
- baseline_bgr = _to_bgr(baseline_gray) if baseline_gray.ndim == 2 else baseline_gray
- baseline_bgr = _resize(baseline_bgr, ref_h, ref_w)
- gan_bgr = _resize(gan_bgr, ref_h, ref_w)
- panels = [bgr, baseline_bgr, gan_bgr]
- # 如果有mask,叠加到原图上作为第四联
- if wm_mask is not None and np.any(wm_mask):
- mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
- panels.append(mask_overlay)
- # 添加标签
- labels = ["Original", "Baseline (masked_adaptive)", "GAN (LaMa)"]
- if len(panels) == 4:
- labels.append("Watermark Mask")
- labeled = []
- for panel, label in zip(panels, labels):
- # 底部加标签条
- bar = np.ones((36, panel.shape[1], 3), dtype=np.uint8) * 240
- cv2.putText(bar, label, (12, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1)
- labeled.append(np.vstack([panel, bar]))
- # 水平拼接
- max_h = max(p.shape[0] for p in labeled)
- for i in range(len(labeled)):
- if labeled[i].shape[0] < max_h:
- pad = np.ones((max_h - labeled[i].shape[0], labeled[i].shape[1], 3), dtype=np.uint8) * 255
- labeled[i] = np.vstack([labeled[i], pad])
- return np.hstack(labeled)
- def _save_result(
- stem: str,
- result: np.ndarray,
- output_dir: Path,
- prefix: str = "",
- ) -> Path:
- """保存结果图片。"""
- p = output_dir / f"{stem}_{prefix}.png"
- if result.ndim == 2:
- cv2.imwrite(str(p), result)
- else:
- cv2.imwrite(str(p), result)
- return p
- def _save_metrics_json(
- metrics_list: List[Dict[str, Any]],
- output_dir: Path,
- ) -> None:
- output_dir.mkdir(parents=True, exist_ok=True)
- p = output_dir / "metrics.json"
- p.write_text(json.dumps(metrics_list, ensure_ascii=False, indent=2), encoding="utf-8")
- logger.info(f"评估指标: {p}")
- # ── 主函数 ────────────────────────────────────────────────────────
- def evaluate(
- input_dir: Path,
- output_root: Path,
- *,
- clean_dir: Optional[Path] = None,
- device: str = "cpu",
- gan_only: bool = False,
- lama_ckpt: Optional[Path] = None,
- lama_config: Optional[Path] = None,
- lama_repo: Optional[Path] = None,
- ) -> None:
- """批量评估。"""
- img_files = sorted([
- f for f in input_dir.iterdir()
- if f.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
- ])
- if not img_files:
- logger.error(f"{input_dir} 下没有图片文件")
- return
- # 输出目录
- out_compare = output_root / "compare"
- out_inpainted = output_root / "inpainted"
- out_mask = output_root / "mask_debug"
- out_metrics = output_root / "metrics"
- for d in [out_compare, out_inpainted, out_mask, out_metrics]:
- d.mkdir(parents=True, exist_ok=True)
- baseline_cfg = _baseline_config()
- wm_cfg = _gan_wm_config()
- inpainter = LamaInpainter(
- device=device,
- model_ckpt_path=str(lama_ckpt) if lama_ckpt else None,
- model_config_path=str(lama_config) if lama_config else None,
- lama_repo_path=str(lama_repo) if lama_repo else None,
- )
- available = inpainter.is_available
- logger.info(f"LaMa 可用: {available}, backend: {inpainter._backend or '未加载'}")
- if not available and not gan_only:
- logger.warning("LaMa backend 不可用,GAN将回退到OpenCV inpaint")
- all_metrics: List[Dict[str, Any]] = []
- for f in img_files:
- logger.info(f"\n处理: {f.name}")
- stem = f.stem
- bgr = _load_image(f)
- # 检查是否有对应 clean 参考图
- clean_img: Optional[np.ndarray] = None
- if clean_dir:
- for ext in (".png", ".jpg", ".jpeg"):
- clean_path = clean_dir / f"{stem}{ext}"
- if clean_path.exists():
- clean_img = _load_image(clean_path)
- break
- if clean_img is None:
- # 尝试移除 _watermarked 后缀
- clean_name = stem.replace("_watermarked", "")
- for ext in (".png", ".jpg", ".jpeg"):
- clean_path = clean_dir / f"{clean_name}{ext}"
- if clean_path.exists():
- clean_img = _load_image(clean_path)
- break
- # 检测水印
- gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
- has_wm = detect_watermark(gray, ratio_threshold=0.025)
- logger.info(f" 水印检测: {'有水印' if has_wm else '无水印'}")
- # ── Baseline ──
- logger.info(" 运行 baseline (masked_adaptive)...")
- t0 = time.perf_counter()
- baseline_result, baseline_debug = _run_baseline(bgr, baseline_cfg)
- baseline_time = time.perf_counter() - t0
- logger.info(f" baseline 耗时: {baseline_time:.1f}s")
- # ── GAN ──
- t0 = time.perf_counter()
- gan_result, gan_debug = _run_gan(bgr, wm_cfg, inpainter)
- gan_time = time.perf_counter() - t0
- # ── 保存结果 ──
- _save_result(stem, baseline_result, out_inpainted, "baseline")
- gan_save = gan_result
- if gan_result.ndim == 2:
- gan_save_bgr = cv2.cvtColor(gan_result, cv2.COLOR_GRAY2BGR)
- else:
- gan_save_bgr = gan_result
- _save_result(stem, gan_save_bgr, out_inpainted, "gan")
- # ── 掩膜可视化 ──
- wm_mask = gan_debug.get("wm_mask")
- if wm_mask is not None and np.any(wm_mask):
- mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
- _save_result(stem, mask_overlay, out_mask, "mask_overlay")
- # ── 对比图 ──
- compare_img = _make_compare_image(bgr, baseline_result, gan_save_bgr, wm_mask)
- _save_result(stem, compare_img, out_compare, "compare")
- # ── 评估指标 ──
- metrics: Dict[str, Any] = {
- "file": f.name,
- "has_watermark": has_wm,
- "baseline_time_s": round(baseline_time, 2),
- "gan_time_s": round(gan_time, 2),
- "gan_mode": gan_debug.get("mode", "unknown"),
- }
- if clean_img is not None:
- # baseline vs clean
- metrics["baseline_psnr"] = round(compute_psnr(baseline_result, clean_img), 2)
- metrics["baseline_ssim"] = round(compute_ssim(baseline_result, clean_img), 4)
- # gan vs clean
- metrics["gan_psnr"] = round(compute_psnr(gan_save_bgr, clean_img), 2)
- metrics["gan_ssim"] = round(compute_ssim(gan_save_bgr, clean_img), 4)
- logger.info(
- f" PSNR: baseline={metrics['baseline_psnr']}dB, "
- f"GAN={metrics['gan_psnr']}dB"
- )
- all_metrics.append(metrics)
- _save_metrics_json(all_metrics, out_metrics)
- # 汇总
- logger.info(f"\n{'='*50}")
- logger.info(f"评估完成,共 {len(img_files)} 张图")
- logger.info(f" 对比图: {out_compare}")
- logger.info(f" 修复结果: {out_inpainted}")
- logger.info(f" 掩膜: {out_mask}")
- logger.info(f" 指标: {out_metrics}")
- if clean_img is not None:
- avg_baseline_psnr = np.mean([m.get("baseline_psnr", 0) for m in all_metrics])
- avg_gan_psnr = np.mean([m.get("gan_psnr", 0) for m in all_metrics])
- logger.info(f" 平均 PSNR: baseline={avg_baseline_psnr:.1f}dB, GAN={avg_gan_psnr:.1f}dB")
- def main():
- root = Path(__file__).parent
- parser = argparse.ArgumentParser(
- description="去水印评估:baseline vs GAN",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog=__doc__,
- )
- parser.add_argument("--input", type=Path, default=root / "test_images" / "input",
- help="输入图片目录")
- parser.add_argument("--output", type=Path, default=root / "output",
- help="输出根目录")
- parser.add_argument("--clean-dir", type=Path, default=None,
- help="clean参考图目录(用于计算PSNR/SSIM)")
- parser.add_argument("--device", type=str, default="cpu",
- choices=["cpu", "cuda", "mps"],
- help="推理设备")
- parser.add_argument("--gan-only", action="store_true",
- help="仅运行GAN(跳过baseline)")
- parser.add_argument(
- "--lama-ckpt",
- type=Path,
- default=Path("/Users/zhch158/models/big-lama/models/best.ckpt"),
- help="LaMa 权重文件路径(默认使用本地已下载权重,不走下载)",
- )
- parser.add_argument(
- "--lama-config",
- type=Path,
- default=None,
- help="LaMa config.yaml 路径(默认自动从 ckpt 同级目录推断)",
- )
- parser.add_argument(
- "--lama-repo",
- type=Path,
- default=None,
- help="LaMa 仓库路径(用于导入 saicinpainting,默认自动探测)",
- )
- args = parser.parse_args()
- evaluate(
- args.input,
- args.output,
- clean_dir=args.clean_dir,
- device=args.device,
- gan_only=args.gan_only,
- lama_ckpt=args.lama_ckpt,
- lama_config=args.lama_config,
- lama_repo=args.lama_repo,
- )
- if __name__ == "__main__":
- main()
|