| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- """
- LaMa (Large Mask Inpainting) 推理模块。
- 封装预训练LaMa模型的加载与推理,方案选择(按优先级):
- 1. simple_lama_inpainting pip包(最简)
- 2. 本地 lama 仓库代码(big-lama checkpoint)
- 3. OpenCV inpainting(终极回退,不用GAN)
- 用法:
- from gan_experiments_lab.lama_inpaint import LamaInpainter
- inpaint = LamaInpainter(device="cpu")
- result = inpaint.inpaint(bgr_image, mask_bool)
- """
- from __future__ import annotations
- import sys
- from pathlib import Path
- from typing import Optional
- import cv2
- import numpy as np
- from loguru import logger
- def _check_simple_lama() -> bool:
- try:
- import simple_lama_inpainting # noqa: F401
- return True
- except ImportError:
- return False
- def _check_lama_repo() -> Optional[Path]:
- """检查本地是否有 lama 仓库并已加入 sys.path。"""
- candidates = [
- Path(__file__).parent / "lama",
- Path(__file__).parents[2] / "lama",
- Path.home() / "lama",
- Path("/tmp/lama"),
- ]
- for p in candidates:
- if (p / "saicinpainting" / "__init__.py").exists():
- return p
- return None
- class LamaInpainter:
- """LaMa inpainting 门面,自动选择可用后端。"""
- def __init__(
- self,
- *,
- device: str = "cpu",
- inference_size: Optional[int] = None,
- pad_to_multiple: int = 8,
- ):
- self._device = device
- self._inference_size = inference_size # None = 保持原尺寸
- self._pad_to_multiple = pad_to_multiple
- self._model = None
- self._backend = None # "simple_lama" | "lama_repo" | "opencv"
- self._lama_repo_path: Optional[Path] = None
- @property
- def is_available(self) -> bool:
- if self._backend is not None:
- return self._backend != "opencv"
- if _check_simple_lama():
- self._backend = "simple_lama"
- return True
- if _check_lama_repo():
- self._backend = "lama_repo"
- return True
- return False
- def load(self) -> bool:
- """加载模型,返回是否成功。"""
- if self._model is not None:
- return True
- if _check_simple_lama():
- return self._load_simple_lama()
- repo = _check_lama_repo()
- if repo:
- return self._load_lama_repo(repo)
- logger.warning("LaMa backends 都不可用,将回退 OpenCV inpainting")
- self._backend = "opencv"
- return False
- def _load_simple_lama(self) -> bool:
- try:
- from simple_lama_inpainting import SimpleLama
- self._model = SimpleLama(device=self._device)
- self._backend = "simple_lama"
- logger.info(f"LaMa (simple_lama_inpainting) 已加载, device={self._device}")
- return True
- except Exception as e:
- logger.warning(f"simple_lama_inpainting 加载失败: {e}")
- return False
- def _load_lama_repo(self, repo_path: Path) -> bool:
- try:
- if str(repo_path) not in sys.path:
- sys.path.insert(0, str(repo_path))
- from omegaconf import OmegaConf
- from saicinpainting.training.trainers import load_checkpoint
- config_path = repo_path / "big-lama" / "config.yaml"
- ckpt_path = repo_path / "big-lama" / "models" / "best.ckpt"
- if not config_path.exists() or not ckpt_path.exists():
- logger.warning(
- f"lama 模型文件缺失。请下载: "
- f"wget https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.zip && "
- f"unzip big-lama.zip -d {repo_path}"
- )
- return False
- conf = OmegaConf.load(str(config_path))
- conf.training_model.predict_only = True
- conf.visualizer.kind = "noop"
- model = load_checkpoint(conf, str(ckpt_path), strict=False, map_location="cpu")
- model.eval()
- if self._device != "cpu":
- model.cuda()
- self._model = model
- self._lama_repo_path = repo_path
- self._backend = "lama_repo"
- logger.info(f"LaMa (lama_repo) 已加载, device={self._device}")
- return True
- except Exception as e:
- logger.warning(f"lama_repo 加载失败: {e}")
- return False
- def inpaint(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
- """
- 修复图像。
- Args:
- image: BGR ndarray (H, W, 3), uint8
- mask: bool ndarray (H, W), True=需要修复的水印区域
- Returns:
- BGR ndarray (H, W, 3), uint8, or None
- """
- if not self._model:
- if not self.load():
- return self._opencv_inpaint(image, mask)
- if self._backend == "simple_lama":
- return self._inpaint_simple_lama(image, mask)
- elif self._backend == "lama_repo":
- return self._inpaint_lama_repo(image, mask)
- else:
- return self._opencv_inpaint(image, mask)
- def _inpaint_simple_lama(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
- try:
- rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- mask_u8 = mask.astype(np.uint8) * 255
- # 按需 resize
- if self._inference_size:
- rgb, mask_u8, orig_size = self._resize_to_inference(rgb, mask_u8)
- result_rgb = self._model(rgb, mask_u8)
- if self._inference_size:
- result_rgb = cv2.resize(result_rgb, (orig_size[1], orig_size[0]))
- return cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
- except Exception as e:
- logger.warning(f"simple_lama 推理失败: {e}")
- return None
- def _inpaint_lama_repo(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
- try:
- import torch
- import torch.nn.functional as F
- from saicinpainting.evaluation.data import pad_tensor_to_modulo
- rgb = cv2.cvtColor(image.astype(np.float32) / 255.0, cv2.COLOR_BGR2RGB)
- mask_f = mask.astype(np.float32)
- orig_h, orig_w = rgb.shape[:2]
- # resize
- if self._inference_size:
- rgb, mask_f, (orig_w, orig_h) = self._resize_image_mask(rgb, mask_f)
- img_t = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
- mask_t = torch.from_numpy(mask_f).unsqueeze(0).unsqueeze(0)
- img_t = pad_tensor_to_modulo(img_t, self._pad_to_multiple)
- mask_t = pad_tensor_to_modulo(mask_t, self._pad_to_multiple)
- if self._device != "cpu":
- img_t = img_t.cuda()
- mask_t = mask_t.cuda()
- with torch.no_grad():
- output = self._model(img_t, mask_t)
- # output shape: (B, C, H, W)
- result = output[0].permute(1, 2, 0).cpu().numpy()
- # 裁掉 pad
- result = result[:orig_h, :orig_w, :]
- result = np.clip(result, 0, 1)
- result_u8 = (result * 255).astype(np.uint8)
- return cv2.cvtColor(result_u8, cv2.COLOR_RGB2BGR)
- except Exception as e:
- logger.warning(f"lama_repo 推理失败: {e}")
- return None
- def _resize_to_inference(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
- h, w = rgb.shape[:2]
- size = self._inference_size or min(h, w)
- scale = size / min(h, w)
- new_w, new_h = int(w * scale), int(h * scale)
- rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
- mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
- return rgb_rs, mask_rs, (w, h)
- def _resize_image_mask(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
- h, w = rgb.shape[:2]
- size = self._inference_size or min(h, w)
- scale = size / min(h, w)
- new_w, new_h = int(w * scale), int(h * scale)
- rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
- mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
- return rgb_rs, mask_rs, (w, h)
- def _opencv_inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
- """OpenCV Telea inpainting 回退(非GAN)。"""
- logger.info("使用 OpenCV inpainting 回退")
- mask_u8 = mask.astype(np.uint8) * 255
- return cv2.inpaint(image, mask_u8, inpaintRadius=5, flags=cv2.INPAINT_TELEA)
- if __name__ == "__main__":
- # 快速功能测试
- print("LaMa 后端检测:")
- print(f" simple_lama_inpainting: {_check_simple_lama()}")
- repo = _check_lama_repo()
- print(f" lama_repo: {repo}")
- inpaint = LamaInpainter(device="cpu")
- print(f" is_available: {inpaint.is_available}")
|