Prechádzať zdrojové kódy

feat(新增水印评估与合成模块): 添加evaluate.py用于对比baseline与LaMa GAN方法的水印去除效果,新增lama_inpaint.py实现LaMa模型的推理,新增watermark_synthesis.py用于合成水印并生成相应的mask,提升水印处理的评估与合成能力。

zhch158_admin 3 dní pred
rodič
commit
eb694a01bb

+ 439 - 0
ocr_tools/gan_experiments_lab/evaluate.py

@@ -0,0 +1,439 @@
+"""
+去水印评估脚本:对比 baseline (masked_adaptive) 与 LaMa GAN 方法。
+
+用法:
+    cd ocr_platform/ocr_tools/gan_experiments_lab
+
+    # 对 test_images/input/ 下所有图片做对比
+    python evaluate.py
+
+    # 指定输入/输出目录
+    python evaluate.py --input ./test_images/synthetic/ --output ./output/synthetic_compare
+
+    # 有clean参考图时计算 PSNR/SSIM
+    python evaluate.py --input ./test_images/synthetic/ --clean-dir ./test_images/clean/
+
+生成物:
+    output/compare/     — 三联对比图 (原图 | baseline | GAN)
+    output/inpainted/   — GAN 修复结果
+    output/mask_debug/  — 掩膜可视化
+    output/metrics/     — 评估指标 JSON
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import sys
+import time
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import cv2
+import numpy as np
+
+# 将 ocr_platform 根目录加入 sys.path,以便导入 ocr_utils
+_repo_root = Path(__file__).parents[2]
+if str(_repo_root) not in sys.path:
+    sys.path.insert(0, str(_repo_root))
+
+from loguru import logger
+from PIL import Image
+
+from ocr_utils.watermark import (
+    WatermarkProcessor,
+    build_watermark_mask,
+    detect_watermark,
+    merge_watermark_config,
+    render_watermark_mask_overlay,
+)
+from lama_inpaint import LamaInpainter
+
+# ── 评估指标 ────────────────────────────────────────────────────
+
+
+def _to_gray(img: np.ndarray) -> np.ndarray:
+    if img.ndim == 3:
+        return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float64)
+    return img.astype(np.float64)
+
+
+def compute_psnr(img1: np.ndarray, img2: np.ndarray) -> float:
+    g1, g2 = _to_gray(img1), _to_gray(img2)
+    mse = np.mean((g1 - g2) ** 2)
+    if mse < 1e-10:
+        return 100.0
+    return float(20 * np.log10(255.0 / np.sqrt(mse)))
+
+
+def compute_ssim(img1: np.ndarray, img2: np.ndarray) -> float:
+    """简易 SSIM 实现(灰度,8x8 block)。"""
+    from math import exp, pi, sqrt
+
+    g1, g2 = _to_gray(img1), _to_gray(img2)
+    k1, k2 = 0.01, 0.03
+    l = 255.0
+    c1, c2 = (k1 * l) ** 2, (k2 * l) ** 2
+
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel)
+    window /= window.sum()
+
+    mu1 = cv2.filter2D(g1, -1, window, borderType=cv2.BORDER_REFLECT)
+    mu2 = cv2.filter2D(g2, -1, window, borderType=cv2.BORDER_REFLECT)
+    mu1_sq = mu1 * mu1
+    mu2_sq = mu2 * mu2
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = cv2.filter2D(g1 * g1, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_sq
+    sigma2_sq = cv2.filter2D(g2 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu2_sq
+    sigma12 = cv2.filter2D(g1 * g2, -1, window, borderType=cv2.BORDER_REFLECT) - mu1_mu2
+
+    num = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
+    denom = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
+    ssim_map = num / (denom + 1e-10)
+    return float(ssim_map.mean())
+
+
+# ── 水印配置 ────────────────────────────────────────────────────
+
+
+def _baseline_config() -> Dict[str, Any]:
+    return merge_watermark_config("page", {
+        "method": "masked_adaptive",
+        "threshold": 175,
+        "contrast_enhancement": {"enabled": True, "method": "text_restore", "text_black_target": 85},
+    })
+
+
+def _gan_wm_config() -> Dict[str, Any]:
+    return merge_watermark_config("page", {"method": "masked_adaptive", "threshold": 175})
+
+
+# ── 单图处理 ────────────────────────────────────────────────────
+
+
+def _load_image(path: Path) -> np.ndarray:
+    """加载图片为 BGR ndarray。"""
+    pil = Image.open(str(path)).convert("RGB")
+    np_img = np.array(pil)
+    return cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
+
+
+def _run_baseline(bgr: np.ndarray, cfg: Dict[str, Any]) -> Tuple[np.ndarray, Dict[str, Any]]:
+    """运行 masked_adaptive 方法。"""
+    proc = WatermarkProcessor(cfg, scope="page")
+    debug: Dict[str, Any] = {}
+    result, stages = proc.process(bgr, apply_removal=True, removal_debug=debug)
+    return np.asarray(result), debug
+
+
+def _run_gan(
+    bgr: np.ndarray,
+    wm_cfg: Dict[str, Any],
+    inpainter: LamaInpainter,
+) -> Tuple[np.ndarray, Dict[str, Any]]:
+    """
+    使用GAN修复水印区域。
+
+    1. 用 build_watermark_mask 检测水印区域
+    2. 用 LaMa 修复
+    3. 失败则回退 baseline
+    """
+    debug: Dict[str, Any] = {"mode": "gan"}
+
+    gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
+    mask_cfg = wm_cfg.get("mask", {})
+    wm_mask, mask_debug = build_watermark_mask(gray, bgr=bgr, **mask_cfg)
+
+    debug.update({k: v for k, v in mask_debug.items()
+                  if not isinstance(v, np.ndarray)})
+    debug["wm_mask"] = wm_mask
+
+    if not np.any(wm_mask):
+        logger.info("  未检测到水印区域,跳过GAN")
+        debug["mode"] = "gan_no_mask"
+        clean_gray, _ = _run_baseline(bgr, wm_cfg)
+        return clean_gray, debug
+
+    logger.info(f"  水印区域: {wm_mask.sum()} 像素 "
+                f"({100 * wm_mask.sum() / wm_mask.size:.2f}%)")
+
+    t0 = time.perf_counter()
+    result = inpainter.inpaint(bgr, wm_mask)
+    elapsed = time.perf_counter() - t0
+
+    if result is not None:
+        debug["mode"] = "gan"
+        debug["gan_success"] = True
+        debug["gan_inference_time_s"] = round(elapsed, 2)
+        logger.info(f"  GAN修复成功 ({elapsed:.1f}s)")
+        # 对修复结果做对比度增强
+        from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
+        ce_cfg = wm_cfg.get("contrast_enhancement")
+        result_gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
+        result_gray = apply_contrast_enhancement_config(result_gray, ce_cfg)
+        return result_gray, debug
+
+    # GAN 失败,回退
+    logger.warning("  GAN修复失败,回退 baseline")
+    debug["mode"] = "gan_fallback"
+    debug["fallback_reason"] = "gan_inference_failed"
+    clean_gray, fallback_debug = _run_baseline(bgr, wm_cfg)
+    debug["fallback_debug"] = fallback_debug
+    return clean_gray, debug
+
+
+# ── 输出 ──────────────────────────────────────────────────────────
+
+
+def _make_compare_image(
+    bgr: np.ndarray,
+    baseline_gray: np.ndarray,
+    gan_result: np.ndarray,
+    wm_mask: Optional[np.ndarray] = None,
+) -> np.ndarray:
+    """生成四联对比图。"""
+    h, w = baseline_gray.shape[:2] if baseline_gray.ndim == 2 else baseline_gray.shape
+
+    def _to_bgr(arr: np.ndarray) -> np.ndarray:
+        if arr.ndim == 2:
+            return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
+        return arr
+
+    def _resize(arr: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
+        if arr.shape[0] != target_h or arr.shape[1] != target_w:
+            return cv2.resize(arr, (target_w, target_h))
+        return arr
+
+    # GAN 结果可能是灰度或BGR
+    gan_bgr = _to_bgr(gan_result) if gan_result.ndim == 2 else gan_result
+    if gan_result.ndim == 3 and gan_result.shape[2] == 3:
+        gan_bgr = gan_result  # 已经是BGR
+
+    # 统一尺寸
+    ref_h, ref_w = bgr.shape[:2]
+    baseline_bgr = _to_bgr(baseline_gray) if baseline_gray.ndim == 2 else baseline_gray
+    baseline_bgr = _resize(baseline_bgr, ref_h, ref_w)
+    gan_bgr = _resize(gan_bgr, ref_h, ref_w)
+
+    panels = [bgr, baseline_bgr, gan_bgr]
+
+    # 如果有mask,叠加到原图上作为第四联
+    if wm_mask is not None and np.any(wm_mask):
+        mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
+        panels.append(mask_overlay)
+
+    # 添加标签
+    labels = ["Original", "Baseline (masked_adaptive)", "GAN (LaMa)"]
+    if len(panels) == 4:
+        labels.append("Watermark Mask")
+
+    labeled = []
+    for panel, label in zip(panels, labels):
+        h_p = panel.shape[0]
+        # 底部加标签条
+        bar = np.ones((36, panel.shape[1], 3), dtype=np.uint8) * 240
+        cv2.putText(bar, label, (12, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1)
+        labeled.append(np.vstack([panel, bar]))
+
+    # 水平拼接
+    max_h = max(p.shape[0] for p in labeled)
+    for i in range(len(labeled)):
+        if labeled[i].shape[0] < max_h:
+            pad = np.ones((max_h - labeled[i].shape[0], labeled[i].shape[1], 3), dtype=np.uint8) * 255
+            labeled[i] = np.vstack([labeled[i], pad])
+
+    return np.hstack(labeled)
+
+
+def _save_result(
+    stem: str,
+    result: np.ndarray,
+    output_dir: Path,
+    prefix: str = "",
+) -> Path:
+    """保存结果图片。"""
+    p = output_dir / f"{stem}_{prefix}.png"
+    if result.ndim == 2:
+        cv2.imwrite(str(p), result)
+    else:
+        cv2.imwrite(str(p), result)
+    return p
+
+
+def _save_metrics_json(
+    metrics_list: List[Dict[str, Any]],
+    output_dir: Path,
+) -> None:
+    output_dir.mkdir(parents=True, exist_ok=True)
+    p = output_dir / "metrics.json"
+    p.write_text(json.dumps(metrics_list, ensure_ascii=False, indent=2), encoding="utf-8")
+    logger.info(f"评估指标: {p}")
+
+
+# ── 主函数 ────────────────────────────────────────────────────────
+
+
+def evaluate(
+    input_dir: Path,
+    output_root: Path,
+    *,
+    clean_dir: Optional[Path] = None,
+    device: str = "cpu",
+    gan_only: bool = False,
+) -> None:
+    """批量评估。"""
+    img_files = sorted([
+        f for f in input_dir.iterdir()
+        if f.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
+    ])
+    if not img_files:
+        logger.error(f"{input_dir} 下没有图片文件")
+        return
+
+    # 输出目录
+    out_compare = output_root / "compare"
+    out_inpainted = output_root / "inpainted"
+    out_mask = output_root / "mask_debug"
+    out_metrics = output_root / "metrics"
+    for d in [out_compare, out_inpainted, out_mask, out_metrics]:
+        d.mkdir(parents=True, exist_ok=True)
+
+    baseline_cfg = _baseline_config()
+    wm_cfg = _gan_wm_config()
+
+    inpainter = LamaInpainter(device=device)
+    available = inpainter.is_available
+    logger.info(f"LaMa 可用: {available}, backend: {inpainter._backend or '未加载'}")
+
+    if not available and not gan_only:
+        logger.warning("LaMa backend 不可用,GAN将回退到OpenCV inpaint")
+
+    all_metrics: List[Dict[str, Any]] = []
+
+    for f in img_files:
+        logger.info(f"\n处理: {f.name}")
+        stem = f.stem
+        bgr = _load_image(f)
+
+        # 检查是否有对应 clean 参考图
+        clean_img: Optional[np.ndarray] = None
+        if clean_dir:
+            for ext in (".png", ".jpg", ".jpeg"):
+                clean_path = clean_dir / f"{stem}{ext}"
+                if clean_path.exists():
+                    clean_img = _load_image(clean_path)
+                    break
+            if clean_img is None:
+                # 尝试移除 _watermarked 后缀
+                clean_name = stem.replace("_watermarked", "")
+                for ext in (".png", ".jpg", ".jpeg"):
+                    clean_path = clean_dir / f"{clean_name}{ext}"
+                    if clean_path.exists():
+                        clean_img = _load_image(clean_path)
+                        break
+
+        # 检测水印
+        gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
+        has_wm = detect_watermark(gray, ratio_threshold=0.025)
+        logger.info(f"  水印检测: {'有水印' if has_wm else '无水印'}")
+
+        # ── Baseline ──
+        logger.info("  运行 baseline (masked_adaptive)...")
+        t0 = time.perf_counter()
+        baseline_result, baseline_debug = _run_baseline(bgr, baseline_cfg)
+        baseline_time = time.perf_counter() - t0
+        logger.info(f"  baseline 耗时: {baseline_time:.1f}s")
+
+        # ── GAN ──
+        t0 = time.perf_counter()
+        gan_result, gan_debug = _run_gan(bgr, wm_cfg, inpainter)
+        gan_time = time.perf_counter() - t0
+
+        # ── 保存结果 ──
+        _save_result(stem, baseline_result, out_inpainted, "baseline")
+        gan_save = gan_result
+        if gan_result.ndim == 2:
+            gan_save_bgr = cv2.cvtColor(gan_result, cv2.COLOR_GRAY2BGR)
+        else:
+            gan_save_bgr = gan_result
+        _save_result(stem, gan_save_bgr, out_inpainted, "gan")
+
+        # ── 掩膜可视化 ──
+        wm_mask = gan_debug.get("wm_mask")
+        if wm_mask is not None and np.any(wm_mask):
+            mask_overlay = render_watermark_mask_overlay(bgr, wm_mask)
+            _save_result(stem, mask_overlay, out_mask, "mask_overlay")
+
+        # ── 对比图 ──
+        compare_img = _make_compare_image(bgr, baseline_result, gan_save_bgr, wm_mask)
+        _save_result(stem, compare_img, out_compare, "compare")
+
+        # ── 评估指标 ──
+        metrics: Dict[str, Any] = {
+            "file": f.name,
+            "has_watermark": has_wm,
+            "baseline_time_s": round(baseline_time, 2),
+            "gan_time_s": round(gan_time, 2),
+            "gan_mode": gan_debug.get("mode", "unknown"),
+        }
+        if clean_img is not None:
+            # baseline vs clean
+            metrics["baseline_psnr"] = round(compute_psnr(baseline_result, clean_img), 2)
+            metrics["baseline_ssim"] = round(compute_ssim(baseline_result, clean_img), 4)
+            # gan vs clean
+            metrics["gan_psnr"] = round(compute_psnr(gan_save_bgr, clean_img), 2)
+            metrics["gan_ssim"] = round(compute_ssim(gan_save_bgr, clean_img), 4)
+            logger.info(
+                f"  PSNR: baseline={metrics['baseline_psnr']}dB, "
+                f"GAN={metrics['gan_psnr']}dB"
+            )
+        all_metrics.append(metrics)
+
+    _save_metrics_json(all_metrics, out_metrics)
+
+    # 汇总
+    logger.info(f"\n{'='*50}")
+    logger.info(f"评估完成,共 {len(img_files)} 张图")
+    logger.info(f"  对比图:   {out_compare}")
+    logger.info(f"  修复结果: {out_inpainted}")
+    logger.info(f"  掩膜:     {out_mask}")
+    logger.info(f"  指标:     {out_metrics}")
+
+    if clean_img is not None:
+        avg_baseline_psnr = np.mean([m.get("baseline_psnr", 0) for m in all_metrics])
+        avg_gan_psnr = np.mean([m.get("gan_psnr", 0) for m in all_metrics])
+        logger.info(f"  平均 PSNR: baseline={avg_baseline_psnr:.1f}dB, GAN={avg_gan_psnr:.1f}dB")
+
+
+def main():
+    root = Path(__file__).parent
+
+    parser = argparse.ArgumentParser(
+        description="去水印评估:baseline vs GAN",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog=__doc__,
+    )
+    parser.add_argument("--input", type=Path, default=root / "test_images" / "input",
+                        help="输入图片目录")
+    parser.add_argument("--output", type=Path, default=root / "output",
+                        help="输出根目录")
+    parser.add_argument("--clean-dir", type=Path, default=None,
+                        help="clean参考图目录(用于计算PSNR/SSIM)")
+    parser.add_argument("--device", type=str, default="cpu",
+                        choices=["cpu", "cuda", "mps"],
+                        help="推理设备")
+    parser.add_argument("--gan-only", action="store_true",
+                        help="仅运行GAN(跳过baseline)")
+    args = parser.parse_args()
+
+    evaluate(
+        args.input,
+        args.output,
+        clean_dir=args.clean_dir,
+        device=args.device,
+        gan_only=args.gan_only,
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 245 - 0
ocr_tools/gan_experiments_lab/lama_inpaint.py

@@ -0,0 +1,245 @@
+"""
+LaMa (Large Mask Inpainting) 推理模块。
+
+封装预训练LaMa模型的加载与推理,方案选择(按优先级):
+1. simple_lama_inpainting  pip包(最简)
+2. 本地 lama 仓库代码(big-lama checkpoint)
+3. OpenCV inpainting(终极回退,不用GAN)
+
+用法:
+    from gan_experiments_lab.lama_inpaint import LamaInpainter
+    inpaint = LamaInpainter(device="cpu")
+    result = inpaint.inpaint(bgr_image, mask_bool)
+"""
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+from typing import Optional
+
+import cv2
+import numpy as np
+from loguru import logger
+
+
+def _check_simple_lama() -> bool:
+    try:
+        import simple_lama_inpainting  # noqa: F401
+        return True
+    except ImportError:
+        return False
+
+
+def _check_lama_repo() -> Optional[Path]:
+    """检查本地是否有 lama 仓库并已加入 sys.path。"""
+    candidates = [
+        Path(__file__).parent / "lama",
+        Path(__file__).parents[2] / "lama",
+        Path.home() / "lama",
+        Path("/tmp/lama"),
+    ]
+    for p in candidates:
+        if (p / "saicinpainting" / "__init__.py").exists():
+            return p
+    return None
+
+
+class LamaInpainter:
+    """LaMa inpainting 门面,自动选择可用后端。"""
+
+    def __init__(
+        self,
+        *,
+        device: str = "cpu",
+        inference_size: Optional[int] = None,
+        pad_to_multiple: int = 8,
+    ):
+        self._device = device
+        self._inference_size = inference_size  # None = 保持原尺寸
+        self._pad_to_multiple = pad_to_multiple
+        self._model = None
+        self._backend = None  # "simple_lama" | "lama_repo" | "opencv"
+        self._lama_repo_path: Optional[Path] = None
+
+    @property
+    def is_available(self) -> bool:
+        if self._backend is not None:
+            return self._backend != "opencv"
+        if _check_simple_lama():
+            self._backend = "simple_lama"
+            return True
+        if _check_lama_repo():
+            self._backend = "lama_repo"
+            return True
+        return False
+
+    def load(self) -> bool:
+        """加载模型,返回是否成功。"""
+        if self._model is not None:
+            return True
+
+        if _check_simple_lama():
+            return self._load_simple_lama()
+        repo = _check_lama_repo()
+        if repo:
+            return self._load_lama_repo(repo)
+
+        logger.warning("LaMa backends 都不可用,将回退 OpenCV inpainting")
+        self._backend = "opencv"
+        return False
+
+    def _load_simple_lama(self) -> bool:
+        try:
+            from simple_lama_inpainting import SimpleLama
+            self._model = SimpleLama(device=self._device)
+            self._backend = "simple_lama"
+            logger.info(f"LaMa (simple_lama_inpainting) 已加载, device={self._device}")
+            return True
+        except Exception as e:
+            logger.warning(f"simple_lama_inpainting 加载失败: {e}")
+            return False
+
+    def _load_lama_repo(self, repo_path: Path) -> bool:
+        try:
+            if str(repo_path) not in sys.path:
+                sys.path.insert(0, str(repo_path))
+
+            from omegaconf import OmegaConf
+            from saicinpainting.training.trainers import load_checkpoint
+
+            config_path = repo_path / "big-lama" / "config.yaml"
+            ckpt_path = repo_path / "big-lama" / "models" / "best.ckpt"
+
+            if not config_path.exists() or not ckpt_path.exists():
+                logger.warning(
+                    f"lama 模型文件缺失。请下载: "
+                    f"wget https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.zip && "
+                    f"unzip big-lama.zip -d {repo_path}"
+                )
+                return False
+
+            conf = OmegaConf.load(str(config_path))
+            conf.training_model.predict_only = True
+            conf.visualizer.kind = "noop"
+
+            model = load_checkpoint(conf, str(ckpt_path), strict=False, map_location="cpu")
+            model.eval()
+            if self._device != "cpu":
+                model.cuda()
+            self._model = model
+            self._lama_repo_path = repo_path
+            self._backend = "lama_repo"
+            logger.info(f"LaMa (lama_repo) 已加载, device={self._device}")
+            return True
+        except Exception as e:
+            logger.warning(f"lama_repo 加载失败: {e}")
+            return False
+
+    def inpaint(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
+        """
+        修复图像。
+
+        Args:
+            image: BGR ndarray (H, W, 3), uint8
+            mask: bool ndarray (H, W), True=需要修复的水印区域
+
+        Returns:
+            BGR ndarray (H, W, 3), uint8, or None
+        """
+        if not self._model:
+            if not self.load():
+                return self._opencv_inpaint(image, mask)
+
+        if self._backend == "simple_lama":
+            return self._inpaint_simple_lama(image, mask)
+        elif self._backend == "lama_repo":
+            return self._inpaint_lama_repo(image, mask)
+        else:
+            return self._opencv_inpaint(image, mask)
+
+    def _inpaint_simple_lama(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
+        try:
+            rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+            mask_u8 = mask.astype(np.uint8) * 255
+            # 按需 resize
+            if self._inference_size:
+                rgb, mask_u8, orig_size = self._resize_to_inference(rgb, mask_u8)
+            result_rgb = self._model(rgb, mask_u8)
+            if self._inference_size:
+                result_rgb = cv2.resize(result_rgb, (orig_size[1], orig_size[0]))
+            return cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
+        except Exception as e:
+            logger.warning(f"simple_lama 推理失败: {e}")
+            return None
+
+    def _inpaint_lama_repo(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
+        try:
+            import torch
+            import torch.nn.functional as F
+            from saicinpainting.evaluation.data import pad_tensor_to_modulo
+
+            rgb = cv2.cvtColor(image.astype(np.float32) / 255.0, cv2.COLOR_BGR2RGB)
+            mask_f = mask.astype(np.float32)
+            orig_h, orig_w = rgb.shape[:2]
+
+            # resize
+            if self._inference_size:
+                rgb, mask_f, (orig_w, orig_h) = self._resize_image_mask(rgb, mask_f)
+
+            img_t = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
+            mask_t = torch.from_numpy(mask_f).unsqueeze(0).unsqueeze(0)
+
+            img_t = pad_tensor_to_modulo(img_t, self._pad_to_multiple)
+            mask_t = pad_tensor_to_modulo(mask_t, self._pad_to_multiple)
+
+            if self._device != "cpu":
+                img_t = img_t.cuda()
+                mask_t = mask_t.cuda()
+
+            with torch.no_grad():
+                output = self._model(img_t, mask_t)
+                # output shape: (B, C, H, W)
+                result = output[0].permute(1, 2, 0).cpu().numpy()
+                # 裁掉 pad
+                result = result[:orig_h, :orig_w, :]
+
+            result = np.clip(result, 0, 1)
+            result_u8 = (result * 255).astype(np.uint8)
+            return cv2.cvtColor(result_u8, cv2.COLOR_RGB2BGR)
+        except Exception as e:
+            logger.warning(f"lama_repo 推理失败: {e}")
+            return None
+
+    def _resize_to_inference(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
+        h, w = rgb.shape[:2]
+        size = self._inference_size or min(h, w)
+        scale = size / min(h, w)
+        new_w, new_h = int(w * scale), int(h * scale)
+        rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
+        mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
+        return rgb_rs, mask_rs, (w, h)
+
+    def _resize_image_mask(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
+        h, w = rgb.shape[:2]
+        size = self._inference_size or min(h, w)
+        scale = size / min(h, w)
+        new_w, new_h = int(w * scale), int(h * scale)
+        rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
+        mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
+        return rgb_rs, mask_rs, (w, h)
+
+    def _opencv_inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
+        """OpenCV Telea inpainting 回退(非GAN)。"""
+        logger.info("使用 OpenCV inpainting 回退")
+        mask_u8 = mask.astype(np.uint8) * 255
+        return cv2.inpaint(image, mask_u8, inpaintRadius=5, flags=cv2.INPAINT_TELEA)
+
+
+if __name__ == "__main__":
+    # 快速功能测试
+    print("LaMa 后端检测:")
+    print(f"  simple_lama_inpainting: {_check_simple_lama()}")
+    repo = _check_lama_repo()
+    print(f"  lama_repo:              {repo}")
+    inpaint = LamaInpainter(device="cpu")
+    print(f"  is_available:           {inpaint.is_available}")

BIN
ocr_tools/gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png


+ 222 - 0
ocr_tools/gan_experiments_lab/watermark_synthesis.py

@@ -0,0 +1,222 @@
+"""
+水印合成脚本:在clean图片上叠加斜向浅色文字水印,输出带水印图 + 精确mask。
+
+用法:
+    python watermark_synthesis.py                          # 默认参数演示
+    python watermark_synthesis.py --input ./test_images/clean/   # 指定输入目录
+    python watermark_synthesis.py --text "SAMPLE" --opacity 0.15 --angle 45
+"""
+from __future__ import annotations
+
+import argparse
+import math
+from pathlib import Path
+from typing import Optional
+
+import cv2
+import numpy as np
+from loguru import logger
+from PIL import Image, ImageDraw, ImageFont
+
+
+def _find_font() -> str:
+    """查找可用中文字体,找不到返回默认字体。"""
+    candidates = [
+        "/System/Library/Fonts/PingFang.ttc",
+        "/System/Library/Fonts/STHeiti Light.ttc",
+        "/System/Library/Fonts/Hiragino Sans GB.ttc",
+        "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
+        "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
+    ]
+    for fp in candidates:
+        if Path(fp).exists():
+            return fp
+    logger.warning("未找到中文字体,使用PIL默认字体")
+    return ""
+
+
+def _text_size_to_font_size(text_height_px: int) -> int:
+    """根据目标文字像素高度估算 font_size。"""
+    return int(text_height_px * 1.15)
+
+
+def _render_watermark_tile(
+    pil_img: Image.Image,
+    text: str,
+    font_path: str,
+    font_size: int,
+    opacity: float,
+    angle_deg: float,
+    spacing_x: int,
+    spacing_y: int,
+) -> tuple[np.ndarray, np.ndarray]:
+    """
+    在图上平铺斜向水印文字,返回 (watermarked_np, mask_np)。
+
+    mask_np: H×W bool, True=水印像素位置。
+    """
+    w, h = pil_img.size
+    text_height = int(font_size / 1.15)
+    gray_value = int(255 * (1 - opacity))
+
+    # 创建水印文字mask(稍大画布以覆盖旋转后区域)
+    diag = int(math.sqrt(w * w + h * h)) + text_height * 4
+    tile_w = diag
+    tile_h = diag
+
+    tile = Image.new("L", (tile_w, tile_h), 0)
+    draw = ImageDraw.Draw(tile)
+    font = ImageFont.truetype(font_path, font_size) if font_path else ImageFont.load_default()
+
+    # 步长取spacing + 文字大小,确保均匀分布
+    step_x = text_height + spacing_x
+    step_y = text_height + spacing_y
+
+    for y in range(0, tile_h, step_y):
+        for x in range(0, tile_w, step_x):
+            draw.text((x, y), text, fill=255, font=font)
+
+    # 旋转
+    tile_rot = tile.rotate(angle_deg, expand=False, fillcolor=0)
+
+    # 裁剪到原图大小(中心对齐)
+    cx, cy = tile_rot.size[0] // 2, tile_rot.size[1] // 2
+    left = cx - w // 2
+    top = cy - h // 2
+    watermark_tile = tile_rot.crop((left, top, left + w, top + h))
+
+    mask_np = np.array(watermark_tile) > 0
+
+    # 叠加到原图
+    base = np.array(pil_img.convert("RGB"))
+    alpha = opacity
+    result = base.copy()
+    result[mask_np] = (
+        base[mask_np].astype(np.float32) * (1 - alpha)
+        + np.array([gray_value, gray_value, gray_value], dtype=np.float32) * alpha
+    ).astype(np.uint8)
+
+    return result, mask_np
+
+
+def synthesize_watermark(
+    input_path: Path,
+    output_dir: Path,
+    *,
+    text: str = "SAMPLE",
+    font_path: str = "",
+    text_height_px: int = 36,
+    opacity: float = 0.12,
+    angle_deg: float = 45.0,
+    spacing_x: int = 180,
+    spacing_y: int = 180,
+    save_mask: bool = True,
+) -> Path:
+    """
+    在输入图片上合成水印,输出到 output_dir。
+
+    Returns:
+        合成后的图片路径
+    """
+    output_dir.mkdir(parents=True, exist_ok=True)
+    pil_img = Image.open(str(input_path)).convert("RGB")
+
+    fp = font_path or _find_font()
+    font_size = _text_size_to_font_size(text_height_px)
+
+    logger.info(
+        f"合成水印: {input_path.name} | "
+        f"text='{text}' font_size={font_size} opacity={opacity} angle={angle_deg}"
+    )
+
+    result_np, mask_np = _render_watermark_tile(
+        pil_img, text, fp, font_size, opacity, angle_deg, spacing_x, spacing_y
+    )
+
+    out_name = f"{input_path.stem}_watermarked{input_path.suffix}"
+    out_path = output_dir / out_name
+    Image.fromarray(result_np).save(str(out_path))
+    logger.info(f"  水印图: {out_path}")
+
+    if save_mask:
+        mask_path = output_dir / f"{input_path.stem}_mask.png"
+        cv2.imwrite(str(mask_path), (mask_np.astype(np.uint8) * 255))
+        logger.info(f"  mask:   {mask_path}")
+
+    return out_path
+
+
+def main():
+    parser = argparse.ArgumentParser(description="水印合成工具")
+    parser.add_argument("--input", type=Path, default=None,
+                        help="输入图片或目录(默认: test_images/clean/)")
+    parser.add_argument("--output", type=Path, default=None,
+                        help="输出目录(默认: test_images/synthetic/)")
+    parser.add_argument("--text", type=str, default="行内内部使用",
+                        help="水印文字内容")
+    parser.add_argument("--text-height", type=int, default=48,
+                        help="文字像素高度(默认48)")
+    parser.add_argument("--opacity", type=float, default=0.10,
+                        help="水印透明度 0~1(默认0.10)")
+    parser.add_argument("--angle", type=float, default=45.0,
+                        help="水印倾斜角度(默认45°)")
+    parser.add_argument("--spacing-x", type=int, default=250,
+                        help="水印文字水平间距(默认250px)")
+    parser.add_argument("--spacing-y", type=int, default=250,
+                        help="水印文字垂直间距(默认250px)")
+    parser.add_argument("--font", type=str, default="",
+                        help="字体文件路径")
+    parser.add_argument("--no-mask", action="store_true",
+                        help="不保存mask")
+    parser.add_argument("--demo", action="store_true",
+                        help="使用input目录下第一张测试图生成演示图")
+    args = parser.parse_args()
+
+    root = Path(__file__).parent
+    input_dir = args.input or (root / "test_images" / "clean")
+    output_dir = args.output or (root / "test_images" / "synthetic")
+
+    if args.demo:
+        # 无clean图时,直接用input目录的水印图再加一层合成水印做演示
+        img_files = sorted(root.glob("test_images/input/*"))
+        if not img_files:
+            logger.error("test_images/input/ 下没有测试图片,请放入图片后重试")
+            return
+        input_dir = root / "test_images" / "input"
+        output_dir = root / "test_images" / "synthetic"
+
+    input_dir = Path(input_dir)
+    output_dir = Path(output_dir)
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    if input_dir.is_dir():
+        img_files = sorted([
+            f for f in input_dir.iterdir()
+            if f.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
+        ])
+    elif input_dir.is_file():
+        img_files = [input_dir]
+    else:
+        logger.error(f"输入路径不存在: {input_dir}")
+        return
+
+    if not img_files:
+        logger.warning(f"{input_dir} 下没有图片文件")
+        return
+
+    for f in img_files:
+        synthesize_watermark(
+            f, output_dir,
+            text=args.text,
+            font_path=args.font,
+            text_height_px=args.text_height,
+            opacity=args.opacity,
+            angle_deg=args.angle,
+            spacing_x=args.spacing_x,
+            spacing_y=args.spacing_y,
+            save_mask=not args.no_mask,
+        )
+
+
+if __name__ == "__main__":
+    main()