pipeline.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, Optional, Union, Tuple, List
  15. import numpy as np
  16. from ...utils.pp_option import PaddlePredictorOption
  17. from ...utils.hpi import HPIConfig
  18. from ..base import BasePipeline
  19. from ...models.open_vocabulary_segmentation.results import SAMSegResult
  20. Number = Union[int, float]
  21. class OpenVocabularySegmentationPipeline(BasePipeline):
  22. """Open Vocabulary Segmentation pipeline"""
  23. entities = "open_vocabulary_segmentation"
  24. def __init__(
  25. self,
  26. config: Dict,
  27. device: str = None,
  28. pp_option: PaddlePredictorOption = None,
  29. use_hpip: bool = False,
  30. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  31. ) -> None:
  32. """
  33. Initializes the class with given configurations and options.
  34. Args:
  35. config (Dict): Configuration dictionary containing model and other parameters.
  36. device (str): The device to run the prediction on. Default is None.
  37. pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
  38. use_hpip (bool, optional): Whether to use the high-performance
  39. inference plugin (HPIP) by default. Defaults to False.
  40. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  41. The default high-performance inference configuration dictionary.
  42. Defaults to None.
  43. """
  44. super().__init__(
  45. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  46. )
  47. # create box-prompted SAM-H
  48. box_prompted_model_cfg = config.get("SubModules", {}).get(
  49. "BoxPromptSegmentation",
  50. {"model_config_error": "config error for doc_ori_classify_model!"},
  51. )
  52. self.box_prompted_model = self.create_model(box_prompted_model_cfg)
  53. # create point-prompted SAM-H
  54. point_prompted_model_cfg = config.get("SubModules", {}).get(
  55. "PointPromptSegmentation",
  56. {"model_config_error": "config error for doc_ori_classify_model!"},
  57. )
  58. self.point_prompted_model = self.create_model(point_prompted_model_cfg)
  59. def predict(
  60. self,
  61. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  62. prompt: Union[List[List[float]], np.ndarray],
  63. prompt_type: str = "box",
  64. **kwargs
  65. ) -> SAMSegResult:
  66. """Predicts image segmentation results for the given input.
  67. Args:
  68. input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
  69. prompt (list[list[float]] | np.ndarray): The prompt for the input image(s).
  70. prompt_type (str): The type of prompt, either 'box' or 'point'. Default is 'box'.
  71. **kwargs: Additional keyword arguments that can be passed to the function.
  72. Returns:
  73. SAMSegResult: The predicted SAM segmentation results.
  74. """
  75. if prompt_type == "box":
  76. yield from self.box_prompted_model(input, prompts={"box_prompt": prompt})
  77. elif prompt_type == "point":
  78. yield from self.point_prompted_model(
  79. input, prompts={"point_prompt": prompt}
  80. )
  81. else:
  82. raise ValueError(
  83. "Invalid prompt type. Only 'box' and 'point' are supported"
  84. )