sam_processer.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from copy import deepcopy
  15. from typing import List, Optional, Tuple, Union
  16. import numpy as np
  17. import PIL
  18. from .....utils.deps import class_requires_deps
  19. from ....utils.benchmark import benchmark
  20. def _get_preprocess_shape(
  21. oldh: int, oldw: int, long_side_length: int
  22. ) -> Tuple[int, int]:
  23. """Compute the output size given input size and target long side length."""
  24. scale = long_side_length * 1.0 / max(oldh, oldw)
  25. newh, neww = oldh * scale, oldw * scale
  26. neww = int(neww + 0.5)
  27. newh = int(newh + 0.5)
  28. return (newh, neww)
  29. @class_requires_deps("paddlepaddle")
  30. class SAMProcessor(object):
  31. def __init__(
  32. self,
  33. size: Optional[Union[List[int], int]] = None,
  34. image_mean: Union[float, List[float]] = [123.675, 116.28, 103.53],
  35. image_std: Union[float, List[float]] = [58.395, 57.12, 57.375],
  36. **kwargs,
  37. ) -> None:
  38. size = size if size is not None else 1024
  39. self.size = size
  40. if isinstance(image_mean, float):
  41. image_mean = [image_mean] * 3
  42. if isinstance(image_std, float):
  43. image_std = [image_std] * 3
  44. self.image_mean = image_mean
  45. self.image_std = image_std
  46. self.image_processor = SamImageProcessor(
  47. self.size, self.image_mean, self.image_std
  48. )
  49. self.prompt_processor = SamPromptProcessor(self.size)
  50. def preprocess(
  51. self,
  52. images,
  53. *,
  54. point_prompt=None,
  55. box_prompt=None,
  56. **kwargs,
  57. ):
  58. if point_prompt is not None and box_prompt is not None:
  59. raise ValueError(
  60. "SAM can only use either points or boxes as prompt, not both at the same time."
  61. )
  62. if point_prompt is None and box_prompt is None:
  63. raise ValueError(
  64. "SAM must use either points or boxes as prompt, now both is None."
  65. )
  66. point_prompt = (
  67. np.array(point_prompt).reshape(-1, 2) if point_prompt is not None else None
  68. )
  69. box_prompt = (
  70. np.array(box_prompt).reshape(-1, 4) if box_prompt is not None else None
  71. )
  72. if point_prompt is not None and point_prompt.size > 2:
  73. raise ValueError(
  74. "SAM now only support one point for using point promot, your input format should be like [[x, y]] only."
  75. )
  76. image_seg = self.image_processor(images)
  77. self.original_size = self.image_processor.original_size
  78. self.input_size = self.image_processor.input_size
  79. prompt = self.prompt_processor(
  80. self.original_size,
  81. point_coords=point_prompt,
  82. box=box_prompt,
  83. )
  84. return image_seg, prompt
  85. def postprocess(self, low_res_masks, mask_threshold: float = 0.0):
  86. import paddle
  87. import paddle.nn.functional as F
  88. if isinstance(low_res_masks, list):
  89. assert len(low_res_masks) == 1
  90. low_res_masks = low_res_masks[0]
  91. masks = F.interpolate(
  92. paddle.to_tensor(low_res_masks),
  93. (self.size, self.size),
  94. mode="bilinear",
  95. align_corners=False,
  96. )
  97. masks = masks[..., : self.input_size[0], : self.input_size[1]]
  98. masks = F.interpolate(
  99. masks, self.original_size, mode="bilinear", align_corners=False
  100. )
  101. masks = (masks > mask_threshold).numpy().astype(np.int8)
  102. return [masks]
  103. @benchmark.timeit
  104. class SamPromptProcessor(object):
  105. """Constructs a Sam prompt processor."""
  106. def __init__(
  107. self,
  108. size: int = 1024,
  109. ):
  110. self.size = size
  111. def apply_coords(
  112. self, coords: np.ndarray, original_size: Tuple[int, ...]
  113. ) -> np.ndarray:
  114. """Expects a numpy array of length 2 in the final dimension. Requires the
  115. original image size in (H, W) format.
  116. """
  117. old_h, old_w = original_size
  118. new_h, new_w = _get_preprocess_shape(
  119. original_size[0], original_size[1], self.size
  120. )
  121. coords = deepcopy(coords).astype(float)
  122. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  123. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  124. return coords
  125. def apply_boxes(
  126. self, boxes: np.ndarray, original_size: Tuple[int, ...]
  127. ) -> np.ndarray:
  128. """Expects a numpy array shape Nx4. Requires the original image size
  129. in (H, W) format.
  130. """
  131. boxes = self.apply_coords(boxes.reshape([-1, 2, 2]), original_size)
  132. return boxes.reshape([-1, 4])
  133. def __call__(
  134. self,
  135. original_size,
  136. point_coords=None,
  137. box=None,
  138. **kwargs,
  139. ):
  140. if point_coords is not None and box is not None:
  141. raise ValueError(
  142. "SAM can only use either points or boxes as prompt, not both at the same time."
  143. )
  144. if point_coords is not None:
  145. point_coords = self.apply_coords(point_coords, original_size)
  146. point_coords = point_coords[None, ...]
  147. return point_coords.astype(np.float32)
  148. if box is not None:
  149. box = self.apply_boxes(box, original_size)
  150. return box.astype(np.float32)
  151. @benchmark.timeit
  152. @class_requires_deps("paddlepaddle")
  153. class SamImageProcessor(object):
  154. """Constructs a Sam image processor."""
  155. def __init__(
  156. self,
  157. size: Union[List[int], int] = None,
  158. image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5],
  159. image_std: Union[float, List[float]] = [0.5, 0.5, 0.5],
  160. **kwargs,
  161. ) -> None:
  162. size = size if size is not None else 1024
  163. self.size = size
  164. if isinstance(image_mean, float):
  165. image_mean = [image_mean] * 3
  166. if isinstance(image_std, float):
  167. image_std = [image_std] * 3
  168. self.image_mean = image_mean
  169. self.image_std = image_std
  170. self.original_size = None
  171. self.input_size = None
  172. def apply_image(self, image: np.ndarray) -> np.ndarray:
  173. """Expects a numpy array with shape HxWxC in uint8 format."""
  174. import paddle.vision.transforms as T
  175. target_size = _get_preprocess_shape(image.shape[0], image.shape[1], self.size)
  176. if isinstance(image, np.ndarray):
  177. image = PIL.Image.fromarray(image)
  178. return np.array(T.resize(image, target_size))
  179. def __call__(self, images, **kwargs):
  180. if not isinstance(images, (list, tuple)):
  181. images = [images]
  182. return self.preprocess(images)
  183. def preprocess(
  184. self,
  185. images,
  186. ):
  187. """Preprocess an image or a batch of images with a same shape."""
  188. import paddle
  189. import paddle.nn.functional as F
  190. input_image = [self.apply_image(image) for image in images]
  191. input_image_paddle = paddle.to_tensor(input_image).cast("int32")
  192. input_image_paddle = input_image_paddle.transpose([0, 3, 1, 2])
  193. original_image_size = images[0].shape[:2]
  194. self.original_size = original_image_size
  195. self.input_size = tuple(input_image_paddle.shape[-2:])
  196. mean = paddle.to_tensor(self.image_mean).reshape([-1, 1, 1])
  197. std = paddle.to_tensor(self.image_std).reshape([-1, 1, 1])
  198. input_image_paddle = (input_image_paddle.astype(std.dtype) - mean) / std
  199. h, w = input_image_paddle.shape[-2:]
  200. padh = self.size - h
  201. padw = self.size - w
  202. input_image = F.pad(input_image_paddle, (0, padw, 0, padh))
  203. return input_image.numpy()