base.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. from os import PathLike
  16. from pathlib import Path
  17. from typing import Any, Dict, Final, Iterator, Optional, TypedDict, Union
  18. import ultra_infer as ui
  19. from paddlex_hpi._config import HPIConfig
  20. from typing_extensions import assert_never
  21. from ultra_infer.model import BaseUltraInferModel
  22. from paddlex.inference.common.reader import ReadImage, ReadTS
  23. from paddlex.inference.models import BasePredictor
  24. from paddlex.inference.utils.new_ir_blocklist import NEWIR_BLOCKLIST
  25. from paddlex.utils import device as device_helper
  26. from paddlex.utils import logging
  27. from paddlex.utils.subclass_register import AutoRegisterABCMetaClass
  28. HPI_CONFIG_KEY: Final[str] = "Hpi"
  29. class HPIParams(TypedDict, total=False):
  30. config: Dict[str, Any]
  31. class HPPredictor(BasePredictor, metaclass=AutoRegisterABCMetaClass):
  32. __is_base = True
  33. def __init__(
  34. self,
  35. model_dir: Union[str, PathLike],
  36. config: Optional[Dict[str, Any]] = None,
  37. device: Optional[str] = None,
  38. batch_size: int = 1,
  39. use_onnx_model: Optional[bool] = None,
  40. hpi_params: Optional[HPIParams] = None,
  41. ) -> None:
  42. super().__init__(model_dir=model_dir, config=config)
  43. self._device = device or device_helper.get_default_device()
  44. self.batch_sampler.batch_size = batch_size
  45. self._onnx_format = use_onnx_model
  46. self._check_and_choose_model_format()
  47. self._hpi_params = hpi_params or {}
  48. self._hpi_config = self._get_hpi_config()
  49. self._ui_model = self.build_ui_model()
  50. self._data_reader = self._build_data_reader()
  51. def __call__(
  52. self,
  53. input: Any,
  54. batch_size: int = None,
  55. device: str = None,
  56. **kwargs: dict[str, Any],
  57. ) -> Iterator[Any]:
  58. self.set_predictor(batch_size, device)
  59. yield from self.apply(input, **kwargs)
  60. @property
  61. def model_path(self) -> Path:
  62. if self._onnx_format:
  63. return self.model_dir / f"{self.MODEL_FILE_PREFIX}.onnx"
  64. else:
  65. return self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdmodel"
  66. @property
  67. def params_path(self) -> Union[Path, None]:
  68. if self._onnx_format:
  69. return None
  70. else:
  71. return self.model_dir / f"{self.MODEL_FILE_PREFIX}.pdiparams"
  72. def set_predictor(self, batch_size: int = None, device: str = None) -> None:
  73. if device and device != self._device:
  74. raise RuntimeError("Currently, changing devices is not supported.")
  75. if batch_size:
  76. self.batch_sampler.batch_size = batch_size
  77. def build_ui_model(self) -> BaseUltraInferModel:
  78. option = self._create_ui_option()
  79. return self._build_ui_model(option)
  80. def _get_hpi_config(self) -> HPIConfig:
  81. if HPI_CONFIG_KEY not in self.config:
  82. logging.debug("Key %r not found in the config", HPI_CONFIG_KEY)
  83. hpi_config = HPIConfig.model_validate(
  84. {
  85. **self.config.get(HPI_CONFIG_KEY, {}),
  86. **self._hpi_params.get("config", {}),
  87. }
  88. )
  89. return hpi_config
  90. def _create_ui_option(self) -> ui.RuntimeOption:
  91. option = ui.RuntimeOption()
  92. # HACK: Disable new IR for models that are known to have issues with the
  93. # new IR.
  94. if self.model_name in NEWIR_BLOCKLIST:
  95. option.paddle_infer_option.enable_new_ir = False
  96. device_type, device_ids = device_helper.parse_device(self._device)
  97. if device_type == "cpu":
  98. pass
  99. elif device_type == "gpu":
  100. if device_ids is None:
  101. device_ids = [0]
  102. if len(device_ids) > 1:
  103. logging.warning(
  104. "Multiple devices are specified (%s), but only the first one will be used.",
  105. self._device,
  106. )
  107. option.use_gpu(device_ids[0])
  108. else:
  109. assert_never(device_type)
  110. backend, backend_config = self._hpi_config.get_backend_and_config(
  111. model_name=self.model_name,
  112. device_type=device_type,
  113. onnx_format=self._onnx_format,
  114. )
  115. logging.info("Backend: %s", backend)
  116. logging.info("Backend config: %s", backend_config)
  117. backend_config.update_ui_option(option, self.model_dir)
  118. return option
  119. def _check_and_choose_model_format(self) -> None:
  120. has_onnx_model = any(self.model_dir.glob(f"{self.MODEL_FILE_PREFIX}.onnx"))
  121. has_pd_model = any(self.model_dir.glob(f"{self.MODEL_FILE_PREFIX}.pdmodel"))
  122. if self._onnx_format is None:
  123. if has_onnx_model and has_pd_model:
  124. logging.warning(
  125. "Both ONNX and Paddle models are detected, but no preference is set. Default model (.pdmodel) will be used."
  126. )
  127. elif has_pd_model:
  128. logging.warning(
  129. "Only Paddle model is detected. Paddle model will be used by default."
  130. )
  131. elif has_onnx_model:
  132. self._onnx_format = True
  133. logging.warning(
  134. "Only ONNX model is detected. ONNX model will be used by default."
  135. )
  136. else:
  137. raise RuntimeError(
  138. "No models are detected. Please ensure the model file exists."
  139. )
  140. elif self._onnx_format:
  141. if not has_onnx_model:
  142. raise RuntimeError(
  143. "ONNX model is specified but not detected. Please ensure the ONNX model file exists."
  144. )
  145. else:
  146. if not has_pd_model:
  147. raise RuntimeError(
  148. "Paddle model is specified but not detected. Please ensure the Paddle model file exists."
  149. )
  150. @abc.abstractmethod
  151. def _build_ui_model(self, option: ui.RuntimeOption) -> BaseUltraInferModel:
  152. raise NotImplementedError
  153. @abc.abstractmethod
  154. def _build_data_reader(self):
  155. raise NotImplementedError
  156. class CVPredictor(HPPredictor):
  157. def _build_data_reader(self):
  158. return ReadImage(format="BGR")
  159. class TSPredictor(HPPredictor):
  160. def _build_data_reader(self):
  161. return ReadTS()