base_predictor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import ABC, abstractmethod
  15. from copy import deepcopy
  16. from pathlib import Path
  17. from typing import Any, Dict, Iterator, List, Optional, Union
  18. from pydantic import ValidationError
  19. from ..... import constants
  20. from .....utils import logging
  21. from .....utils.deps import require_hpip
  22. from .....utils.device import get_default_device, parse_device
  23. from .....utils.flags import (
  24. INFER_BENCHMARK,
  25. INFER_BENCHMARK_ITERS,
  26. INFER_BENCHMARK_WARMUP,
  27. )
  28. from .....utils.subclass_register import AutoRegisterABCMetaClass
  29. from ....common.batch_sampler import BaseBatchSampler
  30. from ....utils.benchmark import ENTRY_POINT_NAME, benchmark
  31. from ....utils.hpi import HPIConfig, HPIInfo
  32. from ....utils.io import YAMLReader
  33. from ....utils.pp_option import PaddlePredictorOption
  34. from ...common import HPInfer, PaddleInfer
  35. class PredictionWrap:
  36. """Wraps the prediction data and supports get by index."""
  37. def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
  38. """Initializes the PredictionWrap with prediction data.
  39. Args:
  40. data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
  41. num (int): The number of predictions, that is length of values per key in the data dictionary.
  42. Raises:
  43. AssertionError: If the length of any list in data does not match num.
  44. """
  45. assert isinstance(data, dict), "data must be a dictionary"
  46. for k in data:
  47. assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
  48. self._data = data
  49. self._keys = data.keys()
  50. def get_by_idx(self, idx: int) -> Dict[str, Any]:
  51. """Get the prediction by specified index.
  52. Args:
  53. idx (int): The index to get predictions from.
  54. Returns:
  55. Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
  56. """
  57. return {key: self._data[key][idx] for key in self._keys}
  58. class BasePredictor(
  59. ABC,
  60. metaclass=AutoRegisterABCMetaClass,
  61. ):
  62. MODEL_FILE_PREFIX = constants.MODEL_FILE_PREFIX
  63. __is_base = True
  64. def __init__(
  65. self,
  66. model_dir: str,
  67. config: Optional[Dict[str, Any]] = None,
  68. *,
  69. device: Optional[str] = None,
  70. batch_size: int = 1,
  71. pp_option: Optional[PaddlePredictorOption] = None,
  72. use_hpip: bool = False,
  73. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  74. ) -> None:
  75. """Initializes the BasePredictor.
  76. Args:
  77. model_dir (str): The directory where the model files are stored.
  78. config (Optional[Dict[str, Any]], optional): The model configuration
  79. dictionary. Defaults to None.
  80. device (Optional[str], optional): The device to run the inference
  81. engine on. Defaults to None.
  82. batch_size (int, optional): The batch size to predict.
  83. Defaults to 1.
  84. pp_option (Optional[PaddlePredictorOption], optional): The inference
  85. engine options. Defaults to None.
  86. use_hpip (bool, optional): Whether to use high-performance inference
  87. plugin. Defaults to False.
  88. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  89. The high-performance inference configuration dictionary.
  90. Defaults to None.
  91. """
  92. super().__init__()
  93. self.model_dir = Path(model_dir)
  94. self.config = config if config else self.load_config(self.model_dir)
  95. self.batch_sampler = self._build_batch_sampler()
  96. self.result_class = self._get_result_class()
  97. # alias predict() to the __call__()
  98. self.predict = self.__call__
  99. self.batch_sampler.batch_size = batch_size
  100. self._use_hpip = use_hpip
  101. if not use_hpip:
  102. if hpi_config is not None:
  103. logging.warning(
  104. "`hpi_config` will be ignored when not using the high-performance inference plugin."
  105. )
  106. self._pp_option = self._prepare_pp_option(pp_option, device)
  107. else:
  108. require_hpip()
  109. if pp_option is not None:
  110. logging.warning(
  111. "`pp_option` will be ignored when using the high-performance inference plugin."
  112. )
  113. self._hpi_config = self._prepare_hpi_config(hpi_config, device)
  114. logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
  115. @property
  116. def config_path(self) -> str:
  117. """
  118. Get the path to the configuration file.
  119. Returns:
  120. str: The path to the configuration file.
  121. """
  122. return self.get_config_path(self.model_dir)
  123. @property
  124. def model_name(self) -> str:
  125. """
  126. Get the model name.
  127. Returns:
  128. str: The model name.
  129. """
  130. return self.config["Global"]["model_name"]
  131. @property
  132. def pp_option(self) -> PaddlePredictorOption:
  133. if not hasattr(self, "_pp_option"):
  134. raise AttributeError(f"{repr(self)} has no attribute 'pp_option'.")
  135. return self._pp_option
  136. @property
  137. def hpi_config(self) -> HPIConfig:
  138. if not hasattr(self, "_hpi_config"):
  139. raise AttributeError(f"{repr(self)} has no attribute 'hpi_config'.")
  140. return self._hpi_config
  141. @property
  142. def use_hpip(self) -> bool:
  143. return self._use_hpip
  144. def __call__(
  145. self,
  146. input: Any,
  147. batch_size: Optional[int] = None,
  148. **kwargs: Any,
  149. ) -> Iterator[Any]:
  150. """
  151. Predict with the input data.
  152. Args:
  153. input (Any): The input data to be predicted.
  154. batch_size (int, optional): The batch size to use. Defaults to None.
  155. **kwargs (Dict[str, Any]): Additional keyword arguments to set up predictor.
  156. Returns:
  157. Iterator[Any]: An iterator yielding the prediction output.
  158. """
  159. self.set_predictor(batch_size)
  160. if INFER_BENCHMARK:
  161. # TODO(zhang-prog): Get metadata of input data
  162. @benchmark.timeit_with_options(name=ENTRY_POINT_NAME)
  163. def _apply(input, **kwargs):
  164. return list(self.apply(input, **kwargs))
  165. if isinstance(input, list):
  166. raise TypeError("`input` cannot be a list in benchmark mode")
  167. input = [input] * batch_size
  168. if not (INFER_BENCHMARK_WARMUP > 0 or INFER_BENCHMARK_ITERS > 0):
  169. raise RuntimeError(
  170. "At least one of `INFER_BENCHMARK_WARMUP` and `INFER_BENCHMARK_ITERS` must be greater than zero"
  171. )
  172. if INFER_BENCHMARK_WARMUP > 0:
  173. benchmark.start_warmup()
  174. for _ in range(INFER_BENCHMARK_WARMUP):
  175. output = _apply(input, **kwargs)
  176. benchmark.collect(batch_size)
  177. benchmark.stop_warmup()
  178. if INFER_BENCHMARK_ITERS > 0:
  179. for _ in range(INFER_BENCHMARK_ITERS):
  180. output = _apply(input, **kwargs)
  181. benchmark.collect(batch_size)
  182. yield output[0]
  183. else:
  184. yield from self.apply(input, **kwargs)
  185. def set_predictor(
  186. self,
  187. batch_size: Optional[int] = None,
  188. ) -> None:
  189. """
  190. Sets the predictor configuration.
  191. Args:
  192. batch_size (Optional[int], optional): The batch size to use. Defaults to None.
  193. Returns:
  194. None
  195. """
  196. if batch_size:
  197. self.batch_sampler.batch_size = batch_size
  198. def get_hpi_info(self):
  199. if "Hpi" not in self.config:
  200. return None
  201. try:
  202. return HPIInfo.model_validate(self.config["Hpi"])
  203. except ValidationError as e:
  204. logging.exception("The HPI info in the model config file is invalid.")
  205. raise RuntimeError(f"Invalid HPI info: {str(e)}") from e
  206. def create_static_infer(self):
  207. if not self._use_hpip:
  208. return PaddleInfer(self.model_dir, self.MODEL_FILE_PREFIX, self._pp_option)
  209. else:
  210. return HPInfer(
  211. self.model_dir,
  212. self.MODEL_FILE_PREFIX,
  213. self._hpi_config,
  214. )
  215. def apply(self, input: Any, **kwargs) -> Iterator[Any]:
  216. """
  217. Do predicting with the input data and yields predictions.
  218. Args:
  219. input (Any): The input data to be predicted.
  220. Yields:
  221. Iterator[Any]: An iterator yielding prediction results.
  222. """
  223. if INFER_BENCHMARK:
  224. if not isinstance(input, list):
  225. raise TypeError("In benchmark mode, `input` must be a list")
  226. batches = list(self.batch_sampler(input))
  227. if len(batches) != 1 or len(batches[0]) != len(input):
  228. raise ValueError("Unexpected number of instances")
  229. else:
  230. batches = self.batch_sampler(input)
  231. for batch_data in batches:
  232. prediction = self.process(batch_data, **kwargs)
  233. prediction = PredictionWrap(prediction, len(batch_data))
  234. for idx in range(len(batch_data)):
  235. yield self.result_class(prediction.get_by_idx(idx))
  236. @abstractmethod
  237. def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
  238. """process the batch data sampled from BatchSampler and return the prediction result.
  239. Args:
  240. batch_data (List[Any]): The batch data sampled from BatchSampler.
  241. Returns:
  242. Dict[str, List[Any]]: The prediction result.
  243. """
  244. raise NotImplementedError
  245. @classmethod
  246. def get_config_path(cls, model_dir) -> str:
  247. """Get the path to the configuration file for the given model directory.
  248. Args:
  249. model_dir (Path): The directory where the static model files is stored.
  250. Returns:
  251. Path: The path to the configuration file.
  252. """
  253. return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
  254. @classmethod
  255. def load_config(cls, model_dir) -> Dict:
  256. """Load the configuration from the specified model directory.
  257. Args:
  258. model_dir (Path): The where the static model files is stored.
  259. Returns:
  260. dict: The loaded configuration dictionary.
  261. """
  262. yaml_reader = YAMLReader()
  263. return yaml_reader.read(cls.get_config_path(model_dir))
  264. @abstractmethod
  265. def _build_batch_sampler(self) -> BaseBatchSampler:
  266. """Build batch sampler.
  267. Returns:
  268. BaseBatchSampler: batch sampler object.
  269. """
  270. raise NotImplementedError
  271. @abstractmethod
  272. def _get_result_class(self) -> type:
  273. """Get the result class.
  274. Returns:
  275. type: The result class.
  276. """
  277. raise NotImplementedError
  278. def _prepare_pp_option(
  279. self,
  280. pp_option: Optional[PaddlePredictorOption],
  281. device: Optional[str],
  282. ) -> PaddlePredictorOption:
  283. if pp_option is None or device is not None:
  284. device_info = self._get_device_info(device)
  285. else:
  286. device_info = None
  287. if pp_option is None:
  288. pp_option = PaddlePredictorOption(model_name=self.model_name)
  289. if device_info:
  290. pp_option.device_type = device_info[0]
  291. pp_option.device_id = device_info[1]
  292. hpi_info = self.get_hpi_info()
  293. if hpi_info is not None:
  294. hpi_info = hpi_info.model_dump(exclude_unset=True)
  295. if pp_option.trt_dynamic_shapes is None:
  296. trt_dynamic_shapes = (
  297. hpi_info.get("backend_configs", {})
  298. .get("paddle_infer", {})
  299. .get("trt_dynamic_shapes", None)
  300. )
  301. if trt_dynamic_shapes is not None:
  302. logging.debug(
  303. "TensorRT dynamic shapes set to %s", trt_dynamic_shapes
  304. )
  305. pp_option.trt_dynamic_shapes = trt_dynamic_shapes
  306. if pp_option.trt_dynamic_shape_input_data is None:
  307. trt_dynamic_shape_input_data = (
  308. hpi_info.get("backend_configs", {})
  309. .get("paddle_infer", {})
  310. .get("trt_dynamic_shape_input_data", None)
  311. )
  312. if trt_dynamic_shape_input_data is not None:
  313. logging.debug(
  314. "TensorRT dynamic shape input data set to %s",
  315. trt_dynamic_shape_input_data,
  316. )
  317. pp_option.trt_dynamic_shape_input_data = (
  318. trt_dynamic_shape_input_data
  319. )
  320. return pp_option
  321. def _prepare_hpi_config(
  322. self,
  323. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]],
  324. device: Optional[str],
  325. ) -> HPIConfig:
  326. if hpi_config is None:
  327. hpi_config = {}
  328. elif isinstance(hpi_config, HPIConfig):
  329. hpi_config = hpi_config.model_dump(exclude_unset=True)
  330. else:
  331. hpi_config = deepcopy(hpi_config)
  332. if "model_name" not in hpi_config:
  333. hpi_config["model_name"] = self.model_name
  334. if device is not None or "device_type" not in hpi_config:
  335. device_type, device_id = self._get_device_info(device)
  336. hpi_config["device_type"] = device_type
  337. if device is not None or "device_id" not in hpi_config:
  338. hpi_config["device_id"] = device_id
  339. if "hpi_info" not in hpi_config:
  340. hpi_info = self.get_hpi_info()
  341. if hpi_info is not None:
  342. hpi_config["hpi_info"] = hpi_info
  343. hpi_config = HPIConfig.model_validate(hpi_config)
  344. return hpi_config
  345. # Should this be static?
  346. def _get_device_info(self, device):
  347. if device is None:
  348. device = get_default_device()
  349. device_type, device_ids = parse_device(device)
  350. if device_ids is not None:
  351. device_id = device_ids[0]
  352. else:
  353. device_id = None
  354. if device_ids and len(device_ids) > 1:
  355. logging.debug("Got multiple device IDs. Using the first one: %d", device_id)
  356. return device_type, device_id