pipeline.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Literal, Optional, Tuple, Union
  15. import numpy as np
  16. from ....utils.deps import pipeline_requires_extra
  17. from ...models.semantic_segmentation.result import SegResult
  18. from ...utils.benchmark import benchmark
  19. from ...utils.hpi import HPIConfig
  20. from ...utils.pp_option import PaddlePredictorOption
  21. from .._parallel import AutoParallelImageSimpleInferencePipeline
  22. from ..base import BasePipeline
  23. @benchmark.time_methods
  24. class _SemanticSegmentationPipeline(BasePipeline):
  25. """Semantic Segmentation Pipeline"""
  26. def __init__(
  27. self,
  28. config: Dict,
  29. device: str = None,
  30. pp_option: PaddlePredictorOption = None,
  31. use_hpip: bool = False,
  32. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  33. ) -> None:
  34. """
  35. Initializes the class with given configurations and options.
  36. Args:
  37. config (Dict): Configuration dictionary containing model and other parameters.
  38. device (str): The device to run the prediction on. Default is None.
  39. pp_option (PaddlePredictorOption): Options for PaddlePaddle predictor. Default is None.
  40. use_hpip (bool, optional): Whether to use the high-performance
  41. inference plugin (HPIP) by default. Defaults to False.
  42. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  43. The default high-performance inference configuration dictionary.
  44. Defaults to None.
  45. """
  46. super().__init__(
  47. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  48. )
  49. semantic_segmentation_model_config = config["SubModules"][
  50. "SemanticSegmentation"
  51. ]
  52. self.semantic_segmentation_model = self.create_model(
  53. semantic_segmentation_model_config
  54. )
  55. self.target_size = semantic_segmentation_model_config["target_size"]
  56. def predict(
  57. self,
  58. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  59. target_size: Union[Literal[-1], None, int, Tuple[int]] = None,
  60. **kwargs
  61. ) -> SegResult:
  62. """Predicts semantic segmentation results for the given input.
  63. Args:
  64. input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
  65. target_size (Literal[-1] | None | int | tuple[int]): The Image size model used to do prediction. Default is None.
  66. If it's set to -1, the original image size will be used.
  67. If it's set to None, the previous level's setting will be used.
  68. If it's set to an integer value, the image will be rescaled to the size of (value, value).
  69. If it's set to a tuple of two integers, the image will be rescaled to the size of (height, width).
  70. **kwargs: Additional keyword arguments that can be passed to the function.
  71. Returns:
  72. SegResult: The predicted segmentation results.
  73. """
  74. yield from self.semantic_segmentation_model(input, target_size=target_size)
  75. @pipeline_requires_extra("cv")
  76. class SemanticSegmentationPipeline(AutoParallelImageSimpleInferencePipeline):
  77. entities = "semantic_segmentation"
  78. @property
  79. def _pipeline_cls(self):
  80. return _SemanticSegmentationPipeline
  81. def _get_batch_size(self, config):
  82. return config["SubModules"]["SemanticSegmentation"].get("batch_size", 1)