lama_inpaint.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. """
  2. LaMa (Large Mask Inpainting) 推理模块。
  3. 封装预训练LaMa模型的加载与推理,方案选择(按优先级):
  4. 1. simple_lama_inpainting pip包(最简)
  5. 2. 本地 lama 仓库代码(big-lama checkpoint)
  6. 3. OpenCV inpainting(终极回退,不用GAN)
  7. 用法:
  8. from gan_experiments_lab.lama_inpaint import LamaInpainter
  9. inpaint = LamaInpainter(device="cpu")
  10. result = inpaint.inpaint(bgr_image, mask_bool)
  11. """
  12. from __future__ import annotations
  13. import sys
  14. from pathlib import Path
  15. from typing import Optional
  16. import cv2
  17. import numpy as np
  18. from loguru import logger
  19. def _check_simple_lama() -> bool:
  20. try:
  21. import simple_lama_inpainting # noqa: F401
  22. return True
  23. except ImportError:
  24. return False
  25. def _check_lama_repo() -> Optional[Path]:
  26. """检查本地是否有 lama 仓库并已加入 sys.path。"""
  27. candidates = [
  28. Path(__file__).parent / "lama",
  29. Path(__file__).parents[2] / "lama",
  30. Path.home() / "lama",
  31. Path("/tmp/lama"),
  32. ]
  33. for p in candidates:
  34. if (p / "saicinpainting" / "__init__.py").exists():
  35. return p
  36. return None
  37. class LamaInpainter:
  38. """LaMa inpainting 门面,自动选择可用后端。"""
  39. def __init__(
  40. self,
  41. *,
  42. device: str = "cpu",
  43. inference_size: Optional[int] = None,
  44. pad_to_multiple: int = 8,
  45. ):
  46. self._device = device
  47. self._inference_size = inference_size # None = 保持原尺寸
  48. self._pad_to_multiple = pad_to_multiple
  49. self._model = None
  50. self._backend = None # "simple_lama" | "lama_repo" | "opencv"
  51. self._lama_repo_path: Optional[Path] = None
  52. @property
  53. def is_available(self) -> bool:
  54. if self._backend is not None:
  55. return self._backend != "opencv"
  56. if _check_simple_lama():
  57. self._backend = "simple_lama"
  58. return True
  59. if _check_lama_repo():
  60. self._backend = "lama_repo"
  61. return True
  62. return False
  63. def load(self) -> bool:
  64. """加载模型,返回是否成功。"""
  65. if self._model is not None:
  66. return True
  67. if _check_simple_lama():
  68. return self._load_simple_lama()
  69. repo = _check_lama_repo()
  70. if repo:
  71. return self._load_lama_repo(repo)
  72. logger.warning("LaMa backends 都不可用,将回退 OpenCV inpainting")
  73. self._backend = "opencv"
  74. return False
  75. def _load_simple_lama(self) -> bool:
  76. try:
  77. from simple_lama_inpainting import SimpleLama
  78. self._model = SimpleLama(device=self._device)
  79. self._backend = "simple_lama"
  80. logger.info(f"LaMa (simple_lama_inpainting) 已加载, device={self._device}")
  81. return True
  82. except Exception as e:
  83. logger.warning(f"simple_lama_inpainting 加载失败: {e}")
  84. return False
  85. def _load_lama_repo(self, repo_path: Path) -> bool:
  86. try:
  87. if str(repo_path) not in sys.path:
  88. sys.path.insert(0, str(repo_path))
  89. from omegaconf import OmegaConf
  90. from saicinpainting.training.trainers import load_checkpoint
  91. config_path = repo_path / "big-lama" / "config.yaml"
  92. ckpt_path = repo_path / "big-lama" / "models" / "best.ckpt"
  93. if not config_path.exists() or not ckpt_path.exists():
  94. logger.warning(
  95. f"lama 模型文件缺失。请下载: "
  96. f"wget https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.zip && "
  97. f"unzip big-lama.zip -d {repo_path}"
  98. )
  99. return False
  100. conf = OmegaConf.load(str(config_path))
  101. conf.training_model.predict_only = True
  102. conf.visualizer.kind = "noop"
  103. model = load_checkpoint(conf, str(ckpt_path), strict=False, map_location="cpu")
  104. model.eval()
  105. if self._device != "cpu":
  106. model.cuda()
  107. self._model = model
  108. self._lama_repo_path = repo_path
  109. self._backend = "lama_repo"
  110. logger.info(f"LaMa (lama_repo) 已加载, device={self._device}")
  111. return True
  112. except Exception as e:
  113. logger.warning(f"lama_repo 加载失败: {e}")
  114. return False
  115. def inpaint(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
  116. """
  117. 修复图像。
  118. Args:
  119. image: BGR ndarray (H, W, 3), uint8
  120. mask: bool ndarray (H, W), True=需要修复的水印区域
  121. Returns:
  122. BGR ndarray (H, W, 3), uint8, or None
  123. """
  124. if not self._model:
  125. if not self.load():
  126. return self._opencv_inpaint(image, mask)
  127. if self._backend == "simple_lama":
  128. return self._inpaint_simple_lama(image, mask)
  129. elif self._backend == "lama_repo":
  130. return self._inpaint_lama_repo(image, mask)
  131. else:
  132. return self._opencv_inpaint(image, mask)
  133. def _inpaint_simple_lama(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
  134. try:
  135. rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  136. mask_u8 = mask.astype(np.uint8) * 255
  137. # 按需 resize
  138. if self._inference_size:
  139. rgb, mask_u8, orig_size = self._resize_to_inference(rgb, mask_u8)
  140. result_rgb = self._model(rgb, mask_u8)
  141. if self._inference_size:
  142. result_rgb = cv2.resize(result_rgb, (orig_size[1], orig_size[0]))
  143. return cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
  144. except Exception as e:
  145. logger.warning(f"simple_lama 推理失败: {e}")
  146. return None
  147. def _inpaint_lama_repo(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
  148. try:
  149. import torch
  150. import torch.nn.functional as F
  151. from saicinpainting.evaluation.data import pad_tensor_to_modulo
  152. rgb = cv2.cvtColor(image.astype(np.float32) / 255.0, cv2.COLOR_BGR2RGB)
  153. mask_f = mask.astype(np.float32)
  154. orig_h, orig_w = rgb.shape[:2]
  155. # resize
  156. if self._inference_size:
  157. rgb, mask_f, (orig_w, orig_h) = self._resize_image_mask(rgb, mask_f)
  158. img_t = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
  159. mask_t = torch.from_numpy(mask_f).unsqueeze(0).unsqueeze(0)
  160. img_t = pad_tensor_to_modulo(img_t, self._pad_to_multiple)
  161. mask_t = pad_tensor_to_modulo(mask_t, self._pad_to_multiple)
  162. if self._device != "cpu":
  163. img_t = img_t.cuda()
  164. mask_t = mask_t.cuda()
  165. with torch.no_grad():
  166. output = self._model(img_t, mask_t)
  167. # output shape: (B, C, H, W)
  168. result = output[0].permute(1, 2, 0).cpu().numpy()
  169. # 裁掉 pad
  170. result = result[:orig_h, :orig_w, :]
  171. result = np.clip(result, 0, 1)
  172. result_u8 = (result * 255).astype(np.uint8)
  173. return cv2.cvtColor(result_u8, cv2.COLOR_RGB2BGR)
  174. except Exception as e:
  175. logger.warning(f"lama_repo 推理失败: {e}")
  176. return None
  177. def _resize_to_inference(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
  178. h, w = rgb.shape[:2]
  179. size = self._inference_size or min(h, w)
  180. scale = size / min(h, w)
  181. new_w, new_h = int(w * scale), int(h * scale)
  182. rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
  183. mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
  184. return rgb_rs, mask_rs, (w, h)
  185. def _resize_image_mask(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
  186. h, w = rgb.shape[:2]
  187. size = self._inference_size or min(h, w)
  188. scale = size / min(h, w)
  189. new_w, new_h = int(w * scale), int(h * scale)
  190. rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
  191. mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
  192. return rgb_rs, mask_rs, (w, h)
  193. def _opencv_inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
  194. """OpenCV Telea inpainting 回退(非GAN)。"""
  195. logger.info("使用 OpenCV inpainting 回退")
  196. mask_u8 = mask.astype(np.uint8) * 255
  197. return cv2.inpaint(image, mask_u8, inpaintRadius=5, flags=cv2.INPAINT_TELEA)
  198. if __name__ == "__main__":
  199. # 快速功能测试
  200. print("LaMa 后端检测:")
  201. print(f" simple_lama_inpainting: {_check_simple_lama()}")
  202. repo = _check_lama_repo()
  203. print(f" lama_repo: {repo}")
  204. inpaint = LamaInpainter(device="cpu")
  205. print(f" is_available: {inpaint.is_available}")