sam_processer.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from copy import deepcopy
  15. from typing import List, Optional, Tuple, Union
  16. import numpy as np
  17. import PIL
  18. from ....utils.benchmark import benchmark
  19. def _get_preprocess_shape(
  20. oldh: int, oldw: int, long_side_length: int
  21. ) -> Tuple[int, int]:
  22. """Compute the output size given input size and target long side length."""
  23. scale = long_side_length * 1.0 / max(oldh, oldw)
  24. newh, neww = oldh * scale, oldw * scale
  25. neww = int(neww + 0.5)
  26. newh = int(newh + 0.5)
  27. return (newh, neww)
  28. class SAMProcessor(object):
  29. def __init__(
  30. self,
  31. size: Optional[Union[List[int], int]] = None,
  32. image_mean: Union[float, List[float]] = [123.675, 116.28, 103.53],
  33. image_std: Union[float, List[float]] = [58.395, 57.12, 57.375],
  34. **kwargs,
  35. ) -> None:
  36. size = size if size is not None else 1024
  37. self.size = size
  38. if isinstance(image_mean, float):
  39. image_mean = [image_mean] * 3
  40. if isinstance(image_std, float):
  41. image_std = [image_std] * 3
  42. self.image_mean = image_mean
  43. self.image_std = image_std
  44. self.image_processor = SamImageProcessor(
  45. self.size, self.image_mean, self.image_std
  46. )
  47. self.prompt_processor = SamPromptProcessor(self.size)
  48. def preprocess(
  49. self,
  50. images,
  51. *,
  52. point_prompt=None,
  53. box_prompt=None,
  54. **kwargs,
  55. ):
  56. if point_prompt is not None and box_prompt is not None:
  57. raise ValueError(
  58. "SAM can only use either points or boxes as prompt, not both at the same time."
  59. )
  60. if point_prompt is None and box_prompt is None:
  61. raise ValueError(
  62. "SAM must use either points or boxes as prompt, now both is None."
  63. )
  64. point_prompt = (
  65. np.array(point_prompt).reshape(-1, 2) if point_prompt is not None else None
  66. )
  67. box_prompt = (
  68. np.array(box_prompt).reshape(-1, 4) if box_prompt is not None else None
  69. )
  70. if point_prompt is not None and point_prompt.size > 2:
  71. raise ValueError(
  72. "SAM now only support one point for using point promot, your input format should be like [[x, y]] only."
  73. )
  74. image_seg = self.image_processor(images)
  75. self.original_size = self.image_processor.original_size
  76. self.input_size = self.image_processor.input_size
  77. prompt = self.prompt_processor(
  78. self.original_size,
  79. point_coords=point_prompt,
  80. box=box_prompt,
  81. )
  82. return image_seg, prompt
  83. def postprocess(self, low_res_masks, mask_threshold: float = 0.0):
  84. import paddle
  85. import paddle.nn.functional as F
  86. if isinstance(low_res_masks, list):
  87. assert len(low_res_masks) == 1
  88. low_res_masks = low_res_masks[0]
  89. masks = F.interpolate(
  90. paddle.to_tensor(low_res_masks),
  91. (self.size, self.size),
  92. mode="bilinear",
  93. align_corners=False,
  94. )
  95. masks = masks[..., : self.input_size[0], : self.input_size[1]]
  96. masks = F.interpolate(
  97. masks, self.original_size, mode="bilinear", align_corners=False
  98. )
  99. masks = (masks > mask_threshold).numpy().astype(np.int8)
  100. return [masks]
  101. @benchmark.timeit
  102. class SamPromptProcessor(object):
  103. """Constructs a Sam prompt processor."""
  104. def __init__(
  105. self,
  106. size: int = 1024,
  107. ):
  108. self.size = size
  109. def apply_coords(
  110. self, coords: np.ndarray, original_size: Tuple[int, ...]
  111. ) -> np.ndarray:
  112. """Expects a numpy array of length 2 in the final dimension. Requires the
  113. original image size in (H, W) format.
  114. """
  115. old_h, old_w = original_size
  116. new_h, new_w = _get_preprocess_shape(
  117. original_size[0], original_size[1], self.size
  118. )
  119. coords = deepcopy(coords).astype(float)
  120. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  121. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  122. return coords
  123. def apply_boxes(
  124. self, boxes: np.ndarray, original_size: Tuple[int, ...]
  125. ) -> np.ndarray:
  126. """Expects a numpy array shape Nx4. Requires the original image size
  127. in (H, W) format.
  128. """
  129. boxes = self.apply_coords(boxes.reshape([-1, 2, 2]), original_size)
  130. return boxes.reshape([-1, 4])
  131. def __call__(
  132. self,
  133. original_size,
  134. point_coords=None,
  135. box=None,
  136. **kwargs,
  137. ):
  138. if point_coords is not None and box is not None:
  139. raise ValueError(
  140. "SAM can only use either points or boxes as prompt, not both at the same time."
  141. )
  142. if point_coords is not None:
  143. point_coords = self.apply_coords(point_coords, original_size)
  144. point_coords = point_coords[None, ...]
  145. return point_coords.astype(np.float32)
  146. if box is not None:
  147. box = self.apply_boxes(box, original_size)
  148. return box.astype(np.float32)
  149. @benchmark.timeit
  150. class SamImageProcessor(object):
  151. """Constructs a Sam image processor."""
  152. def __init__(
  153. self,
  154. size: Union[List[int], int] = None,
  155. image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5],
  156. image_std: Union[float, List[float]] = [0.5, 0.5, 0.5],
  157. **kwargs,
  158. ) -> None:
  159. size = size if size is not None else 1024
  160. self.size = size
  161. if isinstance(image_mean, float):
  162. image_mean = [image_mean] * 3
  163. if isinstance(image_std, float):
  164. image_std = [image_std] * 3
  165. self.image_mean = image_mean
  166. self.image_std = image_std
  167. self.original_size = None
  168. self.input_size = None
  169. def apply_image(self, image: np.ndarray) -> np.ndarray:
  170. """Expects a numpy array with shape HxWxC in uint8 format."""
  171. import paddle.vision.transforms as T
  172. target_size = _get_preprocess_shape(image.shape[0], image.shape[1], self.size)
  173. if isinstance(image, np.ndarray):
  174. image = PIL.Image.fromarray(image)
  175. return np.array(T.resize(image, target_size))
  176. def __call__(self, images, **kwargs):
  177. if not isinstance(images, (list, tuple)):
  178. images = [images]
  179. return self.preprocess(images)
  180. def preprocess(
  181. self,
  182. images,
  183. ):
  184. """Preprocess an image or a batch of images with a same shape."""
  185. import paddle
  186. import paddle.nn.functional as F
  187. input_image = [self.apply_image(image) for image in images]
  188. input_image_paddle = paddle.to_tensor(input_image).cast("int32")
  189. input_image_paddle = input_image_paddle.transpose([0, 3, 1, 2])
  190. original_image_size = images[0].shape[:2]
  191. self.original_size = original_image_size
  192. self.input_size = tuple(input_image_paddle.shape[-2:])
  193. mean = paddle.to_tensor(self.image_mean).reshape([-1, 1, 1])
  194. std = paddle.to_tensor(self.image_std).reshape([-1, 1, 1])
  195. input_image_paddle = (input_image_paddle.astype(std.dtype) - mean) / std
  196. h, w = input_image_paddle.shape[-2:]
  197. padh = self.size - h
  198. padw = self.size - w
  199. input_image = F.pad(input_image_paddle, (0, padw, 0, padh))
  200. return input_image.numpy()