evaluate.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. """
  2. 去水印评估脚本:对比 baseline (masked_adaptive) 与 LaMa GAN 方法。
  3. 用法:
  4. cd ocr_platform/ocr_tools/gan_experiments_lab
  5. # 对 test_images/input/ 下所有图片做对比
  6. python evaluate.py
  7. # 指定输入/输出目录
  8. python evaluate.py --input ./test_images/synthetic/ --output ./output/synthetic_compare
  9. # 有clean参考图时计算 PSNR/SSIM
  10. python evaluate.py --input ./test_images/synthetic/ --clean-dir ./test_images/clean/
  11. 生成物:
  12. output/compare/ — 三联对比图 (原图 | baseline | GAN)
  13. output/inpainted/ — GAN 修复结果
  14. output/mask_debug/ — 掩膜可视化
  15. output/metrics/ — 评估指标 JSON
  16. """
  17. from __future__ import annotations
  18. import argparse
  19. import json
  20. import sys
  21. import time
  22. from pathlib import Path
  23. from typing import Any, Dict, List, Optional, Tuple
  24. import cv2
  25. import numpy as np
  26. # 将 ocr_platform 根目录加入 sys.path,以便导入 ocr_utils
  27. _repo_root = Path(__file__).parents[2]
  28. if str(_repo_root) not in sys.path:
  29. sys.path.insert(0, str(_repo_root))
  30. from loguru import logger
  31. from PIL import Image
  32. from ocr_utils.watermark import (
  33. WatermarkProcessor,
  34. build_watermark_mask,
  35. detect_watermark,
  36. merge_watermark_config,
  37. render_watermark_mask_overlay,
  38. )
  39. from lama_inpaint import LamaInpainter
  40. # ── 评估指标 ────────────────────────────────────────────────────
  41. def _to_gray(img: np.ndarray) -> np.ndarray:
  42. if img.ndim == 3:
  43. return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float64)
  44. return img.astype(np.float64)
  45. def compute_psnr(img1: np.ndarray, img2: np.ndarray) -> float:
  46. g1, g2 = _to_gray(img1), _to_gray(img2)
  47. mse = np.mean((g1 - g2) ** 2)
  48. if mse < 1e-10:
  49. return 100.0
  50. return float(20 * np.log10(255.0 / np.sqrt(mse)))
  51. def compute_ssim(img1: np.ndarray, img2: np.ndarray) -> float:
  52. """简易 SSIM 实现(灰度,8x8 block)。"""
  53. from math import exp, pi, sqrt
  54. g1, g2 = _to_gray(img1), _to_gray(img2)
  55. k1, k2 = 0.01, 0.03
  56. l = 255.0
  57. c1, c2 = (k1 * l) ** 2, (k2 * l) ** 2
  58. kernel = cv2.getGaussianKernel(11, 1.5)
  59. window = np.outer(kernel, kernel)
  60. window /= window.sum()
  61. mu1 = cv2.filter2D(g1, -1, window, borderType=cv2.BORDER_REFLECT)
  62. mu2 = cv2.filter2D(g2, -1, window, borderType=cv2.BORDER_REFLECT)
  63. mu1_sq = mu1 * mu1
  64. mu2_sq = mu2 * mu2
  65. mu1_mu2 = mu1 * mu2
  66. sigma1_sq = cv2.filter2D(g1 * g1, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_sq
  67. sigma2_sq = cv2.filter2D(g2 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu2_sq
  68. sigma12 = cv2.filter2D(g1 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_mu2
  69. num = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
  70. denom = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
  71. ssim_map = num / (denom + 1e-10)
  72. return float(ssim_map.mean())
  73. # ── 水印配置 ────────────────────────────────────────────────────
  74. def _baseline_config() -> Dict[str, Any]:
  75. return merge_watermark_config("page", {
  76. "method": "masked_adaptive",
  77. "threshold": 175,
  78. "contrast_enhancement": {"enabled": True, "method": "text_restore", "text_black_target": 85},
  79. })
  80. def _gan_wm_config() -> Dict[str, Any]:
  81. return merge_watermark_config("page", {"method": "masked_adaptive", "threshold": 175})
  82. # ── 单图处理 ────────────────────────────────────────────────────
  83. def _load_image(path: Path) -> np.ndarray:
  84. """加载图片为 BGR ndarray。"""
  85. pil = Image.open(str(path)).convert("RGB")
  86. np_img = np.array(pil)
  87. return cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
  88. def _run_baseline(bgr: np.ndarray, cfg: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
  89. """运行 masked_adaptive 方法。"""
  90. proc = WatermarkProcessor(cfg, scope="page")
  91. debug: Dict[str, Any] = {}
  92. result, stages = proc.process(bgr, apply_removal=True, removal_debug=debug)
  93. return np.asarray(result), debug
  94. def _run_gan(
  95. bgr: np.ndarray,
  96. wm_cfg: Dict[str, Any],
  97. inpainter: LamaInpainter,
  98. ) -> Tuple[np.ndarray, Dict[str, Any]]:
  99. """
  100. 使用GAN修复水印区域。
  101. 1. 用 build_watermark_mask 检测水印区域
  102. 2. 用 LaMa 修复
  103. 3. 失败则回退 baseline
  104. """
  105. debug: Dict[str, Any] = {"mode": "gan"}
  106. gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
  107. mask_cfg = wm_cfg.get("mask", {})
  108. wm_mask, mask_debug = build_watermark_mask(gray, bgr=bgr, **mask_cfg)
  109. debug.update({k: v for k, v in mask_debug.items()
  110. if not isinstance(v, np.ndarray)})
  111. debug["wm_mask"] = wm_mask
  112. if not np.any(wm_mask):
  113. logger.info(" 未检测到水印区域,跳过GAN")
  114. debug["mode"] = "gan_no_mask"
  115. clean_gray, _ = _run_baseline(bgr, wm_cfg)
  116. return clean_gray, debug
  117. logger.info(f" 水印区域: {wm_mask.sum()} 像素 "
  118. f"({100 * wm_mask.sum() / wm_mask.size:.2f}%)")
  119. t0 = time.perf_counter()
  120. result = inpainter.inpaint(bgr, wm_mask)
  121. elapsed = time.perf_counter() - t0
  122. if result is not None:
  123. debug["mode"] = "gan"
  124. debug["gan_success"] = True
  125. debug["gan_inference_time_s"] = round(elapsed, 2)
  126. logger.info(f" GAN修复成功 ({elapsed:.1f}s)")
  127. # 对修复结果做对比度增强
  128. from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
  129. ce_cfg = wm_cfg.get("contrast_enhancement")
  130. result_gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
  131. result_gray = apply_contrast_enhancement_config(result_gray, ce_cfg)
  132. return result_gray, debug
  133. # GAN 失败,回退
  134. logger.warning(" GAN修复失败,回退 baseline")
  135. debug["mode"] = "gan_fallback"
  136. debug["fallback_reason"] = "gan_inference_failed"
  137. clean_gray, fallback_debug = _run_baseline(bgr, wm_cfg)
  138. debug["fallback_debug"] = fallback_debug
  139. return clean_gray, debug
  140. # ── 输出 ──────────────────────────────────────────────────────────
  141. def _make_compare_image(
  142. bgr: np.ndarray,
  143. baseline_gray: np.ndarray,
  144. gan_result: np.ndarray,
  145. wm_mask: Optional[np.ndarray] = None,
  146. ) -> np.ndarray:
  147. """生成四联对比图。"""
  148. h, w = baseline_gray.shape[:2] if baseline_gray.ndim == 2 else baseline_gray.shape
  149. def _to_bgr(arr: np.ndarray) -> np.ndarray:
  150. if arr.ndim == 2:
  151. return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
  152. return arr
  153. def _resize(arr: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
  154. if arr.shape[0] != target_h or arr.shape[1] != target_w:
  155. return cv2.resize(arr, (target_w, target_h))
  156. return arr
  157. # GAN 结果可能是灰度或BGR
  158. gan_bgr = _to_bgr(gan_result) if gan_result.ndim == 2 else gan_result
  159. if gan_result.ndim == 3 and gan_result.shape[2] == 3:
  160. gan_bgr = gan_result # 已经是BGR
  161. # 统一尺寸
  162. ref_h, ref_w = bgr.shape[:2]
  163. baseline_bgr = _to_bgr(baseline_gray) if baseline_gray.ndim == 2 else baseline_gray
  164. baseline_bgr = _resize(baseline_bgr, ref_h, ref_w)
  165. gan_bgr = _resize(gan_bgr, ref_h, ref_w)
  166. panels = [bgr, baseline_bgr, gan_bgr]
  167. # 如果有mask,叠加到原图上作为第四联
  168. if wm_mask is not None and np.any(wm_mask):
  169. mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
  170. panels.append(mask_overlay)
  171. # 添加标签
  172. labels = ["Original", "Baseline (masked_adaptive)", "GAN (LaMa)"]
  173. if len(panels) == 4:
  174. labels.append("Watermark Mask")
  175. labeled = []
  176. for panel, label in zip(panels, labels):
  177. h_p = panel.shape[0]
  178. # 底部加标签条
  179. bar = np.ones((36, panel.shape[1], 3), dtype=np.uint8) * 240
  180. cv2.putText(bar, label, (12, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1)
  181. labeled.append(np.vstack([panel, bar]))
  182. # 水平拼接
  183. max_h = max(p.shape[0] for p in labeled)
  184. for i in range(len(labeled)):
  185. if labeled[i].shape[0] < max_h:
  186. pad = np.ones((max_h - labeled[i].shape[0], labeled[i].shape[1], 3), dtype=np.uint8) * 255
  187. labeled[i] = np.vstack([labeled[i], pad])
  188. return np.hstack(labeled)
  189. def _save_result(
  190. stem: str,
  191. result: np.ndarray,
  192. output_dir: Path,
  193. prefix: str = "",
  194. ) -> Path:
  195. """保存结果图片。"""
  196. p = output_dir / f"{stem}_{prefix}.png"
  197. if result.ndim == 2:
  198. cv2.imwrite(str(p), result)
  199. else:
  200. cv2.imwrite(str(p), result)
  201. return p
  202. def _save_metrics_json(
  203. metrics_list: List[Dict[str, Any]],
  204. output_dir: Path,
  205. ) -> None:
  206. output_dir.mkdir(parents=True, exist_ok=True)
  207. p = output_dir / "metrics.json"
  208. p.write_text(json.dumps(metrics_list, ensure_ascii=False, indent=2), encoding="utf-8")
  209. logger.info(f"评估指标: {p}")
  210. # ── 主函数 ────────────────────────────────────────────────────────
  211. def evaluate(
  212. input_dir: Path,
  213. output_root: Path,
  214. *,
  215. clean_dir: Optional[Path] = None,
  216. device: str = "cpu",
  217. gan_only: bool = False,
  218. ) -> None:
  219. """批量评估。"""
  220. img_files = sorted([
  221. f for f in input_dir.iterdir()
  222. if f.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  223. ])
  224. if not img_files:
  225. logger.error(f"{input_dir} 下没有图片文件")
  226. return
  227. # 输出目录
  228. out_compare = output_root / "compare"
  229. out_inpainted = output_root / "inpainted"
  230. out_mask = output_root / "mask_debug"
  231. out_metrics = output_root / "metrics"
  232. for d in [out_compare, out_inpainted, out_mask, out_metrics]:
  233. d.mkdir(parents=True, exist_ok=True)
  234. baseline_cfg = _baseline_config()
  235. wm_cfg = _gan_wm_config()
  236. inpainter = LamaInpainter(device=device)
  237. available = inpainter.is_available
  238. logger.info(f"LaMa 可用: {available}, backend: {inpainter._backend or '未加载'}")
  239. if not available and not gan_only:
  240. logger.warning("LaMa backend 不可用,GAN将回退到OpenCV inpaint")
  241. all_metrics: List[Dict[str, Any]] = []
  242. for f in img_files:
  243. logger.info(f"\n处理: {f.name}")
  244. stem = f.stem
  245. bgr = _load_image(f)
  246. # 检查是否有对应 clean 参考图
  247. clean_img: Optional[np.ndarray] = None
  248. if clean_dir:
  249. for ext in (".png", ".jpg", ".jpeg"):
  250. clean_path = clean_dir / f"{stem}{ext}"
  251. if clean_path.exists():
  252. clean_img = _load_image(clean_path)
  253. break
  254. if clean_img is None:
  255. # 尝试移除 _watermarked 后缀
  256. clean_name = stem.replace("_watermarked", "")
  257. for ext in (".png", ".jpg", ".jpeg"):
  258. clean_path = clean_dir / f"{clean_name}{ext}"
  259. if clean_path.exists():
  260. clean_img = _load_image(clean_path)
  261. break
  262. # 检测水印
  263. gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
  264. has_wm = detect_watermark(gray, ratio_threshold=0.025)
  265. logger.info(f" 水印检测: {'有水印' if has_wm else '无水印'}")
  266. # ── Baseline ──
  267. logger.info(" 运行 baseline (masked_adaptive)...")
  268. t0 = time.perf_counter()
  269. baseline_result, baseline_debug = _run_baseline(bgr, baseline_cfg)
  270. baseline_time = time.perf_counter() - t0
  271. logger.info(f" baseline 耗时: {baseline_time:.1f}s")
  272. # ── GAN ──
  273. t0 = time.perf_counter()
  274. gan_result, gan_debug = _run_gan(bgr, wm_cfg, inpainter)
  275. gan_time = time.perf_counter() - t0
  276. # ── 保存结果 ──
  277. _save_result(stem, baseline_result, out_inpainted, "baseline")
  278. gan_save = gan_result
  279. if gan_result.ndim == 2:
  280. gan_save_bgr = cv2.cvtColor(gan_result, cv2.COLOR_GRAY2BGR)
  281. else:
  282. gan_save_bgr = gan_result
  283. _save_result(stem, gan_save_bgr, out_inpainted, "gan")
  284. # ── 掩膜可视化 ──
  285. wm_mask = gan_debug.get("wm_mask")
  286. if wm_mask is not None and np.any(wm_mask):
  287. mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
  288. _save_result(stem, mask_overlay, out_mask, "mask_overlay")
  289. # ── 对比图 ──
  290. compare_img = _make_compare_image(bgr, baseline_result, gan_save_bgr, wm_mask)
  291. _save_result(stem, compare_img, out_compare, "compare")
  292. # ── 评估指标 ──
  293. metrics: Dict[str, Any] = {
  294. "file": f.name,
  295. "has_watermark": has_wm,
  296. "baseline_time_s": round(baseline_time, 2),
  297. "gan_time_s": round(gan_time, 2),
  298. "gan_mode": gan_debug.get("mode", "unknown"),
  299. }
  300. if clean_img is not None:
  301. # baseline vs clean
  302. metrics["baseline_psnr"] = round(compute_psnr(baseline_result, clean_img), 2)
  303. metrics["baseline_ssim"] = round(compute_ssim(baseline_result, clean_img), 4)
  304. # gan vs clean
  305. metrics["gan_psnr"] = round(compute_psnr(gan_save_bgr, clean_img), 2)
  306. metrics["gan_ssim"] = round(compute_ssim(gan_save_bgr, clean_img), 4)
  307. logger.info(
  308. f" PSNR: baseline={metrics['baseline_psnr']}dB, "
  309. f"GAN={metrics['gan_psnr']}dB"
  310. )
  311. all_metrics.append(metrics)
  312. _save_metrics_json(all_metrics, out_metrics)
  313. # 汇总
  314. logger.info(f"\n{'='*50}")
  315. logger.info(f"评估完成,共 {len(img_files)} 张图")
  316. logger.info(f" 对比图: {out_compare}")
  317. logger.info(f" 修复结果: {out_inpainted}")
  318. logger.info(f" 掩膜: {out_mask}")
  319. logger.info(f" 指标: {out_metrics}")
  320. if clean_img is not None:
  321. avg_baseline_psnr = np.mean([m.get("baseline_psnr", 0) for m in all_metrics])
  322. avg_gan_psnr = np.mean([m.get("gan_psnr", 0) for m in all_metrics])
  323. logger.info(f" 平均 PSNR: baseline={avg_baseline_psnr:.1f}dB, GAN={avg_gan_psnr:.1f}dB")
  324. def main():
  325. root = Path(__file__).parent
  326. parser = argparse.ArgumentParser(
  327. description="去水印评估:baseline vs GAN",
  328. formatter_class=argparse.RawDescriptionHelpFormatter,
  329. epilog=__doc__,
  330. )
  331. parser.add_argument("--input", type=Path, default=root / "test_images" / "input",
  332. help="输入图片目录")
  333. parser.add_argument("--output", type=Path, default=root / "output",
  334. help="输出根目录")
  335. parser.add_argument("--clean-dir", type=Path, default=None,
  336. help="clean参考图目录(用于计算PSNR/SSIM)")
  337. parser.add_argument("--device", type=str, default="cpu",
  338. choices=["cpu", "cuda", "mps"],
  339. help="推理设备")
  340. parser.add_argument("--gan-only", action="store_true",
  341. help="仅运行GAN(跳过baseline)")
  342. args = parser.parse_args()
  343. evaluate(
  344. args.input,
  345. args.output,
  346. clean_dir=args.clean_dir,
  347. device=args.device,
  348. gan_only=args.gan_only,
  349. )
  350. if __name__ == "__main__":
  351. main()