""" LaMa (Large Mask Inpainting) 推理模块。 封装预训练LaMa模型的加载与推理,方案选择(按优先级): 1. simple_lama_inpainting pip包(最简) 2. 本地 lama 仓库代码(big-lama checkpoint) 3. OpenCV inpainting(终极回退,不用GAN) 用法: from gan_experiments_lab.lama_inpaint import LamaInpainter inpaint = LamaInpainter(device="cpu") result = inpaint.inpaint(bgr_image, mask_bool) """ from __future__ import annotations import sys from pathlib import Path from typing import Optional import cv2 import numpy as np from loguru import logger def _check_simple_lama() -> bool: try: import simple_lama_inpainting # noqa: F401 return True except ImportError: return False def _check_lama_repo() -> Optional[Path]: """检查本地是否有 lama 仓库并已加入 sys.path。""" candidates = [ Path(__file__).parent / "lama", Path(__file__).parents[2] / "lama", Path.home() / "lama", Path("/tmp/lama"), ] for p in candidates: if (p / "saicinpainting" / "__init__.py").exists(): return p return None class LamaInpainter: """LaMa inpainting 门面,自动选择可用后端。""" def __init__( self, *, device: str = "cpu", inference_size: Optional[int] = None, pad_to_multiple: int = 8, ): self._device = device self._inference_size = inference_size # None = 保持原尺寸 self._pad_to_multiple = pad_to_multiple self._model = None self._backend = None # "simple_lama" | "lama_repo" | "opencv" self._lama_repo_path: Optional[Path] = None @property def is_available(self) -> bool: if self._backend is not None: return self._backend != "opencv" if _check_simple_lama(): self._backend = "simple_lama" return True if _check_lama_repo(): self._backend = "lama_repo" return True return False def load(self) -> bool: """加载模型,返回是否成功。""" if self._model is not None: return True if _check_simple_lama(): return self._load_simple_lama() repo = _check_lama_repo() if repo: return self._load_lama_repo(repo) logger.warning("LaMa backends 都不可用,将回退 OpenCV inpainting") self._backend = "opencv" return False def _load_simple_lama(self) -> bool: try: from simple_lama_inpainting import SimpleLama self._model = SimpleLama(device=self._device) self._backend = "simple_lama" logger.info(f"LaMa (simple_lama_inpainting) 已加载, device={self._device}") return True except Exception as e: logger.warning(f"simple_lama_inpainting 加载失败: {e}") return False def _load_lama_repo(self, repo_path: Path) -> bool: try: if str(repo_path) not in sys.path: sys.path.insert(0, str(repo_path)) from omegaconf import OmegaConf from saicinpainting.training.trainers import load_checkpoint config_path = repo_path / "big-lama" / "config.yaml" ckpt_path = repo_path / "big-lama" / "models" / "best.ckpt" if not config_path.exists() or not ckpt_path.exists(): logger.warning( f"lama 模型文件缺失。请下载: " f"wget https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.zip && " f"unzip big-lama.zip -d {repo_path}" ) return False conf = OmegaConf.load(str(config_path)) conf.training_model.predict_only = True conf.visualizer.kind = "noop" model = load_checkpoint(conf, str(ckpt_path), strict=False, map_location="cpu") model.eval() if self._device != "cpu": model.cuda() self._model = model self._lama_repo_path = repo_path self._backend = "lama_repo" logger.info(f"LaMa (lama_repo) 已加载, device={self._device}") return True except Exception as e: logger.warning(f"lama_repo 加载失败: {e}") return False def inpaint(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]: """ 修复图像。 Args: image: BGR ndarray (H, W, 3), uint8 mask: bool ndarray (H, W), True=需要修复的水印区域 Returns: BGR ndarray (H, W, 3), uint8, or None """ if not self._model: if not self.load(): return self._opencv_inpaint(image, mask) if self._backend == "simple_lama": return self._inpaint_simple_lama(image, mask) elif self._backend == "lama_repo": return self._inpaint_lama_repo(image, mask) else: return self._opencv_inpaint(image, mask) def _inpaint_simple_lama(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]: try: rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_u8 = mask.astype(np.uint8) * 255 # 按需 resize if self._inference_size: rgb, mask_u8, orig_size = self._resize_to_inference(rgb, mask_u8) result_rgb = self._model(rgb, mask_u8) if self._inference_size: result_rgb = cv2.resize(result_rgb, (orig_size[1], orig_size[0])) return cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) except Exception as e: logger.warning(f"simple_lama 推理失败: {e}") return None def _inpaint_lama_repo(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]: try: import torch import torch.nn.functional as F from saicinpainting.evaluation.data import pad_tensor_to_modulo rgb = cv2.cvtColor(image.astype(np.float32) / 255.0, cv2.COLOR_BGR2RGB) mask_f = mask.astype(np.float32) orig_h, orig_w = rgb.shape[:2] # resize if self._inference_size: rgb, mask_f, (orig_w, orig_h) = self._resize_image_mask(rgb, mask_f) img_t = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) mask_t = torch.from_numpy(mask_f).unsqueeze(0).unsqueeze(0) img_t = pad_tensor_to_modulo(img_t, self._pad_to_multiple) mask_t = pad_tensor_to_modulo(mask_t, self._pad_to_multiple) if self._device != "cpu": img_t = img_t.cuda() mask_t = mask_t.cuda() with torch.no_grad(): output = self._model(img_t, mask_t) # output shape: (B, C, H, W) result = output[0].permute(1, 2, 0).cpu().numpy() # 裁掉 pad result = result[:orig_h, :orig_w, :] result = np.clip(result, 0, 1) result_u8 = (result * 255).astype(np.uint8) return cv2.cvtColor(result_u8, cv2.COLOR_RGB2BGR) except Exception as e: logger.warning(f"lama_repo 推理失败: {e}") return None def _resize_to_inference(self, rgb: np.ndarray, mask: np.ndarray) -> tuple: h, w = rgb.shape[:2] size = self._inference_size or min(h, w) scale = size / min(h, w) new_w, new_h = int(w * scale), int(h * scale) rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC) mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST) return rgb_rs, mask_rs, (w, h) def _resize_image_mask(self, rgb: np.ndarray, mask: np.ndarray) -> tuple: h, w = rgb.shape[:2] size = self._inference_size or min(h, w) scale = size / min(h, w) new_w, new_h = int(w * scale), int(h * scale) rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC) mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST) return rgb_rs, mask_rs, (w, h) def _opencv_inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: """OpenCV Telea inpainting 回退(非GAN)。""" logger.info("使用 OpenCV inpainting 回退") mask_u8 = mask.astype(np.uint8) * 255 return cv2.inpaint(image, mask_u8, inpaintRadius=5, flags=cv2.INPAINT_TELEA) if __name__ == "__main__": # 快速功能测试 print("LaMa 后端检测:") print(f" simple_lama_inpainting: {_check_simple_lama()}") repo = _check_lama_repo() print(f" lama_repo: {repo}") inpaint = LamaInpainter(device="cpu") print(f" is_available: {inpaint.is_available}")