|
|
@@ -0,0 +1,439 @@
|
|
|
+"""
|
|
|
+去水印评估脚本:对比 baseline (masked_adaptive) 与 LaMa GAN 方法。
|
|
|
+
|
|
|
+用法:
|
|
|
+ cd ocr_platform/ocr_tools/gan_experiments_lab
|
|
|
+
|
|
|
+ # 对 test_images/input/ 下所有图片做对比
|
|
|
+ python evaluate.py
|
|
|
+
|
|
|
+ # 指定输入/输出目录
|
|
|
+ python evaluate.py --input ./test_images/synthetic/ --output ./output/synthetic_compare
|
|
|
+
|
|
|
+ # 有clean参考图时计算 PSNR/SSIM
|
|
|
+ python evaluate.py --input ./test_images/synthetic/ --clean-dir ./test_images/clean/
|
|
|
+
|
|
|
+生成物:
|
|
|
+ output/compare/ — 三联对比图 (原图 | baseline | GAN)
|
|
|
+ output/inpainted/ — GAN 修复结果
|
|
|
+ output/mask_debug/ — 掩膜可视化
|
|
|
+ output/metrics/ — 评估指标 JSON
|
|
|
+"""
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import argparse
|
|
|
+import json
|
|
|
+import sys
|
|
|
+import time
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any, Dict, List, Optional, Tuple
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+# 将 ocr_platform 根目录加入 sys.path,以便导入 ocr_utils
|
|
|
+_repo_root = Path(__file__).parents[2]
|
|
|
+if str(_repo_root) not in sys.path:
|
|
|
+ sys.path.insert(0, str(_repo_root))
|
|
|
+
|
|
|
+from loguru import logger
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
+from ocr_utils.watermark import (
|
|
|
+ WatermarkProcessor,
|
|
|
+ build_watermark_mask,
|
|
|
+ detect_watermark,
|
|
|
+ merge_watermark_config,
|
|
|
+ render_watermark_mask_overlay,
|
|
|
+)
|
|
|
+from lama_inpaint import LamaInpainter
|
|
|
+
|
|
|
+# ── 评估指标 ────────────────────────────────────────────────────
|
|
|
+
|
|
|
+
|
|
|
+def _to_gray(img: np.ndarray) -> np.ndarray:
|
|
|
+ if img.ndim == 3:
|
|
|
+ return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float64)
|
|
|
+ return img.astype(np.float64)
|
|
|
+
|
|
|
+
|
|
|
+def compute_psnr(img1: np.ndarray, img2: np.ndarray) -> float:
|
|
|
+ g1, g2 = _to_gray(img1), _to_gray(img2)
|
|
|
+ mse = np.mean((g1 - g2) ** 2)
|
|
|
+ if mse < 1e-10:
|
|
|
+ return 100.0
|
|
|
+ return float(20 * np.log10(255.0 / np.sqrt(mse)))
|
|
|
+
|
|
|
+
|
|
|
+def compute_ssim(img1: np.ndarray, img2: np.ndarray) -> float:
|
|
|
+ """简易 SSIM 实现(灰度,8x8 block)。"""
|
|
|
+ from math import exp, pi, sqrt
|
|
|
+
|
|
|
+ g1, g2 = _to_gray(img1), _to_gray(img2)
|
|
|
+ k1, k2 = 0.01, 0.03
|
|
|
+ l = 255.0
|
|
|
+ c1, c2 = (k1 * l) ** 2, (k2 * l) ** 2
|
|
|
+
|
|
|
+ kernel = cv2.getGaussianKernel(11, 1.5)
|
|
|
+ window = np.outer(kernel, kernel)
|
|
|
+ window /= window.sum()
|
|
|
+
|
|
|
+ mu1 = cv2.filter2D(g1, -1, window, borderType=cv2.BORDER_REFLECT)
|
|
|
+ mu2 = cv2.filter2D(g2, -1, window, borderType=cv2.BORDER_REFLECT)
|
|
|
+ mu1_sq = mu1 * mu1
|
|
|
+ mu2_sq = mu2 * mu2
|
|
|
+ mu1_mu2 = mu1 * mu2
|
|
|
+ sigma1_sq = cv2.filter2D(g1 * g1, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_sq
|
|
|
+ sigma2_sq = cv2.filter2D(g2 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu2_sq
|
|
|
+ sigma12 = cv2.filter2D(g1 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_mu2
|
|
|
+
|
|
|
+ num = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
|
|
|
+ denom = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
|
|
|
+ ssim_map = num / (denom + 1e-10)
|
|
|
+ return float(ssim_map.mean())
|
|
|
+
|
|
|
+
|
|
|
+# ── 水印配置 ────────────────────────────────────────────────────
|
|
|
+
|
|
|
+
|
|
|
+def _baseline_config() -> Dict[str, Any]:
|
|
|
+ return merge_watermark_config("page", {
|
|
|
+ "method": "masked_adaptive",
|
|
|
+ "threshold": 175,
|
|
|
+ "contrast_enhancement": {"enabled": True, "method": "text_restore", "text_black_target": 85},
|
|
|
+ })
|
|
|
+
|
|
|
+
|
|
|
+def _gan_wm_config() -> Dict[str, Any]:
|
|
|
+ return merge_watermark_config("page", {"method": "masked_adaptive", "threshold": 175})
|
|
|
+
|
|
|
+
|
|
|
+# ── 单图处理 ────────────────────────────────────────────────────
|
|
|
+
|
|
|
+
|
|
|
+def _load_image(path: Path) -> np.ndarray:
|
|
|
+ """加载图片为 BGR ndarray。"""
|
|
|
+ pil = Image.open(str(path)).convert("RGB")
|
|
|
+ np_img = np.array(pil)
|
|
|
+ return cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
|
|
|
+
|
|
|
+
|
|
|
+def _run_baseline(bgr: np.ndarray, cfg: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
|
|
|
+ """运行 masked_adaptive 方法。"""
|
|
|
+ proc = WatermarkProcessor(cfg, scope="page")
|
|
|
+ debug: Dict[str, Any] = {}
|
|
|
+ result, stages = proc.process(bgr, apply_removal=True, removal_debug=debug)
|
|
|
+ return np.asarray(result), debug
|
|
|
+
|
|
|
+
|
|
|
+def _run_gan(
|
|
|
+ bgr: np.ndarray,
|
|
|
+ wm_cfg: Dict[str, Any],
|
|
|
+ inpainter: LamaInpainter,
|
|
|
+) -> Tuple[np.ndarray, Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 使用GAN修复水印区域。
|
|
|
+
|
|
|
+ 1. 用 build_watermark_mask 检测水印区域
|
|
|
+ 2. 用 LaMa 修复
|
|
|
+ 3. 失败则回退 baseline
|
|
|
+ """
|
|
|
+ debug: Dict[str, Any] = {"mode": "gan"}
|
|
|
+
|
|
|
+ gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
|
|
+ mask_cfg = wm_cfg.get("mask", {})
|
|
|
+ wm_mask, mask_debug = build_watermark_mask(gray, bgr=bgr, **mask_cfg)
|
|
|
+
|
|
|
+ debug.update({k: v for k, v in mask_debug.items()
|
|
|
+ if not isinstance(v, np.ndarray)})
|
|
|
+ debug["wm_mask"] = wm_mask
|
|
|
+
|
|
|
+ if not np.any(wm_mask):
|
|
|
+ logger.info(" 未检测到水印区域,跳过GAN")
|
|
|
+ debug["mode"] = "gan_no_mask"
|
|
|
+ clean_gray, _ = _run_baseline(bgr, wm_cfg)
|
|
|
+ return clean_gray, debug
|
|
|
+
|
|
|
+ logger.info(f" 水印区域: {wm_mask.sum()} 像素 "
|
|
|
+ f"({100 * wm_mask.sum() / wm_mask.size:.2f}%)")
|
|
|
+
|
|
|
+ t0 = time.perf_counter()
|
|
|
+ result = inpainter.inpaint(bgr, wm_mask)
|
|
|
+ elapsed = time.perf_counter() - t0
|
|
|
+
|
|
|
+ if result is not None:
|
|
|
+ debug["mode"] = "gan"
|
|
|
+ debug["gan_success"] = True
|
|
|
+ debug["gan_inference_time_s"] = round(elapsed, 2)
|
|
|
+ logger.info(f" GAN修复成功 ({elapsed:.1f}s)")
|
|
|
+ # 对修复结果做对比度增强
|
|
|
+ from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
|
|
|
+ ce_cfg = wm_cfg.get("contrast_enhancement")
|
|
|
+ result_gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
|
|
|
+ result_gray = apply_contrast_enhancement_config(result_gray, ce_cfg)
|
|
|
+ return result_gray, debug
|
|
|
+
|
|
|
+ # GAN 失败,回退
|
|
|
+ logger.warning(" GAN修复失败,回退 baseline")
|
|
|
+ debug["mode"] = "gan_fallback"
|
|
|
+ debug["fallback_reason"] = "gan_inference_failed"
|
|
|
+ clean_gray, fallback_debug = _run_baseline(bgr, wm_cfg)
|
|
|
+ debug["fallback_debug"] = fallback_debug
|
|
|
+ return clean_gray, debug
|
|
|
+
|
|
|
+
|
|
|
+# ── 输出 ──────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+
|
|
|
+def _make_compare_image(
|
|
|
+ bgr: np.ndarray,
|
|
|
+ baseline_gray: np.ndarray,
|
|
|
+ gan_result: np.ndarray,
|
|
|
+ wm_mask: Optional[np.ndarray] = None,
|
|
|
+) -> np.ndarray:
|
|
|
+ """生成四联对比图。"""
|
|
|
+ h, w = baseline_gray.shape[:2] if baseline_gray.ndim == 2 else baseline_gray.shape
|
|
|
+
|
|
|
+ def _to_bgr(arr: np.ndarray) -> np.ndarray:
|
|
|
+ if arr.ndim == 2:
|
|
|
+ return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
|
|
|
+ return arr
|
|
|
+
|
|
|
+ def _resize(arr: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
|
|
|
+ if arr.shape[0] != target_h or arr.shape[1] != target_w:
|
|
|
+ return cv2.resize(arr, (target_w, target_h))
|
|
|
+ return arr
|
|
|
+
|
|
|
+ # GAN 结果可能是灰度或BGR
|
|
|
+ gan_bgr = _to_bgr(gan_result) if gan_result.ndim == 2 else gan_result
|
|
|
+ if gan_result.ndim == 3 and gan_result.shape[2] == 3:
|
|
|
+ gan_bgr = gan_result # 已经是BGR
|
|
|
+
|
|
|
+ # 统一尺寸
|
|
|
+ ref_h, ref_w = bgr.shape[:2]
|
|
|
+ baseline_bgr = _to_bgr(baseline_gray) if baseline_gray.ndim == 2 else baseline_gray
|
|
|
+ baseline_bgr = _resize(baseline_bgr, ref_h, ref_w)
|
|
|
+ gan_bgr = _resize(gan_bgr, ref_h, ref_w)
|
|
|
+
|
|
|
+ panels = [bgr, baseline_bgr, gan_bgr]
|
|
|
+
|
|
|
+ # 如果有mask,叠加到原图上作为第四联
|
|
|
+ if wm_mask is not None and np.any(wm_mask):
|
|
|
+ mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
|
|
|
+ panels.append(mask_overlay)
|
|
|
+
|
|
|
+ # 添加标签
|
|
|
+ labels = ["Original", "Baseline (masked_adaptive)", "GAN (LaMa)"]
|
|
|
+ if len(panels) == 4:
|
|
|
+ labels.append("Watermark Mask")
|
|
|
+
|
|
|
+ labeled = []
|
|
|
+ for panel, label in zip(panels, labels):
|
|
|
+ h_p = panel.shape[0]
|
|
|
+ # 底部加标签条
|
|
|
+ bar = np.ones((36, panel.shape[1], 3), dtype=np.uint8) * 240
|
|
|
+ cv2.putText(bar, label, (12, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1)
|
|
|
+ labeled.append(np.vstack([panel, bar]))
|
|
|
+
|
|
|
+ # 水平拼接
|
|
|
+ max_h = max(p.shape[0] for p in labeled)
|
|
|
+ for i in range(len(labeled)):
|
|
|
+ if labeled[i].shape[0] < max_h:
|
|
|
+ pad = np.ones((max_h - labeled[i].shape[0], labeled[i].shape[1], 3), dtype=np.uint8) * 255
|
|
|
+ labeled[i] = np.vstack([labeled[i], pad])
|
|
|
+
|
|
|
+ return np.hstack(labeled)
|
|
|
+
|
|
|
+
|
|
|
+def _save_result(
|
|
|
+ stem: str,
|
|
|
+ result: np.ndarray,
|
|
|
+ output_dir: Path,
|
|
|
+ prefix: str = "",
|
|
|
+) -> Path:
|
|
|
+ """保存结果图片。"""
|
|
|
+ p = output_dir / f"{stem}_{prefix}.png"
|
|
|
+ if result.ndim == 2:
|
|
|
+ cv2.imwrite(str(p), result)
|
|
|
+ else:
|
|
|
+ cv2.imwrite(str(p), result)
|
|
|
+ return p
|
|
|
+
|
|
|
+
|
|
|
+def _save_metrics_json(
|
|
|
+ metrics_list: List[Dict[str, Any]],
|
|
|
+ output_dir: Path,
|
|
|
+) -> None:
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ p = output_dir / "metrics.json"
|
|
|
+ p.write_text(json.dumps(metrics_list, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
+ logger.info(f"评估指标: {p}")
|
|
|
+
|
|
|
+
|
|
|
+# ── 主函数 ────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+
|
|
|
+def evaluate(
|
|
|
+ input_dir: Path,
|
|
|
+ output_root: Path,
|
|
|
+ *,
|
|
|
+ clean_dir: Optional[Path] = None,
|
|
|
+ device: str = "cpu",
|
|
|
+ gan_only: bool = False,
|
|
|
+) -> None:
|
|
|
+ """批量评估。"""
|
|
|
+ img_files = sorted([
|
|
|
+ f for f in input_dir.iterdir()
|
|
|
+ if f.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
|
|
|
+ ])
|
|
|
+ if not img_files:
|
|
|
+ logger.error(f"{input_dir} 下没有图片文件")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 输出目录
|
|
|
+ out_compare = output_root / "compare"
|
|
|
+ out_inpainted = output_root / "inpainted"
|
|
|
+ out_mask = output_root / "mask_debug"
|
|
|
+ out_metrics = output_root / "metrics"
|
|
|
+ for d in [out_compare, out_inpainted, out_mask, out_metrics]:
|
|
|
+ d.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ baseline_cfg = _baseline_config()
|
|
|
+ wm_cfg = _gan_wm_config()
|
|
|
+
|
|
|
+ inpainter = LamaInpainter(device=device)
|
|
|
+ available = inpainter.is_available
|
|
|
+ logger.info(f"LaMa 可用: {available}, backend: {inpainter._backend or '未加载'}")
|
|
|
+
|
|
|
+ if not available and not gan_only:
|
|
|
+ logger.warning("LaMa backend 不可用,GAN将回退到OpenCV inpaint")
|
|
|
+
|
|
|
+ all_metrics: List[Dict[str, Any]] = []
|
|
|
+
|
|
|
+ for f in img_files:
|
|
|
+ logger.info(f"\n处理: {f.name}")
|
|
|
+ stem = f.stem
|
|
|
+ bgr = _load_image(f)
|
|
|
+
|
|
|
+ # 检查是否有对应 clean 参考图
|
|
|
+ clean_img: Optional[np.ndarray] = None
|
|
|
+ if clean_dir:
|
|
|
+ for ext in (".png", ".jpg", ".jpeg"):
|
|
|
+ clean_path = clean_dir / f"{stem}{ext}"
|
|
|
+ if clean_path.exists():
|
|
|
+ clean_img = _load_image(clean_path)
|
|
|
+ break
|
|
|
+ if clean_img is None:
|
|
|
+ # 尝试移除 _watermarked 后缀
|
|
|
+ clean_name = stem.replace("_watermarked", "")
|
|
|
+ for ext in (".png", ".jpg", ".jpeg"):
|
|
|
+ clean_path = clean_dir / f"{clean_name}{ext}"
|
|
|
+ if clean_path.exists():
|
|
|
+ clean_img = _load_image(clean_path)
|
|
|
+ break
|
|
|
+
|
|
|
+ # 检测水印
|
|
|
+ gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
|
|
+ has_wm = detect_watermark(gray, ratio_threshold=0.025)
|
|
|
+ logger.info(f" 水印检测: {'有水印' if has_wm else '无水印'}")
|
|
|
+
|
|
|
+ # ── Baseline ──
|
|
|
+ logger.info(" 运行 baseline (masked_adaptive)...")
|
|
|
+ t0 = time.perf_counter()
|
|
|
+ baseline_result, baseline_debug = _run_baseline(bgr, baseline_cfg)
|
|
|
+ baseline_time = time.perf_counter() - t0
|
|
|
+ logger.info(f" baseline 耗时: {baseline_time:.1f}s")
|
|
|
+
|
|
|
+ # ── GAN ──
|
|
|
+ t0 = time.perf_counter()
|
|
|
+ gan_result, gan_debug = _run_gan(bgr, wm_cfg, inpainter)
|
|
|
+ gan_time = time.perf_counter() - t0
|
|
|
+
|
|
|
+ # ── 保存结果 ──
|
|
|
+ _save_result(stem, baseline_result, out_inpainted, "baseline")
|
|
|
+ gan_save = gan_result
|
|
|
+ if gan_result.ndim == 2:
|
|
|
+ gan_save_bgr = cv2.cvtColor(gan_result, cv2.COLOR_GRAY2BGR)
|
|
|
+ else:
|
|
|
+ gan_save_bgr = gan_result
|
|
|
+ _save_result(stem, gan_save_bgr, out_inpainted, "gan")
|
|
|
+
|
|
|
+ # ── 掩膜可视化 ──
|
|
|
+ wm_mask = gan_debug.get("wm_mask")
|
|
|
+ if wm_mask is not None and np.any(wm_mask):
|
|
|
+ mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
|
|
|
+ _save_result(stem, mask_overlay, out_mask, "mask_overlay")
|
|
|
+
|
|
|
+ # ── 对比图 ──
|
|
|
+ compare_img = _make_compare_image(bgr, baseline_result, gan_save_bgr, wm_mask)
|
|
|
+ _save_result(stem, compare_img, out_compare, "compare")
|
|
|
+
|
|
|
+ # ── 评估指标 ──
|
|
|
+ metrics: Dict[str, Any] = {
|
|
|
+ "file": f.name,
|
|
|
+ "has_watermark": has_wm,
|
|
|
+ "baseline_time_s": round(baseline_time, 2),
|
|
|
+ "gan_time_s": round(gan_time, 2),
|
|
|
+ "gan_mode": gan_debug.get("mode", "unknown"),
|
|
|
+ }
|
|
|
+ if clean_img is not None:
|
|
|
+ # baseline vs clean
|
|
|
+ metrics["baseline_psnr"] = round(compute_psnr(baseline_result, clean_img), 2)
|
|
|
+ metrics["baseline_ssim"] = round(compute_ssim(baseline_result, clean_img), 4)
|
|
|
+ # gan vs clean
|
|
|
+ metrics["gan_psnr"] = round(compute_psnr(gan_save_bgr, clean_img), 2)
|
|
|
+ metrics["gan_ssim"] = round(compute_ssim(gan_save_bgr, clean_img), 4)
|
|
|
+ logger.info(
|
|
|
+ f" PSNR: baseline={metrics['baseline_psnr']}dB, "
|
|
|
+ f"GAN={metrics['gan_psnr']}dB"
|
|
|
+ )
|
|
|
+ all_metrics.append(metrics)
|
|
|
+
|
|
|
+ _save_metrics_json(all_metrics, out_metrics)
|
|
|
+
|
|
|
+ # 汇总
|
|
|
+ logger.info(f"\n{'='*50}")
|
|
|
+ logger.info(f"评估完成,共 {len(img_files)} 张图")
|
|
|
+ logger.info(f" 对比图: {out_compare}")
|
|
|
+ logger.info(f" 修复结果: {out_inpainted}")
|
|
|
+ logger.info(f" 掩膜: {out_mask}")
|
|
|
+ logger.info(f" 指标: {out_metrics}")
|
|
|
+
|
|
|
+ if clean_img is not None:
|
|
|
+ avg_baseline_psnr = np.mean([m.get("baseline_psnr", 0) for m in all_metrics])
|
|
|
+ avg_gan_psnr = np.mean([m.get("gan_psnr", 0) for m in all_metrics])
|
|
|
+ logger.info(f" 平均 PSNR: baseline={avg_baseline_psnr:.1f}dB, GAN={avg_gan_psnr:.1f}dB")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ root = Path(__file__).parent
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="去水印评估:baseline vs GAN",
|
|
|
+ formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
+ epilog=__doc__,
|
|
|
+ )
|
|
|
+ parser.add_argument("--input", type=Path, default=root / "test_images" / "input",
|
|
|
+ help="输入图片目录")
|
|
|
+ parser.add_argument("--output", type=Path, default=root / "output",
|
|
|
+ help="输出根目录")
|
|
|
+ parser.add_argument("--clean-dir", type=Path, default=None,
|
|
|
+ help="clean参考图目录(用于计算PSNR/SSIM)")
|
|
|
+ parser.add_argument("--device", type=str, default="cpu",
|
|
|
+ choices=["cpu", "cuda", "mps"],
|
|
|
+ help="推理设备")
|
|
|
+ parser.add_argument("--gan-only", action="store_true",
|
|
|
+ help="仅运行GAN(跳过baseline)")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ evaluate(
|
|
|
+ args.input,
|
|
|
+ args.output,
|
|
|
+ clean_dir=args.clean_dir,
|
|
|
+ device=args.device,
|
|
|
+ gan_only=args.gan_only,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|