sam_processer.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Dict, List, Optional, Union, Tuple
  16. import numpy as np
  17. import PIL
  18. from copy import deepcopy
  19. from .....utils.lazy_loader import LazyLoader
  20. from ....utils.benchmark import benchmark
  21. # NOTE: LazyLoader is used to avoid conflicts between ultra-infer and Paddle
  22. paddle = LazyLoader("lazy_paddle", globals(), "paddle")
  23. T = LazyLoader("T", globals(), "paddle.vision.transforms")
  24. F = LazyLoader("F", globals(), "paddle.nn.functional")
  25. def _get_preprocess_shape(
  26. oldh: int, oldw: int, long_side_length: int
  27. ) -> Tuple[int, int]:
  28. """Compute the output size given input size and target long side length."""
  29. scale = long_side_length * 1.0 / max(oldh, oldw)
  30. newh, neww = oldh * scale, oldw * scale
  31. neww = int(neww + 0.5)
  32. newh = int(newh + 0.5)
  33. return (newh, neww)
  34. class SAMProcessor(object):
  35. def __init__(
  36. self,
  37. size: Optional[Union[List[int], int]] = None,
  38. image_mean: Union[float, List[float]] = [123.675, 116.28, 103.53],
  39. image_std: Union[float, List[float]] = [58.395, 57.12, 57.375],
  40. **kwargs,
  41. ) -> None:
  42. size = size if size is not None else 1024
  43. self.size = size
  44. if isinstance(image_mean, float):
  45. image_mean = [image_mean] * 3
  46. if isinstance(image_std, float):
  47. image_std = [image_std] * 3
  48. self.image_mean = image_mean
  49. self.image_std = image_std
  50. self.image_processor = SamImageProcessor(
  51. self.size, self.image_mean, self.image_std
  52. )
  53. self.prompt_processor = SamPromptProcessor(self.size)
  54. def preprocess(
  55. self,
  56. images,
  57. *,
  58. point_prompt=None,
  59. box_prompt=None,
  60. **kwargs,
  61. ):
  62. if point_prompt is not None and box_prompt is not None:
  63. raise ValueError(
  64. "SAM can only use either points or boxes as prompt, not both at the same time."
  65. )
  66. if point_prompt is None and box_prompt is None:
  67. raise ValueError(
  68. "SAM must use either points or boxes as prompt, now both is None."
  69. )
  70. point_prompt = (
  71. np.array(point_prompt).reshape(-1, 2) if point_prompt is not None else None
  72. )
  73. box_prompt = (
  74. np.array(box_prompt).reshape(-1, 4) if box_prompt is not None else None
  75. )
  76. if point_prompt is not None and point_prompt.size > 2:
  77. raise ValueError(
  78. "SAM now only support one point for using point promot, your input format should be like [[x, y]] only."
  79. )
  80. image_seg = self.image_processor(images)
  81. self.original_size = self.image_processor.original_size
  82. self.input_size = self.image_processor.input_size
  83. prompt = self.prompt_processor(
  84. self.original_size,
  85. point_coords=point_prompt,
  86. box=box_prompt,
  87. )
  88. return image_seg, prompt
  89. def postprocess(self, low_res_masks, mask_threshold: float = 0.0):
  90. if isinstance(low_res_masks, list):
  91. assert len(low_res_masks) == 1
  92. low_res_masks = low_res_masks[0]
  93. masks = F.interpolate(
  94. paddle.to_tensor(low_res_masks),
  95. (self.size, self.size),
  96. mode="bilinear",
  97. align_corners=False,
  98. )
  99. masks = masks[..., : self.input_size[0], : self.input_size[1]]
  100. masks = F.interpolate(
  101. masks, self.original_size, mode="bilinear", align_corners=False
  102. )
  103. masks = (masks > mask_threshold).numpy().astype(np.int8)
  104. return [masks]
  105. @benchmark.timeit
  106. class SamPromptProcessor(object):
  107. """Constructs a Sam prompt processor."""
  108. def __init__(
  109. self,
  110. size: int = 1024,
  111. ):
  112. self.size = size
  113. def apply_coords(
  114. self, coords: np.ndarray, original_size: Tuple[int, ...]
  115. ) -> np.ndarray:
  116. """Expects a numpy array of length 2 in the final dimension. Requires the
  117. original image size in (H, W) format.
  118. """
  119. old_h, old_w = original_size
  120. new_h, new_w = _get_preprocess_shape(
  121. original_size[0], original_size[1], self.size
  122. )
  123. coords = deepcopy(coords).astype(float)
  124. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  125. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  126. return coords
  127. def apply_boxes(
  128. self, boxes: np.ndarray, original_size: Tuple[int, ...]
  129. ) -> np.ndarray:
  130. """Expects a numpy array shape Nx4. Requires the original image size
  131. in (H, W) format.
  132. """
  133. boxes = self.apply_coords(boxes.reshape([-1, 2, 2]), original_size)
  134. return boxes.reshape([-1, 4])
  135. def __call__(
  136. self,
  137. original_size,
  138. point_coords=None,
  139. box=None,
  140. **kwargs,
  141. ):
  142. if point_coords is not None and box is not None:
  143. raise ValueError(
  144. "SAM can only use either points or boxes as prompt, not both at the same time."
  145. )
  146. if point_coords is not None:
  147. point_coords = self.apply_coords(point_coords, original_size)
  148. point_coords = point_coords[None, ...]
  149. return point_coords.astype(np.float32)
  150. if box is not None:
  151. box = self.apply_boxes(box, original_size)
  152. return box.astype(np.float32)
  153. @benchmark.timeit
  154. class SamImageProcessor(object):
  155. """Constructs a Sam image processor."""
  156. def __init__(
  157. self,
  158. size: Union[List[int], int] = None,
  159. image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5],
  160. image_std: Union[float, List[float]] = [0.5, 0.5, 0.5],
  161. **kwargs,
  162. ) -> None:
  163. size = size if size is not None else 1024
  164. self.size = size
  165. if isinstance(image_mean, float):
  166. image_mean = [image_mean] * 3
  167. if isinstance(image_std, float):
  168. image_std = [image_std] * 3
  169. self.image_mean = image_mean
  170. self.image_std = image_std
  171. self.original_size = None
  172. self.input_size = None
  173. def apply_image(self, image: np.ndarray) -> np.ndarray:
  174. """Expects a numpy array with shape HxWxC in uint8 format."""
  175. target_size = _get_preprocess_shape(image.shape[0], image.shape[1], self.size)
  176. if isinstance(image, np.ndarray):
  177. image = PIL.Image.fromarray(image)
  178. return np.array(T.resize(image, target_size))
  179. def __call__(self, images, **kwargs):
  180. if not isinstance(images, (list, tuple)):
  181. images = [images]
  182. return self.preprocess(images)
  183. def preprocess(
  184. self,
  185. images,
  186. ):
  187. """Preprocess an image or a batch of images with a same shape."""
  188. size = self.size
  189. input_image = [self.apply_image(image) for image in images]
  190. input_image_paddle = paddle.to_tensor(input_image).cast("int32")
  191. input_image_paddle = input_image_paddle.transpose([0, 3, 1, 2])
  192. original_image_size = images[0].shape[:2]
  193. self.original_size = original_image_size
  194. self.input_size = tuple(input_image_paddle.shape[-2:])
  195. mean = paddle.to_tensor(self.image_mean).reshape([-1, 1, 1])
  196. std = paddle.to_tensor(self.image_std).reshape([-1, 1, 1])
  197. input_image_paddle = (input_image_paddle.astype(std.dtype) - mean) / std
  198. h, w = input_image_paddle.shape[-2:]
  199. padh = self.size - h
  200. padw = self.size - w
  201. input_image = F.pad(input_image_paddle, (0, padw, 0, padh))
  202. return input_image.numpy()