predictor.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple
  15. import numpy as np
  16. from ....utils.func_register import FuncRegister
  17. from ....modules.semantic_segmentation.model_list import MODELS
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..common import (
  21. ResizeByShort,
  22. Normalize,
  23. ToCHWImage,
  24. ToBatch,
  25. StaticInfer,
  26. )
  27. from .processors import Resize, SegPostProcess
  28. from ..base import BasicPredictor
  29. from .result import SegResult
  30. class SegPredictor(BasicPredictor):
  31. """SegPredictor that inherits from BasicPredictor."""
  32. entities = MODELS
  33. _FUNC_MAP = {}
  34. register = FuncRegister(_FUNC_MAP)
  35. def __init__(
  36. self,
  37. target_size: Union[int, Tuple[int], None] = None,
  38. *args: List,
  39. **kwargs: Dict,
  40. ) -> None:
  41. """Initializes SegPredictor.
  42. Args:
  43. target_size: Image size used for inference.
  44. *args: Arbitrary positional arguments passed to the superclass.
  45. **kwargs: Arbitrary keyword arguments passed to the superclass.
  46. """
  47. super().__init__(*args, **kwargs)
  48. self.target_size = target_size
  49. self.preprocessors, self.infer, self.postprocessers = self._build()
  50. def _build_batch_sampler(self) -> ImageBatchSampler:
  51. """Builds and returns an ImageBatchSampler instance.
  52. Returns:
  53. ImageBatchSampler: An instance of ImageBatchSampler.
  54. """
  55. return ImageBatchSampler()
  56. def _get_result_class(self) -> type:
  57. """Returns the result class, SegResult.
  58. Returns:
  59. type: The SegResult class.
  60. """
  61. return SegResult
  62. def _build(self) -> Tuple:
  63. """Build the preprocessors, inference engine, and postprocessors based on the configuration.
  64. Returns:
  65. tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
  66. """
  67. preprocessors = {"Read": ReadImage(format="RGB")}
  68. preprocessors["ToCHW"] = ToCHWImage()
  69. for cfg in self.config["Deploy"]["transforms"]:
  70. tf_key = cfg.pop("type")
  71. func = self._FUNC_MAP[tf_key]
  72. args = cfg
  73. name, op = func(self, **args) if args else func(self)
  74. preprocessors[name] = op
  75. preprocessors["ToBatch"] = ToBatch()
  76. if "Resize" not in preprocessors:
  77. _, op = self._FUNC_MAP["Resize"](self, target_size=-1)
  78. preprocessors["Resize"] = op
  79. if self.target_size is not None:
  80. _, op = self._FUNC_MAP["Resize"](self, target_size=self.target_size)
  81. preprocessors["Resize"] = op
  82. infer = StaticInfer(
  83. model_dir=self.model_dir,
  84. model_prefix=self.MODEL_FILE_PREFIX,
  85. option=self.pp_option,
  86. )
  87. postprocessers = SegPostProcess()
  88. return preprocessors, infer, postprocessers
  89. def process(
  90. self,
  91. batch_data: List[Union[str, np.ndarray]],
  92. target_size: Union[int, Tuple[int], None] = None,
  93. ) -> Dict[str, Any]:
  94. """
  95. Process a batch of data through the preprocessing, inference, and postprocessing.
  96. Args:
  97. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  98. target_size: Image size used for inference.
  99. Returns:
  100. dict: A dictionary containing the input path, raw image, and predicted segmentation maps for every instance of the batch. Keys include 'input_path', 'input_img', and 'pred'.
  101. """
  102. batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
  103. batch_imgs = self.preprocessors["Resize"](
  104. imgs=batch_raw_imgs, target_size=target_size
  105. )
  106. batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
  107. batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
  108. x = self.preprocessors["ToBatch"](imgs=batch_imgs)
  109. batch_preds = self.infer(x=x)
  110. if len(batch_data) > 1:
  111. batch_preds = np.split(batch_preds[0], len(batch_data), axis=0)
  112. # postprocess
  113. batch_preds = self.postprocessers(batch_preds, batch_raw_imgs)
  114. return {
  115. "input_path": batch_data,
  116. "input_img": batch_raw_imgs,
  117. "pred": batch_preds,
  118. }
  119. @register("Normalize")
  120. def build_normalize(
  121. self,
  122. mean=0.5,
  123. std=0.5,
  124. ):
  125. op = Normalize(mean=mean, std=std)
  126. return "Normalize", op
  127. @register("Resize")
  128. def build_resize(
  129. self,
  130. target_size=-1,
  131. keep_ratio=True,
  132. size_divisor=32,
  133. interp="LINEAR",
  134. ):
  135. op = Resize(
  136. target_size=target_size,
  137. keep_ratio=keep_ratio,
  138. size_divisor=size_divisor,
  139. interp=interp,
  140. )
  141. return "Resize", op