static_infer.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import subprocess
  16. from os import PathLike
  17. from pathlib import Path
  18. from typing import List, Sequence, Union
  19. import numpy as np
  20. from ....utils import logging
  21. from ....utils.deps import class_requires_deps
  22. from ....utils.device import check_supported_device_type
  23. from ....utils.flags import (
  24. DEBUG,
  25. DISABLE_MKLDNN_MODEL_BL,
  26. DISABLE_TRT_MODEL_BL,
  27. USE_PIR_TRT,
  28. )
  29. from ...utils.benchmark import benchmark, set_inference_operations
  30. from ...utils.hpi import (
  31. HPIConfig,
  32. OMConfig,
  33. ONNXRuntimeConfig,
  34. OpenVINOConfig,
  35. TensorRTConfig,
  36. suggest_inference_backend_and_config,
  37. )
  38. from ...utils.mkldnn_blocklist import MKLDNN_BLOCKLIST
  39. from ...utils.model_paths import get_model_paths
  40. from ...utils.pp_option import PaddlePredictorOption, get_default_run_mode
  41. from ...utils.trt_blocklist import TRT_BLOCKLIST
  42. from ...utils.trt_config import DISABLE_TRT_HALF_OPS_CONFIG
  43. CACHE_DIR = ".cache"
  44. INFERENCE_OPERATIONS = [
  45. "PaddleInferChainLegacy",
  46. "MultiBackendInfer",
  47. ]
  48. set_inference_operations(INFERENCE_OPERATIONS)
  49. # XXX: Better use Paddle Inference API to do this
  50. def _pd_dtype_to_np_dtype(pd_dtype):
  51. import paddle
  52. if pd_dtype == paddle.inference.DataType.FLOAT64:
  53. return np.float64
  54. elif pd_dtype == paddle.inference.DataType.FLOAT32:
  55. return np.float32
  56. elif pd_dtype == paddle.inference.DataType.INT64:
  57. return np.int64
  58. elif pd_dtype == paddle.inference.DataType.INT32:
  59. return np.int32
  60. elif pd_dtype == paddle.inference.DataType.UINT8:
  61. return np.uint8
  62. elif pd_dtype == paddle.inference.DataType.INT8:
  63. return np.int8
  64. else:
  65. raise TypeError(f"Unsupported data type: {pd_dtype}")
  66. # old trt
  67. def _collect_trt_shape_range_info(
  68. model_file,
  69. model_params,
  70. gpu_id,
  71. shape_range_info_path,
  72. dynamic_shapes,
  73. dynamic_shape_input_data,
  74. ):
  75. import paddle.inference
  76. dynamic_shape_input_data = dynamic_shape_input_data or {}
  77. config = paddle.inference.Config(model_file, model_params)
  78. config.enable_use_gpu(100, gpu_id)
  79. config.collect_shape_range_info(shape_range_info_path)
  80. # TODO: Add other needed options
  81. config.disable_glog_info()
  82. predictor = paddle.inference.create_predictor(config)
  83. input_names = predictor.get_input_names()
  84. for name in dynamic_shapes:
  85. if name not in input_names:
  86. raise ValueError(
  87. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  88. )
  89. for name in input_names:
  90. if name not in dynamic_shapes:
  91. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  92. for name in dynamic_shape_input_data:
  93. if name not in input_names:
  94. raise ValueError(
  95. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  96. )
  97. # It would be better to check if the shapes are valid.
  98. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  99. for name, candidate_shapes in dynamic_shapes.items():
  100. # XXX: Currently we have no way to get the data type of the tensor
  101. # without creating an input handle.
  102. handle = predictor.get_input_handle(name)
  103. dtype = _pd_dtype_to_np_dtype(handle.type())
  104. min_shape, opt_shape, max_shape = candidate_shapes
  105. if name in dynamic_shape_input_data:
  106. min_arrs[name] = np.array(
  107. dynamic_shape_input_data[name][0], dtype=dtype
  108. ).reshape(min_shape)
  109. opt_arrs[name] = np.array(
  110. dynamic_shape_input_data[name][1], dtype=dtype
  111. ).reshape(opt_shape)
  112. max_arrs[name] = np.array(
  113. dynamic_shape_input_data[name][2], dtype=dtype
  114. ).reshape(max_shape)
  115. else:
  116. min_arrs[name] = np.ones(min_shape, dtype=dtype)
  117. opt_arrs[name] = np.ones(opt_shape, dtype=dtype)
  118. max_arrs[name] = np.ones(max_shape, dtype=dtype)
  119. # `opt_arrs` is used twice to ensure it is the most frequently used.
  120. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  121. for name, arr in arrs.items():
  122. handle = predictor.get_input_handle(name)
  123. handle.reshape(arr.shape)
  124. handle.copy_from_cpu(arr)
  125. predictor.run()
  126. # HACK: The shape range info will be written to the file only when
  127. # `predictor` is garbage collected. It works in CPython, but it is
  128. # definitely a bad idea to count on the implementation-dependent behavior of
  129. # a garbage collector. Is there a more explicit and deterministic way to
  130. # handle this?
  131. # HACK: Manually delete the predictor to trigger its destructor, ensuring that the shape_range_info file would be saved.
  132. del predictor
  133. # pir trt
  134. def _convert_trt(
  135. trt_cfg_setting,
  136. pp_model_file,
  137. pp_params_file,
  138. trt_save_path,
  139. device_id,
  140. dynamic_shapes,
  141. dynamic_shape_input_data,
  142. ):
  143. import paddle.inference
  144. from paddle.tensorrt.export import Input, TensorRTConfig, convert
  145. def _set_trt_config():
  146. for attr_name in trt_cfg_setting:
  147. assert hasattr(
  148. trt_config, attr_name
  149. ), f"The `{type(trt_config)}` don't have the attribute `{attr_name}`!"
  150. setattr(trt_config, attr_name, trt_cfg_setting[attr_name])
  151. def _get_predictor(model_file, params_file):
  152. # HACK
  153. config = paddle.inference.Config(str(model_file), str(params_file))
  154. config.enable_use_gpu(100, device_id)
  155. # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
  156. config.disable_mkldnn()
  157. config.disable_glog_info()
  158. return paddle.inference.create_predictor(config)
  159. dynamic_shape_input_data = dynamic_shape_input_data or {}
  160. predictor = _get_predictor(pp_model_file, pp_params_file)
  161. input_names = predictor.get_input_names()
  162. for name in dynamic_shapes:
  163. if name not in input_names:
  164. raise ValueError(
  165. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  166. )
  167. for name in input_names:
  168. if name not in dynamic_shapes:
  169. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  170. for name in dynamic_shape_input_data:
  171. if name not in input_names:
  172. raise ValueError(
  173. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  174. )
  175. trt_inputs = []
  176. for name, candidate_shapes in dynamic_shapes.items():
  177. # XXX: Currently we have no way to get the data type of the tensor
  178. # without creating an input handle.
  179. handle = predictor.get_input_handle(name)
  180. dtype = _pd_dtype_to_np_dtype(handle.type())
  181. min_shape, opt_shape, max_shape = candidate_shapes
  182. if name in dynamic_shape_input_data:
  183. min_arr = np.array(dynamic_shape_input_data[name][0], dtype=dtype).reshape(
  184. min_shape
  185. )
  186. opt_arr = np.array(dynamic_shape_input_data[name][1], dtype=dtype).reshape(
  187. opt_shape
  188. )
  189. max_arr = np.array(dynamic_shape_input_data[name][2], dtype=dtype).reshape(
  190. max_shape
  191. )
  192. else:
  193. min_arr = np.ones(min_shape, dtype=dtype)
  194. opt_arr = np.ones(opt_shape, dtype=dtype)
  195. max_arr = np.ones(max_shape, dtype=dtype)
  196. # refer to: https://github.com/PolaKuma/Paddle/blob/3347f225bc09f2ec09802a2090432dd5cb5b6739/test/tensorrt/test_converter_model_resnet50.py
  197. trt_input = Input((min_arr, opt_arr, max_arr))
  198. trt_inputs.append(trt_input)
  199. # Create TensorRTConfig
  200. trt_config = TensorRTConfig(inputs=trt_inputs)
  201. _set_trt_config()
  202. trt_config.save_model_dir = str(trt_save_path)
  203. pp_model_path = str(pp_model_file.with_suffix(""))
  204. convert(pp_model_path, trt_config)
  205. def _sort_inputs(inputs, names):
  206. # NOTE: Adjust input tensors to match the sorted sequence.
  207. indices = sorted(range(len(names)), key=names.__getitem__)
  208. inputs = [inputs[indices.index(i)] for i in range(len(inputs))]
  209. return inputs
  210. # FIXME: Name might be misleading
  211. @benchmark.timeit
  212. class PaddleInferChainLegacy:
  213. def __init__(self, predictor):
  214. self.predictor = predictor
  215. input_names = self.predictor.get_input_names()
  216. self.input_handles = []
  217. self.output_handles = []
  218. for input_name in input_names:
  219. input_handle = self.predictor.get_input_handle(input_name)
  220. self.input_handles.append(input_handle)
  221. output_names = self.predictor.get_output_names()
  222. for output_name in output_names:
  223. output_handle = self.predictor.get_output_handle(output_name)
  224. self.output_handles.append(output_handle)
  225. def __call__(self, x):
  226. for input_, input_handle in zip(x, self.input_handles):
  227. input_handle.reshape(input_.shape)
  228. input_handle.copy_from_cpu(input_)
  229. self.predictor.run()
  230. outputs = [o.copy_to_cpu() for o in self.output_handles]
  231. return outputs
  232. class StaticInfer(metaclass=abc.ABCMeta):
  233. @abc.abstractmethod
  234. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  235. raise NotImplementedError
  236. class PaddleInfer(StaticInfer):
  237. def __init__(
  238. self,
  239. model_name: str,
  240. model_dir: Union[str, PathLike],
  241. model_file_prefix: str,
  242. option: PaddlePredictorOption,
  243. ) -> None:
  244. super().__init__()
  245. self._model_name = model_name
  246. self.model_dir = Path(model_dir)
  247. self.model_file_prefix = model_file_prefix
  248. self._option = option
  249. self.predictor = self._create()
  250. self.infer = PaddleInferChainLegacy(self.predictor)
  251. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  252. names = self.predictor.get_input_names()
  253. if len(names) != len(x):
  254. raise ValueError(
  255. f"The number of inputs does not match the model: {len(names)} vs {len(x)}"
  256. )
  257. # TODO:
  258. # Ensure that input tensors follow the model's input sequence without sorting.
  259. x = _sort_inputs(x, names)
  260. x = list(map(np.ascontiguousarray, x))
  261. pred = self.infer(x)
  262. return pred
  263. def _check_run_mode(self):
  264. # TODO: Check if trt is available
  265. # check avaliable for trt
  266. if (
  267. not DISABLE_TRT_MODEL_BL
  268. and self._option.run_mode.startswith("trt")
  269. and self._model_name in TRT_BLOCKLIST
  270. and self._option.device_type == "gpu"
  271. ):
  272. logging.warning(
  273. f"The model({self._model_name}) is not supported to run in trt mode! Using `paddle` instead!"
  274. )
  275. self._option.run_mode = "paddle"
  276. # check avaliable for mkldnn
  277. elif (
  278. not DISABLE_MKLDNN_MODEL_BL
  279. and self._option.run_mode.startswith("mkldnn")
  280. and self._model_name in MKLDNN_BLOCKLIST
  281. and self._option.device_type == "cpu"
  282. ):
  283. logging.warning(
  284. f"The model({self._model_name}) is not supported to run in MKLDNN mode! Using `paddle` instead!"
  285. )
  286. self._option.run_mode = "paddle"
  287. return "paddle"
  288. # check avaliable for model
  289. if self._model_name == "LaTeX_OCR_rec" and self._option.device_type == "cpu":
  290. import cpuinfo
  291. if (
  292. "GenuineIntel" in cpuinfo.get_cpu_info().get("vendor_id_raw", "")
  293. and self._option.run_mode != "mkldnn"
  294. ):
  295. logging.warning(
  296. "Now, the `LaTeX_OCR_rec` model only support `mkldnn` mode when running on Intel CPU devices. So using `mkldnn` instead."
  297. )
  298. self._option.run_mode = "mkldnn"
  299. def _create(
  300. self,
  301. ):
  302. """_create"""
  303. import paddle
  304. import paddle.inference
  305. model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
  306. if "paddle" not in model_paths:
  307. raise RuntimeError("No valid PaddlePaddle model found")
  308. check_supported_device_type(self._option.device_type, self._model_name)
  309. self._check_run_mode()
  310. model_file, params_file = model_paths["paddle"]
  311. if self._option.device_type == "cpu" and self._option.device_id is not None:
  312. self._option.device_id = None
  313. logging.debug("`device_id` has been set to None")
  314. if (
  315. self._option.device_type in ("gpu", "dcu", "npu", "mlu", "gcu", "xpu")
  316. and self._option.device_id is None
  317. ):
  318. self._option.device_id = 0
  319. logging.debug("`device_id` has been set to 0")
  320. # for TRT
  321. if self._option.run_mode.startswith("trt"):
  322. assert self._option.device_type.lower() == "gpu", (
  323. f"`{self._option.run_mode}` is only available on GPU devices, "
  324. f"but got device_type='{self._option.device_type}'."
  325. )
  326. cache_dir = self.model_dir / CACHE_DIR / "paddle"
  327. config = self._configure_trt(
  328. model_file,
  329. params_file,
  330. cache_dir,
  331. )
  332. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  333. config.enable_use_gpu(100, self._option.device_id)
  334. # for Native Paddle and MKLDNN
  335. else:
  336. config = paddle.inference.Config(str(model_file), str(params_file))
  337. if self._option.device_type == "gpu":
  338. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  339. from paddle.inference import PrecisionType
  340. precision = (
  341. PrecisionType.Half
  342. if self._option.run_mode == "paddle_fp16"
  343. else PrecisionType.Float32
  344. )
  345. config.disable_mkldnn()
  346. config.enable_use_gpu(100, self._option.device_id, precision)
  347. if hasattr(config, "enable_new_ir"):
  348. config.enable_new_ir(self._option.enable_new_ir)
  349. if self._option.enable_new_ir and self._option.enable_cinn:
  350. config.enable_cinn()
  351. if hasattr(config, "enable_new_executor"):
  352. config.enable_new_executor()
  353. config.set_optimization_level(3)
  354. elif self._option.device_type == "npu":
  355. config.enable_custom_device("npu", self._option.device_id)
  356. if hasattr(config, "enable_new_ir"):
  357. config.enable_new_ir(self._option.enable_new_ir)
  358. if hasattr(config, "enable_new_executor"):
  359. config.enable_new_executor()
  360. elif self._option.device_type == "xpu":
  361. config.enable_xpu()
  362. config.set_xpu_device_id(self._option.device_id)
  363. if hasattr(config, "enable_new_ir"):
  364. config.enable_new_ir(self._option.enable_new_ir)
  365. if hasattr(config, "enable_new_executor"):
  366. config.enable_new_executor()
  367. config.delete_pass("conv2d_bn_xpu_fuse_pass")
  368. config.delete_pass("transfer_layout_pass")
  369. elif self._option.device_type == "mlu":
  370. config.enable_custom_device("mlu", self._option.device_id)
  371. if hasattr(config, "enable_new_ir"):
  372. config.enable_new_ir(self._option.enable_new_ir)
  373. if hasattr(config, "enable_new_executor"):
  374. config.enable_new_executor()
  375. elif self._option.device_type == "gcu":
  376. from paddle_custom_device.gcu import passes as gcu_passes
  377. gcu_passes.setUp()
  378. config.enable_custom_device("gcu", self._option.device_id)
  379. if hasattr(config, "enable_new_ir"):
  380. config.enable_new_ir()
  381. if hasattr(config, "enable_new_executor"):
  382. config.enable_new_executor()
  383. else:
  384. pass_builder = config.pass_builder()
  385. name = "PaddleX_" + self._option.model_name
  386. gcu_passes.append_passes_for_legacy_ir(pass_builder, name)
  387. elif self._option.device_type == "dcu":
  388. if hasattr(config, "enable_new_ir"):
  389. config.enable_new_ir(self._option.enable_new_ir)
  390. if self._option.enable_new_ir and self._option.enable_cinn:
  391. config.enable_cinn()
  392. config.enable_use_gpu(100, self._option.device_id)
  393. config.disable_mkldnn()
  394. if hasattr(config, "enable_new_executor"):
  395. config.enable_new_executor()
  396. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  397. if paddle.is_compiled_with_rocm():
  398. # Delete unsupported passes in dcu
  399. config.delete_pass("conv2d_add_act_fuse_pass")
  400. config.delete_pass("conv2d_add_fuse_pass")
  401. elif self._option.device_type == "iluvatar_gpu":
  402. config.enable_custom_device("iluvatar_gpu", int(self._option.device_id))
  403. if hasattr(config, "enable_new_ir"):
  404. config.enable_new_ir(self._option.enable_new_ir)
  405. if hasattr(config, "enable_new_executor"):
  406. config.enable_new_executor()
  407. else:
  408. assert self._option.device_type == "cpu"
  409. config.disable_gpu()
  410. if "mkldnn" in self._option.run_mode:
  411. config.enable_mkldnn()
  412. if "bf16" in self._option.run_mode:
  413. config.enable_mkldnn_bfloat16()
  414. config.set_mkldnn_cache_capacity(self._option.mkldnn_cache_capacity)
  415. else:
  416. if hasattr(config, "disable_mkldnn"):
  417. config.disable_mkldnn()
  418. config.set_cpu_math_library_num_threads(self._option.cpu_threads)
  419. if hasattr(config, "enable_new_ir"):
  420. config.enable_new_ir(self._option.enable_new_ir)
  421. if hasattr(config, "enable_new_executor"):
  422. config.enable_new_executor()
  423. config.set_optimization_level(3)
  424. config.enable_memory_optim()
  425. for del_p in self._option.delete_pass:
  426. config.delete_pass(del_p)
  427. # Disable paddle inference logging
  428. if not DEBUG:
  429. config.disable_glog_info()
  430. predictor = paddle.inference.create_predictor(config)
  431. return predictor
  432. def _configure_trt(self, model_file, params_file, cache_dir):
  433. # TODO: Support calibration
  434. import paddle.inference
  435. if USE_PIR_TRT:
  436. if self._option.trt_dynamic_shapes is None:
  437. raise RuntimeError("No dynamic shape information provided")
  438. trt_save_path = cache_dir / "trt" / self.model_file_prefix
  439. trt_model_file = trt_save_path.with_suffix(".json")
  440. trt_params_file = trt_save_path.with_suffix(".pdiparams")
  441. if not trt_model_file.exists() or not trt_params_file.exists():
  442. _convert_trt(
  443. self._option.trt_cfg_setting,
  444. model_file,
  445. params_file,
  446. trt_save_path,
  447. self._option.device_id,
  448. self._option.trt_dynamic_shapes,
  449. self._option.trt_dynamic_shape_input_data,
  450. )
  451. else:
  452. logging.debug(
  453. f"Use TRT cache files(`{trt_model_file}` and `{trt_params_file}`)."
  454. )
  455. config = paddle.inference.Config(str(trt_model_file), str(trt_params_file))
  456. else:
  457. config = paddle.inference.Config(str(model_file), str(params_file))
  458. config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
  459. # call enable_use_gpu() first to use TensorRT engine
  460. config.enable_use_gpu(100, self._option.device_id)
  461. for func_name in self._option.trt_cfg_setting:
  462. assert hasattr(
  463. config, func_name
  464. ), f"The `{type(config)}` don't have function `{func_name}`!"
  465. args = self._option.trt_cfg_setting[func_name]
  466. if isinstance(args, list):
  467. getattr(config, func_name)(*args)
  468. else:
  469. getattr(config, func_name)(**args)
  470. if self._option.trt_use_dynamic_shapes:
  471. if self._option.trt_dynamic_shapes is None:
  472. raise RuntimeError("No dynamic shape information provided")
  473. if self._option.trt_collect_shape_range_info:
  474. # NOTE: We always use a shape range info file.
  475. if self._option.trt_shape_range_info_path is not None:
  476. trt_shape_range_info_path = Path(
  477. self._option.trt_shape_range_info_path
  478. )
  479. else:
  480. trt_shape_range_info_path = cache_dir / "shape_range_info.pbtxt"
  481. should_collect_shape_range_info = True
  482. if not trt_shape_range_info_path.exists():
  483. trt_shape_range_info_path.parent.mkdir(
  484. parents=True, exist_ok=True
  485. )
  486. logging.info(
  487. f"Shape range info will be collected into {trt_shape_range_info_path}"
  488. )
  489. elif self._option.trt_discard_cached_shape_range_info:
  490. trt_shape_range_info_path.unlink()
  491. logging.info(
  492. f"The shape range info file ({trt_shape_range_info_path}) has been removed, and the shape range info will be re-collected."
  493. )
  494. else:
  495. logging.info(
  496. f"A shape range info file ({trt_shape_range_info_path}) already exists. There is no need to collect the info again."
  497. )
  498. should_collect_shape_range_info = False
  499. if should_collect_shape_range_info:
  500. _collect_trt_shape_range_info(
  501. str(model_file),
  502. str(params_file),
  503. self._option.device_id,
  504. str(trt_shape_range_info_path),
  505. self._option.trt_dynamic_shapes,
  506. self._option.trt_dynamic_shape_input_data,
  507. )
  508. if (
  509. self._option.model_name in DISABLE_TRT_HALF_OPS_CONFIG
  510. and self._option.run_mode == "trt_fp16"
  511. ):
  512. paddle.inference.InternalUtils.disable_tensorrt_half_ops(
  513. config, DISABLE_TRT_HALF_OPS_CONFIG[self._option.model_name]
  514. )
  515. config.enable_tuned_tensorrt_dynamic_shape(
  516. str(trt_shape_range_info_path),
  517. self._option.trt_allow_rebuild_at_runtime,
  518. )
  519. else:
  520. min_shapes, opt_shapes, max_shapes = {}, {}, {}
  521. for (
  522. key,
  523. shapes,
  524. ) in self._option.trt_dynamic_shapes.items():
  525. min_shapes[key] = shapes[0]
  526. opt_shapes[key] = shapes[1]
  527. max_shapes[key] = shapes[2]
  528. config.set_trt_dynamic_shape_info(
  529. min_shapes, max_shapes, opt_shapes
  530. )
  531. return config
  532. # FIXME: Name might be misleading
  533. @benchmark.timeit
  534. @class_requires_deps("ultra-infer")
  535. class MultiBackendInfer(object):
  536. def __init__(self, ui_runtime):
  537. super().__init__()
  538. self.ui_runtime = ui_runtime
  539. # The time consumed by the wrapper code will also be taken into account.
  540. def __call__(self, x):
  541. outputs = self.ui_runtime.infer(x)
  542. return outputs
  543. # TODO: It would be better to refactor the code to make `HPInfer` a higher-level
  544. # class that uses `PaddleInfer`.
  545. @class_requires_deps("ultra-infer")
  546. class HPInfer(StaticInfer):
  547. def __init__(
  548. self,
  549. model_dir: Union[str, PathLike],
  550. model_file_prefix: str,
  551. config: HPIConfig,
  552. ) -> None:
  553. super().__init__()
  554. self._model_dir = Path(model_dir)
  555. self._model_file_prefix = model_file_prefix
  556. self._config = config
  557. backend, backend_config = self._determine_backend_and_config()
  558. if backend == "paddle":
  559. self._use_paddle = True
  560. self._paddle_infer = self._build_paddle_infer(backend_config)
  561. else:
  562. self._use_paddle = False
  563. ui_runtime = self._build_ui_runtime(backend, backend_config)
  564. self._multi_backend_infer = MultiBackendInfer(ui_runtime)
  565. num_inputs = ui_runtime.num_inputs()
  566. self._input_names = [
  567. ui_runtime.get_input_info(i).name for i in range(num_inputs)
  568. ]
  569. @property
  570. def model_dir(self) -> Path:
  571. return self._model_dir
  572. @property
  573. def model_file_prefix(self) -> str:
  574. return self._model_file_prefix
  575. @property
  576. def config(self) -> HPIConfig:
  577. return self._config
  578. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  579. if self._use_paddle:
  580. return self._call_paddle_infer(x)
  581. else:
  582. return self._call_multi_backend_infer(x)
  583. def _call_paddle_infer(self, x):
  584. return self._paddle_infer(x)
  585. def _call_multi_backend_infer(self, x):
  586. num_inputs = len(self._input_names)
  587. if len(x) != num_inputs:
  588. raise ValueError(f"Expected {num_inputs} inputs but got {len(x)} instead")
  589. x = _sort_inputs(x, self._input_names)
  590. inputs = {}
  591. for name, input_ in zip(self._input_names, x):
  592. inputs[name] = np.ascontiguousarray(input_)
  593. return self._multi_backend_infer(inputs)
  594. def _determine_backend_and_config(self):
  595. if self._config.auto_config:
  596. # Should we use the strategy pattern here to allow extensible
  597. # strategies?
  598. model_paths = get_model_paths(self._model_dir, self._model_file_prefix)
  599. ret = suggest_inference_backend_and_config(
  600. self._config,
  601. model_paths,
  602. )
  603. if ret[0] is None:
  604. # Should I use a custom exception?
  605. raise RuntimeError(
  606. f"No inference backend and configuration could be suggested. Reason: {ret[1]}"
  607. )
  608. backend, backend_config = ret
  609. else:
  610. backend = self._config.backend
  611. if backend is None:
  612. raise RuntimeError(
  613. "When automatic configuration is not used, the inference backend must be specified manually."
  614. )
  615. backend_config = self._config.backend_config or {}
  616. if backend == "paddle":
  617. if not backend_config:
  618. is_default_config = True
  619. elif backend_config.keys() != {"run_mode"}:
  620. is_default_config = False
  621. else:
  622. is_default_config = backend_config["run_mode"] == get_default_run_mode(
  623. self._config.pdx_model_name, self._config.device_type
  624. )
  625. if is_default_config:
  626. logging.warning(
  627. "The Paddle Inference backend is selected with the default configuration. This may not provide optimal performance."
  628. )
  629. return backend, backend_config
  630. def _build_paddle_infer(self, backend_config):
  631. kwargs = {
  632. "device_type": self._config.device_type,
  633. "device_id": self._config.device_id,
  634. **backend_config,
  635. }
  636. # TODO: This is probably redundant. Can we reuse the code in the
  637. # predictor class?
  638. paddle_info = None
  639. if self._config.hpi_info:
  640. hpi_info = self._config.hpi_info
  641. if hpi_info.backend_configs:
  642. paddle_info = hpi_info.backend_configs.paddle_infer
  643. if paddle_info is not None:
  644. if (
  645. kwargs.get("trt_dynamic_shapes") is None
  646. and paddle_info.trt_dynamic_shapes is not None
  647. ):
  648. trt_dynamic_shapes = paddle_info.trt_dynamic_shapes
  649. logging.debug("TensorRT dynamic shapes set to %s", trt_dynamic_shapes)
  650. kwargs["trt_dynamic_shapes"] = trt_dynamic_shapes
  651. if (
  652. kwargs.get("trt_dynamic_shape_input_data") is None
  653. and paddle_info.trt_dynamic_shape_input_data is not None
  654. ):
  655. trt_dynamic_shape_input_data = paddle_info.trt_dynamic_shape_input_data
  656. logging.debug(
  657. "TensorRT dynamic shape input data set to %s",
  658. trt_dynamic_shape_input_data,
  659. )
  660. kwargs["trt_dynamic_shape_input_data"] = trt_dynamic_shape_input_data
  661. pp_option = PaddlePredictorOption(**kwargs)
  662. pp_option.setdefault_by_model_name(model_name=self._config.pdx_model_name)
  663. logging.info("Using Paddle Inference backend")
  664. logging.info("Paddle predictor option: %s", pp_option)
  665. return PaddleInfer(
  666. self._config.pdx_model_name,
  667. self._model_dir,
  668. self._model_file_prefix,
  669. option=pp_option,
  670. )
  671. def _build_ui_runtime(self, backend, backend_config, ui_option=None):
  672. # TODO: Validate the compatibility of backends with device types
  673. from ultra_infer import ModelFormat, Runtime, RuntimeOption
  674. if ui_option is None:
  675. ui_option = RuntimeOption()
  676. if self._config.device_type == "cpu":
  677. pass
  678. elif self._config.device_type == "gpu":
  679. ui_option.use_gpu(self._config.device_id or 0)
  680. elif self._config.device_type == "npu":
  681. ui_option.use_ascend(self._config.device_id or 0)
  682. else:
  683. raise RuntimeError(
  684. f"Unsupported device type {repr(self._config.device_type)}"
  685. )
  686. model_paths = get_model_paths(self._model_dir, self.model_file_prefix)
  687. if backend in ("openvino", "onnxruntime", "tensorrt"):
  688. # XXX: This introduces side effects.
  689. if "onnx" not in model_paths:
  690. if self._config.auto_paddle2onnx:
  691. if "paddle" not in model_paths:
  692. raise RuntimeError("PaddlePaddle model required")
  693. # The CLI is used here since there is currently no API.
  694. logging.info(
  695. "Automatically converting PaddlePaddle model to ONNX format"
  696. )
  697. try:
  698. subprocess.run(
  699. [
  700. "paddlex",
  701. "--paddle2onnx",
  702. "--paddle_model_dir",
  703. str(self._model_dir),
  704. "--onnx_model_dir",
  705. str(self._model_dir),
  706. ],
  707. capture_output=True,
  708. check=True,
  709. text=True,
  710. )
  711. except subprocess.CalledProcessError as e:
  712. raise RuntimeError(
  713. f"PaddlePaddle-to-ONNX conversion failed:\n{e.stderr}"
  714. ) from e
  715. model_paths = get_model_paths(
  716. self._model_dir, self.model_file_prefix
  717. )
  718. assert "onnx" in model_paths
  719. else:
  720. raise RuntimeError("ONNX model required")
  721. ui_option.set_model_path(str(model_paths["onnx"]), "", ModelFormat.ONNX)
  722. elif backend == "om":
  723. if "om" not in model_paths:
  724. raise RuntimeError("OM model required")
  725. ui_option.set_model_path(str(model_paths["om"]), "", ModelFormat.OM)
  726. else:
  727. raise ValueError(f"Unsupported inference backend {repr(backend)}")
  728. if backend == "openvino":
  729. backend_config = OpenVINOConfig.model_validate(backend_config)
  730. ui_option.use_openvino_backend()
  731. ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
  732. elif backend == "onnxruntime":
  733. backend_config = ONNXRuntimeConfig.model_validate(backend_config)
  734. ui_option.use_ort_backend()
  735. ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
  736. elif backend == "tensorrt":
  737. if (
  738. backend_config.get("use_dynamic_shapes", True)
  739. and backend_config.get("dynamic_shapes") is None
  740. ):
  741. trt_info = None
  742. if self._config.hpi_info:
  743. hpi_info = self._config.hpi_info
  744. if hpi_info.backend_configs:
  745. trt_info = hpi_info.backend_configs.tensorrt
  746. if trt_info is not None and trt_info.dynamic_shapes is not None:
  747. trt_dynamic_shapes = trt_info.dynamic_shapes
  748. logging.debug(
  749. "TensorRT dynamic shapes set to %s", trt_dynamic_shapes
  750. )
  751. backend_config = {
  752. **backend_config,
  753. "dynamic_shapes": trt_dynamic_shapes,
  754. }
  755. backend_config = TensorRTConfig.model_validate(backend_config)
  756. ui_option.use_trt_backend()
  757. cache_dir = self._model_dir / CACHE_DIR / "tensorrt"
  758. cache_dir.mkdir(parents=True, exist_ok=True)
  759. ui_option.trt_option.serialize_file = str(cache_dir / "trt_serialized.trt")
  760. if backend_config.precision == "fp16":
  761. ui_option.trt_option.enable_fp16 = True
  762. if not backend_config.use_dynamic_shapes:
  763. raise RuntimeError(
  764. "TensorRT static shape inference is currently not supported"
  765. )
  766. if backend_config.dynamic_shapes is not None:
  767. if not Path(ui_option.trt_option.serialize_file).exists():
  768. for name, shapes in backend_config.dynamic_shapes.items():
  769. ui_option.trt_option.set_shape(name, *shapes)
  770. else:
  771. logging.info(
  772. "TensorRT dynamic shapes will be loaded from the file."
  773. )
  774. elif backend == "om":
  775. backend_config = OMConfig.model_validate(backend_config)
  776. ui_option.use_om_backend()
  777. else:
  778. raise ValueError(f"Unsupported inference backend {repr(backend)}")
  779. logging.info("Inference backend: %s", backend)
  780. logging.info("Inference backend config: %s", backend_config)
  781. ui_runtime = Runtime(ui_option)
  782. return ui_runtime