base_predictor.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from abc import ABC, abstractmethod
  16. from copy import deepcopy
  17. from pathlib import Path
  18. from typing import Any, Dict, Iterator, List, Optional, Union
  19. from pydantic import ValidationError
  20. from ..... import constants
  21. from .....utils import logging
  22. from .....utils.deps import require_hpip
  23. from .....utils.device import get_default_device, parse_device
  24. from .....utils.flags import (
  25. INFER_BENCHMARK,
  26. INFER_BENCHMARK_ITERS,
  27. INFER_BENCHMARK_WARMUP,
  28. PIPELINE_BENCHMARK,
  29. )
  30. from .....utils.subclass_register import AutoRegisterABCMetaClass
  31. from ....common.batch_sampler import BaseBatchSampler
  32. from ....utils.benchmark import ENTRY_POINT_NAME, benchmark
  33. from ....utils.hpi import HPIConfig, HPIInfo
  34. from ....utils.io import YAMLReader
  35. from ....utils.model_paths import get_model_paths
  36. from ....utils.pp_option import PaddlePredictorOption
  37. from ...common import HPInfer, PaddleInfer
  38. from ...common.genai import GenAIClient, GenAIConfig, need_local_model
  39. class PredictionWrap:
  40. """Wraps the prediction data and supports get by index."""
  41. def __init__(self, data: Dict[str, List[Any]], num: int) -> None:
  42. """Initializes the PredictionWrap with prediction data.
  43. Args:
  44. data (Dict[str, List[Any]]): A dictionary where keys are string identifiers and values are lists of predictions.
  45. num (int): The number of predictions, that is length of values per key in the data dictionary.
  46. Raises:
  47. AssertionError: If the length of any list in data does not match num.
  48. """
  49. assert isinstance(data, dict), "data must be a dictionary"
  50. for k in data:
  51. assert len(data[k]) == num, f"{len(data[k])} != {num} for key {k}!"
  52. self._data = data
  53. self._keys = data.keys()
  54. def get_by_idx(self, idx: int) -> Dict[str, Any]:
  55. """Get the prediction by specified index.
  56. Args:
  57. idx (int): The index to get predictions from.
  58. Returns:
  59. Dict[str, Any]: A dictionary with the same keys as the input data, but with the values at the specified index.
  60. """
  61. return {key: self._data[key][idx] for key in self._keys}
  62. class BasePredictor(
  63. ABC,
  64. metaclass=AutoRegisterABCMetaClass,
  65. ):
  66. MODEL_FILE_PREFIX = constants.MODEL_FILE_PREFIX
  67. __is_base = True
  68. def __init__(
  69. self,
  70. model_dir: Optional[str] = None,
  71. config: Optional[Dict[str, Any]] = None,
  72. *,
  73. device: Optional[str] = None,
  74. batch_size: int = 1,
  75. pp_option: Optional[PaddlePredictorOption] = None,
  76. use_hpip: bool = False,
  77. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  78. genai_config: Optional[GenAIConfig] = None,
  79. model_name: Optional[str] = None,
  80. ) -> None:
  81. """Initializes the BasePredictor.
  82. Args:
  83. model_dir (Optional[str], optional): The directory where the model
  84. files are stored.
  85. config (Optional[Dict[str, Any]], optional): The model configuration
  86. dictionary. Defaults to None.
  87. device (Optional[str], optional): The device to run the inference
  88. engine on. Defaults to None.
  89. batch_size (int, optional): The batch size to predict.
  90. Defaults to 1.
  91. pp_option (Optional[PaddlePredictorOption], optional): The inference
  92. engine options. Defaults to None.
  93. use_hpip (bool, optional): Whether to use high-performance inference
  94. plugin. Defaults to False.
  95. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  96. The high-performance inference configuration dictionary.
  97. Defaults to None.
  98. genai_config (Optional[GenAIConfig]], optional): The generative AI
  99. configuration. Defaults to None.
  100. model_name (Optional[str], optional): Optional model name.
  101. Defaults to None.
  102. """
  103. super().__init__()
  104. if need_local_model(genai_config):
  105. if model_dir is None:
  106. raise ValueError(
  107. "`model_dir` should not be `None`, as a local model is needed."
  108. )
  109. self.model_dir = Path(model_dir)
  110. self.config = config if config else self.load_config(self.model_dir)
  111. self._use_local_model = True
  112. else:
  113. if model_dir is not None:
  114. warnings.warn("`model_dir` will be ignored, as it is not needed.")
  115. self.model_dir = None
  116. self.config = config
  117. self._genai_config = genai_config
  118. assert genai_config.server_url is not None
  119. self._genai_client = GenAIClient(
  120. backend=genai_config.backend,
  121. base_url=genai_config.server_url,
  122. max_concurrency=genai_config.max_concurrency,
  123. model_name=model_name,
  124. **(genai_config.client_kwargs or {}),
  125. )
  126. self._use_local_model = False
  127. if model_name:
  128. if self.config:
  129. if self.config["Global"]["model_name"] != model_name:
  130. raise ValueError("`model_name` is not consistent with `config`")
  131. self._model_name = model_name
  132. self.batch_sampler = self._build_batch_sampler()
  133. self.result_class = self._get_result_class()
  134. # alias predict() to the __call__()
  135. self.predict = self.__call__
  136. self.batch_sampler.batch_size = batch_size
  137. if self.model_dir and get_model_paths(self.model_dir, self.MODEL_FILE_PREFIX):
  138. self._use_hpip = use_hpip
  139. if not use_hpip:
  140. self._pp_option = self._prepare_pp_option(pp_option, device)
  141. else:
  142. require_hpip()
  143. self._hpi_config = self._prepare_hpi_config(hpi_config, device)
  144. else:
  145. self._use_hpip = False
  146. logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
  147. @property
  148. def config_path(self) -> str:
  149. """
  150. Get the path to the configuration file.
  151. Returns:
  152. str: The path to the configuration file.
  153. """
  154. return self.get_config_path(self.model_dir)
  155. @property
  156. def model_name(self) -> str:
  157. """
  158. Get the model name.
  159. Returns:
  160. str: The model name.
  161. """
  162. if self.config:
  163. return self.config["Global"]["model_name"]
  164. else:
  165. if hasattr(self, "_model_name"):
  166. return self._model_name
  167. else:
  168. raise AttributeError(f"{repr(self)} has no attribute 'model_name'.")
  169. @property
  170. def pp_option(self) -> PaddlePredictorOption:
  171. if not hasattr(self, "_pp_option"):
  172. raise AttributeError(f"{repr(self)} has no attribute 'pp_option'.")
  173. return self._pp_option
  174. @property
  175. def hpi_config(self) -> HPIConfig:
  176. if not hasattr(self, "_hpi_config"):
  177. raise AttributeError(f"{repr(self)} has no attribute 'hpi_config'.")
  178. return self._hpi_config
  179. @property
  180. def use_hpip(self) -> bool:
  181. return self._use_hpip
  182. @property
  183. def genai_config(self) -> GenAIConfig:
  184. if not hasattr(self, "_genai_config"):
  185. raise AttributeError(f"{repr(self)} has no attribute 'genai_config'.")
  186. return self._genai_config
  187. def __call__(
  188. self,
  189. input: Any,
  190. batch_size: Optional[int] = None,
  191. **kwargs: Any,
  192. ) -> Iterator[Any]:
  193. """
  194. Predict with the input data.
  195. Args:
  196. input (Any): The input data to be predicted.
  197. batch_size (int, optional): The batch size to use. Defaults to None.
  198. **kwargs (Dict[str, Any]): Additional keyword arguments to set up predictor.
  199. Returns:
  200. Iterator[Any]: An iterator yielding the prediction output.
  201. """
  202. self.set_predictor(batch_size)
  203. if INFER_BENCHMARK:
  204. # TODO(zhang-prog): Get metadata of input data
  205. @benchmark.timeit_with_options(name=ENTRY_POINT_NAME)
  206. def _apply(input, **kwargs):
  207. return list(self.apply(input, **kwargs))
  208. if isinstance(input, list):
  209. raise TypeError("`input` cannot be a list in benchmark mode")
  210. input = [input] * batch_size
  211. if not (INFER_BENCHMARK_WARMUP > 0 or INFER_BENCHMARK_ITERS > 0):
  212. raise RuntimeError(
  213. "At least one of `INFER_BENCHMARK_WARMUP` and `INFER_BENCHMARK_ITERS` must be greater than zero"
  214. )
  215. if INFER_BENCHMARK_WARMUP > 0:
  216. benchmark.start_warmup()
  217. for _ in range(INFER_BENCHMARK_WARMUP):
  218. output = _apply(input, **kwargs)
  219. benchmark.collect(batch_size)
  220. benchmark.stop_warmup()
  221. if INFER_BENCHMARK_ITERS > 0:
  222. for _ in range(INFER_BENCHMARK_ITERS):
  223. output = _apply(input, **kwargs)
  224. benchmark.collect(batch_size)
  225. yield output[0]
  226. elif PIPELINE_BENCHMARK:
  227. @benchmark.timeit_with_options(name=type(self).__name__ + ".apply")
  228. def _apply(input, **kwargs):
  229. return list(self.apply(input, **kwargs))
  230. yield from _apply(input, **kwargs)
  231. else:
  232. yield from self.apply(input, **kwargs)
  233. def set_predictor(
  234. self,
  235. batch_size: Optional[int] = None,
  236. ) -> None:
  237. """
  238. Sets the predictor configuration.
  239. Args:
  240. batch_size (Optional[int], optional): The batch size to use. Defaults to None.
  241. Returns:
  242. None
  243. """
  244. if batch_size:
  245. self.batch_sampler.batch_size = batch_size
  246. def get_hpi_info(self):
  247. if "Hpi" not in self.config:
  248. return None
  249. try:
  250. return HPIInfo.model_validate(self.config["Hpi"])
  251. except ValidationError as e:
  252. raise RuntimeError(f"Invalid HPI info: {str(e)}") from e
  253. def create_static_infer(self):
  254. if not self._use_hpip:
  255. return PaddleInfer(
  256. self.model_name, self.model_dir, self.MODEL_FILE_PREFIX, self._pp_option
  257. )
  258. else:
  259. return HPInfer(
  260. self.model_dir,
  261. self.MODEL_FILE_PREFIX,
  262. self._hpi_config,
  263. )
  264. def apply(self, input: Any, **kwargs) -> Iterator[Any]:
  265. """
  266. Do predicting with the input data and yields predictions.
  267. Args:
  268. input (Any): The input data to be predicted.
  269. Yields:
  270. Iterator[Any]: An iterator yielding prediction results.
  271. """
  272. if INFER_BENCHMARK:
  273. if not isinstance(input, list):
  274. raise TypeError("In benchmark mode, `input` must be a list")
  275. batches = list(self.batch_sampler(input))
  276. if len(batches) != 1 or len(batches[0]) != len(input):
  277. raise ValueError("Unexpected number of instances")
  278. else:
  279. batches = self.batch_sampler(input)
  280. for batch_data in batches:
  281. prediction = self.process(batch_data, **kwargs)
  282. prediction = PredictionWrap(prediction, len(batch_data))
  283. for idx in range(len(batch_data)):
  284. yield self.result_class(prediction.get_by_idx(idx))
  285. @abstractmethod
  286. def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
  287. """process the batch data sampled from BatchSampler and return the prediction result.
  288. Args:
  289. batch_data (List[Any]): The batch data sampled from BatchSampler.
  290. Returns:
  291. Dict[str, List[Any]]: The prediction result.
  292. """
  293. raise NotImplementedError
  294. def close(self) -> None:
  295. if hasattr(self, "_genai_client"):
  296. self._genai_client.close()
  297. @classmethod
  298. def get_config_path(cls, model_dir) -> str:
  299. """Get the path to the configuration file for the given model directory.
  300. Args:
  301. model_dir (Path): The directory where the static model files is stored.
  302. Returns:
  303. Path: The path to the configuration file.
  304. """
  305. return model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
  306. @classmethod
  307. def load_config(cls, model_dir) -> Dict:
  308. """Load the configuration from the specified model directory.
  309. Args:
  310. model_dir (Path): The where the static model files is stored.
  311. Returns:
  312. dict: The loaded configuration dictionary.
  313. """
  314. yaml_reader = YAMLReader()
  315. return yaml_reader.read(cls.get_config_path(model_dir))
  316. @abstractmethod
  317. def _build_batch_sampler(self) -> BaseBatchSampler:
  318. """Build batch sampler.
  319. Returns:
  320. BaseBatchSampler: batch sampler object.
  321. """
  322. raise NotImplementedError
  323. @abstractmethod
  324. def _get_result_class(self) -> type:
  325. """Get the result class.
  326. Returns:
  327. type: The result class.
  328. """
  329. raise NotImplementedError
  330. def _prepare_pp_option(
  331. self,
  332. pp_option: Optional[PaddlePredictorOption],
  333. device: Optional[str],
  334. ) -> PaddlePredictorOption:
  335. if pp_option is None or device is not None:
  336. device_info = self._get_device_info(device)
  337. else:
  338. device_info = None
  339. if pp_option is None:
  340. pp_option = PaddlePredictorOption()
  341. if device_info:
  342. pp_option.device_type = device_info[0]
  343. pp_option.device_id = device_info[1]
  344. pp_option.setdefault_by_model_name(model_name=self.model_name)
  345. hpi_info = self.get_hpi_info()
  346. if hpi_info is not None:
  347. hpi_info = hpi_info.model_dump(exclude_unset=True)
  348. if pp_option.trt_dynamic_shapes is None:
  349. trt_dynamic_shapes = (
  350. hpi_info.get("backend_configs", {})
  351. .get("paddle_infer", {})
  352. .get("trt_dynamic_shapes", None)
  353. )
  354. if trt_dynamic_shapes is not None:
  355. logging.debug(
  356. "TensorRT dynamic shapes set to %s", trt_dynamic_shapes
  357. )
  358. pp_option.trt_dynamic_shapes = trt_dynamic_shapes
  359. if pp_option.trt_dynamic_shape_input_data is None:
  360. trt_dynamic_shape_input_data = (
  361. hpi_info.get("backend_configs", {})
  362. .get("paddle_infer", {})
  363. .get("trt_dynamic_shape_input_data", None)
  364. )
  365. if trt_dynamic_shape_input_data is not None:
  366. logging.debug(
  367. "TensorRT dynamic shape input data set to %s",
  368. trt_dynamic_shape_input_data,
  369. )
  370. pp_option.trt_dynamic_shape_input_data = (
  371. trt_dynamic_shape_input_data
  372. )
  373. return pp_option
  374. def _prepare_hpi_config(
  375. self,
  376. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]],
  377. device: Optional[str],
  378. ) -> HPIConfig:
  379. if hpi_config is None:
  380. hpi_config = {}
  381. elif isinstance(hpi_config, HPIConfig):
  382. hpi_config = hpi_config.model_dump(exclude_unset=True)
  383. else:
  384. hpi_config = deepcopy(hpi_config)
  385. if "model_name" not in hpi_config:
  386. hpi_config["model_name"] = self.model_name
  387. if device is not None or "device_type" not in hpi_config:
  388. device_type, device_id = self._get_device_info(device)
  389. hpi_config["device_type"] = device_type
  390. if device is not None or "device_id" not in hpi_config:
  391. hpi_config["device_id"] = device_id
  392. if "hpi_info" not in hpi_config:
  393. hpi_info = self.get_hpi_info()
  394. if hpi_info is not None:
  395. hpi_config["hpi_info"] = hpi_info
  396. hpi_config = HPIConfig.model_validate(hpi_config)
  397. return hpi_config
  398. # Should this be static?
  399. def _get_device_info(self, device):
  400. if device is None:
  401. device = get_default_device()
  402. device_type, device_ids = parse_device(device)
  403. if device_ids is not None:
  404. device_id = device_ids[0]
  405. else:
  406. device_id = None
  407. if device_ids and len(device_ids) > 1:
  408. logging.debug("Got multiple device IDs. Using the first one: %d", device_id)
  409. return device_type, device_id