sam_processer.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Dict, List, Optional, Union, Tuple
  16. import numpy as np
  17. import PIL
  18. from copy import deepcopy
  19. from .....utils.lazy_loader import LazyLoader
  20. # NOTE: LazyLoader is used to avoid conflicts between ultra-infer and Paddle
  21. paddle = LazyLoader("lazy_paddle", globals(), "paddle")
  22. T = LazyLoader("T", globals(), "paddle.vision.transforms")
  23. F = LazyLoader("F", globals(), "paddle.nn.functional")
  24. def _get_preprocess_shape(
  25. oldh: int, oldw: int, long_side_length: int
  26. ) -> Tuple[int, int]:
  27. """Compute the output size given input size and target long side length."""
  28. scale = long_side_length * 1.0 / max(oldh, oldw)
  29. newh, neww = oldh * scale, oldw * scale
  30. neww = int(neww + 0.5)
  31. newh = int(newh + 0.5)
  32. return (newh, neww)
  33. class SAMProcessor(object):
  34. def __init__(
  35. self,
  36. size: Optional[Union[List[int], int]] = None,
  37. image_mean: Union[float, List[float]] = [123.675, 116.28, 103.53],
  38. image_std: Union[float, List[float]] = [58.395, 57.12, 57.375],
  39. **kwargs,
  40. ) -> None:
  41. size = size if size is not None else 1024
  42. self.size = size
  43. if isinstance(image_mean, float):
  44. image_mean = [image_mean] * 3
  45. if isinstance(image_std, float):
  46. image_std = [image_std] * 3
  47. self.image_mean = image_mean
  48. self.image_std = image_std
  49. self.image_processor = SamImageProcessor(
  50. self.size, self.image_mean, self.image_std
  51. )
  52. self.prompt_processor = SamPromptProcessor(self.size)
  53. def preprocess(
  54. self,
  55. images,
  56. *,
  57. point_prompt=None,
  58. box_prompt=None,
  59. **kwargs,
  60. ):
  61. if point_prompt is not None and box_prompt is not None:
  62. raise ValueError(
  63. "SAM can only use either points or boxes as prompt, not both at the same time."
  64. )
  65. if point_prompt is None and box_prompt is None:
  66. raise ValueError(
  67. "SAM must use either points or boxes as prompt, now both is None."
  68. )
  69. point_prompt = (
  70. np.array(point_prompt).reshape(-1, 2) if point_prompt is not None else None
  71. )
  72. box_prompt = (
  73. np.array(box_prompt).reshape(-1, 4) if box_prompt is not None else None
  74. )
  75. if point_prompt is not None and point_prompt.size > 2:
  76. raise ValueError(
  77. "SAM now only support one point for using point promot, your input format should be like [[x, y]] only."
  78. )
  79. image_seg = self.image_processor(images)
  80. self.original_size = self.image_processor.original_size
  81. self.input_size = self.image_processor.input_size
  82. prompt = self.prompt_processor(
  83. self.original_size,
  84. point_coords=point_prompt,
  85. box=box_prompt,
  86. )
  87. return image_seg, prompt
  88. def postprocess(self, low_res_masks, mask_threshold: float = 0.0):
  89. if isinstance(low_res_masks, list):
  90. assert len(low_res_masks) == 1
  91. low_res_masks = low_res_masks[0]
  92. masks = F.interpolate(
  93. paddle.to_tensor(low_res_masks),
  94. (self.size, self.size),
  95. mode="bilinear",
  96. align_corners=False,
  97. )
  98. masks = masks[..., : self.input_size[0], : self.input_size[1]]
  99. masks = F.interpolate(
  100. masks, self.original_size, mode="bilinear", align_corners=False
  101. )
  102. masks = (masks > mask_threshold).numpy().astype(np.int8)
  103. return [masks]
  104. class SamPromptProcessor(object):
  105. """Constructs a Sam prompt processor."""
  106. def __init__(
  107. self,
  108. size: int = 1024,
  109. ):
  110. self.size = size
  111. def apply_coords(
  112. self, coords: np.ndarray, original_size: Tuple[int, ...]
  113. ) -> np.ndarray:
  114. """Expects a numpy array of length 2 in the final dimension. Requires the
  115. original image size in (H, W) format.
  116. """
  117. old_h, old_w = original_size
  118. new_h, new_w = _get_preprocess_shape(
  119. original_size[0], original_size[1], self.size
  120. )
  121. coords = deepcopy(coords).astype(float)
  122. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  123. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  124. return coords
  125. def apply_boxes(
  126. self, boxes: np.ndarray, original_size: Tuple[int, ...]
  127. ) -> np.ndarray:
  128. """Expects a numpy array shape Nx4. Requires the original image size
  129. in (H, W) format.
  130. """
  131. boxes = self.apply_coords(boxes.reshape([-1, 2, 2]), original_size)
  132. return boxes.reshape([-1, 4])
  133. def __call__(
  134. self,
  135. original_size,
  136. point_coords=None,
  137. box=None,
  138. **kwargs,
  139. ):
  140. if point_coords is not None and box is not None:
  141. raise ValueError(
  142. "SAM can only use either points or boxes as prompt, not both at the same time."
  143. )
  144. if point_coords is not None:
  145. point_coords = self.apply_coords(point_coords, original_size)
  146. point_coords = point_coords[None, ...]
  147. return point_coords.astype(np.float32)
  148. if box is not None:
  149. box = self.apply_boxes(box, original_size)
  150. return box.astype(np.float32)
  151. class SamImageProcessor(object):
  152. """Constructs a Sam image processor."""
  153. def __init__(
  154. self,
  155. size: Union[List[int], int] = None,
  156. image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5],
  157. image_std: Union[float, List[float]] = [0.5, 0.5, 0.5],
  158. **kwargs,
  159. ) -> None:
  160. size = size if size is not None else 1024
  161. self.size = size
  162. if isinstance(image_mean, float):
  163. image_mean = [image_mean] * 3
  164. if isinstance(image_std, float):
  165. image_std = [image_std] * 3
  166. self.image_mean = image_mean
  167. self.image_std = image_std
  168. self.original_size = None
  169. self.input_size = None
  170. def apply_image(self, image: np.ndarray) -> np.ndarray:
  171. """Expects a numpy array with shape HxWxC in uint8 format."""
  172. target_size = _get_preprocess_shape(image.shape[0], image.shape[1], self.size)
  173. if isinstance(image, np.ndarray):
  174. image = PIL.Image.fromarray(image)
  175. return np.array(T.resize(image, target_size))
  176. def __call__(self, images, **kwargs):
  177. if not isinstance(images, (list, tuple)):
  178. images = [images]
  179. return self.preprocess(images)
  180. def preprocess(
  181. self,
  182. images,
  183. ):
  184. """Preprocess an image or a batch of images with a same shape."""
  185. size = self.size
  186. input_image = [self.apply_image(image) for image in images]
  187. input_image_paddle = paddle.to_tensor(input_image).cast("int32")
  188. input_image_paddle = input_image_paddle.transpose([0, 3, 1, 2])
  189. original_image_size = images[0].shape[:2]
  190. self.original_size = original_image_size
  191. self.input_size = tuple(input_image_paddle.shape[-2:])
  192. mean = paddle.to_tensor(self.image_mean).reshape([-1, 1, 1])
  193. std = paddle.to_tensor(self.image_std).reshape([-1, 1, 1])
  194. input_image_paddle = (input_image_paddle.astype(std.dtype) - mean) / std
  195. h, w = input_image_paddle.shape[-2:]
  196. padh = self.size - h
  197. padw = self.size - w
  198. input_image = F.pad(input_image_paddle, (0, padw, 0, padh))
  199. return input_image.numpy()