lama_inpaint.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. """
  2. LaMa (Large Mask Inpainting) 推理模块。
  3. 封装预训练LaMa模型的加载与推理,方案选择(按优先级):
  4. 1. simple_lama_inpainting pip包(最简)
  5. 2. 本地 lama 仓库代码(big-lama checkpoint)
  6. 3. OpenCV inpainting(终极回退,不用GAN)
  7. 用法:
  8. from gan_experiments_lab.lama_inpaint import LamaInpainter
  9. inpaint = LamaInpainter(device="cpu")
  10. result = inpaint.inpaint(bgr_image, mask_bool)
  11. """
  12. from __future__ import annotations
  13. import sys
  14. from pathlib import Path
  15. from typing import Optional
  16. import cv2
  17. import numpy as np
  18. from loguru import logger
  19. def _check_simple_lama() -> bool:
  20. try:
  21. import simple_lama_inpainting # noqa: F401
  22. return True
  23. except ImportError:
  24. return False
  25. def _check_lama_repo() -> Optional[Path]:
  26. """检查本地是否有 lama 仓库并已加入 sys.path。"""
  27. candidates = [
  28. Path(__file__).parent / "lama",
  29. Path(__file__).parents[3] / "lama",
  30. Path.home() / "lama",
  31. Path("/tmp/lama"),
  32. ]
  33. for p in candidates:
  34. if (p / "saicinpainting" / "__init__.py").exists():
  35. return p
  36. return None
  37. class LamaInpainter:
  38. """LaMa inpainting 门面,自动选择可用后端。"""
  39. def __init__(
  40. self,
  41. *,
  42. device: str = "cpu",
  43. inference_size: Optional[int] = None,
  44. pad_to_multiple: int = 8,
  45. model_ckpt_path: Optional[str] = None,
  46. model_config_path: Optional[str] = None,
  47. lama_repo_path: Optional[str] = None,
  48. ):
  49. self._device = device
  50. self._inference_size = inference_size # None = 保持原尺寸
  51. self._pad_to_multiple = pad_to_multiple
  52. self._model_ckpt_path = Path(model_ckpt_path).expanduser() if model_ckpt_path else None
  53. self._model_config_path = Path(model_config_path).expanduser() if model_config_path else None
  54. self._preferred_repo_path = Path(lama_repo_path).expanduser() if lama_repo_path else None
  55. self._model = None
  56. self._backend = None # "simple_lama" | "lama_repo" | "opencv"
  57. self._lama_repo_path: Optional[Path] = None
  58. @property
  59. def is_available(self) -> bool:
  60. if self._backend is not None:
  61. return self._backend != "opencv"
  62. # 显式提供本地权重/仓库时,优先走 lama_repo 流程
  63. if self._model_ckpt_path or self._preferred_repo_path:
  64. self._backend = "lama_repo"
  65. return True
  66. if _check_simple_lama():
  67. self._backend = "simple_lama"
  68. return True
  69. if _check_lama_repo():
  70. self._backend = "lama_repo"
  71. return True
  72. return False
  73. def load(self) -> bool:
  74. """加载模型,返回是否成功。"""
  75. if self._model is not None:
  76. return True
  77. if _check_simple_lama() and not self._model_ckpt_path:
  78. return self._load_simple_lama()
  79. repo = self._preferred_repo_path or _check_lama_repo()
  80. if repo or self._model_ckpt_path:
  81. if self._load_lama_repo(repo):
  82. return True
  83. # lama_repo 失败后,若环境有 simple_lama,退回 simple_lama 而非 OpenCV
  84. if _check_simple_lama():
  85. return self._load_simple_lama()
  86. return False
  87. logger.warning("LaMa backends 都不可用,将回退 OpenCV inpainting")
  88. self._backend = "opencv"
  89. return False
  90. def _load_simple_lama(self) -> bool:
  91. try:
  92. from simple_lama_inpainting import SimpleLama
  93. self._model = SimpleLama(device=self._device)
  94. self._backend = "simple_lama"
  95. logger.info(f"LaMa (simple_lama_inpainting) 已加载, device={self._device}")
  96. return True
  97. except Exception as e:
  98. logger.warning(f"simple_lama_inpainting 加载失败: {e}")
  99. return False
  100. def _load_lama_repo(self, repo_path: Optional[Path]) -> bool:
  101. try:
  102. if repo_path and str(repo_path) not in sys.path:
  103. sys.path.insert(0, str(repo_path))
  104. from omegaconf import OmegaConf
  105. from saicinpainting.training.trainers import load_checkpoint
  106. if self._model_ckpt_path:
  107. ckpt_path = self._model_ckpt_path
  108. elif repo_path:
  109. ckpt_path = repo_path / "big-lama" / "models" / "best.ckpt"
  110. else:
  111. logger.warning("未提供 LaMa 权重路径,且未找到 lama 仓库目录")
  112. return False
  113. config_path = self._model_config_path
  114. if config_path is None:
  115. # 优先使用权重目录的上一级 config.yaml(与官方 big-lama 目录结构一致)
  116. if ckpt_path.parent.name == "models":
  117. sibling_cfg = ckpt_path.parent.parent / "config.yaml"
  118. if sibling_cfg.exists():
  119. config_path = sibling_cfg
  120. if config_path is None and repo_path:
  121. config_path = repo_path / "big-lama" / "config.yaml"
  122. if not config_path.exists() or not ckpt_path.exists():
  123. logger.warning(
  124. f"lama 模型文件缺失: config={config_path}, ckpt={ckpt_path}"
  125. )
  126. return False
  127. conf = OmegaConf.load(str(config_path))
  128. conf.training_model.predict_only = True
  129. conf.visualizer.kind = "noop"
  130. model = load_checkpoint(conf, str(ckpt_path), strict=False, map_location="cpu")
  131. model.eval()
  132. if self._device != "cpu":
  133. model.cuda()
  134. self._model = model
  135. self._lama_repo_path = repo_path
  136. self._backend = "lama_repo"
  137. logger.info(f"LaMa (lama_repo) 已加载, device={self._device}")
  138. return True
  139. except Exception as e:
  140. logger.warning(f"lama_repo 加载失败: {e}")
  141. return False
  142. def inpaint(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
  143. """
  144. 修复图像。
  145. Args:
  146. image: BGR ndarray (H, W, 3), uint8
  147. mask: bool ndarray (H, W), True=需要修复的水印区域
  148. Returns:
  149. BGR ndarray (H, W, 3), uint8, or None
  150. """
  151. if not self._model:
  152. if not self.load():
  153. return self._opencv_inpaint(image, mask)
  154. if self._backend == "simple_lama":
  155. return self._inpaint_simple_lama(image, mask)
  156. elif self._backend == "lama_repo":
  157. return self._inpaint_lama_repo(image, mask)
  158. else:
  159. return self._opencv_inpaint(image, mask)
  160. def _inpaint_simple_lama(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
  161. try:
  162. rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  163. mask_u8 = mask.astype(np.uint8) * 255
  164. # 按需 resize
  165. if self._inference_size:
  166. rgb, mask_u8, orig_size = self._resize_to_inference(rgb, mask_u8)
  167. result_rgb = self._model(rgb, mask_u8)
  168. if not isinstance(result_rgb, np.ndarray):
  169. result_rgb = np.asarray(result_rgb)
  170. if self._inference_size:
  171. result_rgb = cv2.resize(result_rgb, (orig_size[1], orig_size[0]))
  172. return cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
  173. except Exception as e:
  174. logger.warning(f"simple_lama 推理失败: {e}")
  175. return None
  176. def _inpaint_lama_repo(self, image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
  177. try:
  178. import torch
  179. from saicinpainting.evaluation.data import pad_tensor_to_modulo
  180. rgb = cv2.cvtColor(image.astype(np.float32) / 255.0, cv2.COLOR_BGR2RGB)
  181. mask_f = mask.astype(np.float32)
  182. orig_h, orig_w = rgb.shape[:2]
  183. # resize
  184. if self._inference_size:
  185. rgb, mask_f, (orig_w, orig_h) = self._resize_image_mask(rgb, mask_f)
  186. img_t = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
  187. mask_t = torch.from_numpy(mask_f).unsqueeze(0).unsqueeze(0)
  188. img_t = pad_tensor_to_modulo(img_t, self._pad_to_multiple)
  189. mask_t = pad_tensor_to_modulo(mask_t, self._pad_to_multiple)
  190. if self._device != "cpu":
  191. img_t = img_t.cuda()
  192. mask_t = mask_t.cuda()
  193. with torch.no_grad():
  194. batch = {"image": img_t, "mask": mask_t}
  195. assert self._model is not None
  196. output = self._model(batch)
  197. pred = output["inpainted"] if isinstance(output, dict) else batch["inpainted"]
  198. result = pred[0].permute(1, 2, 0).cpu().numpy()
  199. # 裁掉 pad
  200. result = result[:orig_h, :orig_w, :]
  201. result = np.clip(result, 0, 1)
  202. result_u8 = (result * 255).astype(np.uint8)
  203. return cv2.cvtColor(result_u8, cv2.COLOR_RGB2BGR)
  204. except Exception as e:
  205. logger.warning(f"lama_repo 推理失败: {e}")
  206. return None
  207. def _resize_to_inference(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
  208. h, w = rgb.shape[:2]
  209. size = self._inference_size or min(h, w)
  210. scale = size / min(h, w)
  211. new_w, new_h = int(w * scale), int(h * scale)
  212. rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
  213. mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
  214. return rgb_rs, mask_rs, (w, h)
  215. def _resize_image_mask(self, rgb: np.ndarray, mask: np.ndarray) -> tuple:
  216. h, w = rgb.shape[:2]
  217. size = self._inference_size or min(h, w)
  218. scale = size / min(h, w)
  219. new_w, new_h = int(w * scale), int(h * scale)
  220. rgb_rs = cv2.resize(rgb, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
  221. mask_rs = cv2.resize(mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
  222. return rgb_rs, mask_rs, (w, h)
  223. def _opencv_inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
  224. """OpenCV Telea inpainting 回退(非GAN)。"""
  225. logger.info("使用 OpenCV inpainting 回退")
  226. mask_u8 = mask.astype(np.uint8) * 255
  227. return cv2.inpaint(image, mask_u8, inpaintRadius=5, flags=cv2.INPAINT_TELEA)
  228. if __name__ == "__main__":
  229. # 快速功能测试
  230. print("LaMa 后端检测:")
  231. print(f" simple_lama_inpainting: {_check_simple_lama()}")
  232. repo = _check_lama_repo()
  233. print(f" lama_repo: {repo}")
  234. inpaint = LamaInpainter(device="cpu")
  235. print(f" is_available: {inpaint.is_available}")