static_infer.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import subprocess
  16. from pathlib import Path
  17. from typing import List, Sequence
  18. import numpy as np
  19. from ....utils import logging
  20. from ....utils.deps import (
  21. class_requires_deps,
  22. function_requires_deps,
  23. is_paddle2onnx_plugin_available,
  24. )
  25. from ....utils.device import constr_device
  26. from ....utils.flags import DEBUG, INFER_BENCHMARK_USE_NEW_INFER_API, USE_PIR_TRT
  27. from ...utils.benchmark import benchmark, set_inference_operations
  28. from ...utils.hpi import (
  29. HPIConfig,
  30. OMConfig,
  31. ONNXRuntimeConfig,
  32. OpenVINOConfig,
  33. TensorRTConfig,
  34. get_model_paths,
  35. suggest_inference_backend_and_config,
  36. )
  37. from ...utils.pp_option import PaddlePredictorOption
  38. from ...utils.trt_config import DISABLE_TRT_HALF_OPS_CONFIG
  39. CACHE_DIR = ".cache"
  40. INFERENCE_OPERATIONS = [
  41. "PaddleCopyToDevice",
  42. "PaddleCopyToHost",
  43. "PaddleModelInfer",
  44. "PaddleInferChainLegacy",
  45. "MultiBackendInfer",
  46. ]
  47. set_inference_operations(INFERENCE_OPERATIONS)
  48. # XXX: Better use Paddle Inference API to do this
  49. @function_requires_deps("paddlepaddle")
  50. def _pd_dtype_to_np_dtype(pd_dtype):
  51. import paddle
  52. if pd_dtype == paddle.inference.DataType.FLOAT64:
  53. return np.float64
  54. elif pd_dtype == paddle.inference.DataType.FLOAT32:
  55. return np.float32
  56. elif pd_dtype == paddle.inference.DataType.INT64:
  57. return np.int64
  58. elif pd_dtype == paddle.inference.DataType.INT32:
  59. return np.int32
  60. elif pd_dtype == paddle.inference.DataType.UINT8:
  61. return np.uint8
  62. elif pd_dtype == paddle.inference.DataType.INT8:
  63. return np.int8
  64. else:
  65. raise TypeError(f"Unsupported data type: {pd_dtype}")
  66. # old trt
  67. @function_requires_deps("paddlepaddle")
  68. def _collect_trt_shape_range_info(
  69. model_file,
  70. model_params,
  71. gpu_id,
  72. shape_range_info_path,
  73. dynamic_shapes,
  74. dynamic_shape_input_data,
  75. ):
  76. import paddle.inference
  77. dynamic_shape_input_data = dynamic_shape_input_data or {}
  78. config = paddle.inference.Config(model_file, model_params)
  79. config.enable_use_gpu(100, gpu_id)
  80. config.collect_shape_range_info(shape_range_info_path)
  81. # TODO: Add other needed options
  82. config.disable_glog_info()
  83. predictor = paddle.inference.create_predictor(config)
  84. input_names = predictor.get_input_names()
  85. for name in dynamic_shapes:
  86. if name not in input_names:
  87. raise ValueError(
  88. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  89. )
  90. for name in input_names:
  91. if name not in dynamic_shapes:
  92. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  93. for name in dynamic_shape_input_data:
  94. if name not in input_names:
  95. raise ValueError(
  96. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  97. )
  98. # It would be better to check if the shapes are valid.
  99. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  100. for name, candidate_shapes in dynamic_shapes.items():
  101. # XXX: Currently we have no way to get the data type of the tensor
  102. # without creating an input handle.
  103. handle = predictor.get_input_handle(name)
  104. dtype = _pd_dtype_to_np_dtype(handle.type())
  105. min_shape, opt_shape, max_shape = candidate_shapes
  106. if name in dynamic_shape_input_data:
  107. min_arrs[name] = np.array(
  108. dynamic_shape_input_data[name][0], dtype=dtype
  109. ).reshape(min_shape)
  110. opt_arrs[name] = np.array(
  111. dynamic_shape_input_data[name][1], dtype=dtype
  112. ).reshape(opt_shape)
  113. max_arrs[name] = np.array(
  114. dynamic_shape_input_data[name][2], dtype=dtype
  115. ).reshape(max_shape)
  116. else:
  117. min_arrs[name] = np.ones(min_shape, dtype=dtype)
  118. opt_arrs[name] = np.ones(opt_shape, dtype=dtype)
  119. max_arrs[name] = np.ones(max_shape, dtype=dtype)
  120. # `opt_arrs` is used twice to ensure it is the most frequently used.
  121. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  122. for name, arr in arrs.items():
  123. handle = predictor.get_input_handle(name)
  124. handle.reshape(arr.shape)
  125. handle.copy_from_cpu(arr)
  126. predictor.run()
  127. # HACK: The shape range info will be written to the file only when
  128. # `predictor` is garbage collected. It works in CPython, but it is
  129. # definitely a bad idea to count on the implementation-dependent behavior of
  130. # a garbage collector. Is there a more explicit and deterministic way to
  131. # handle this?
  132. # HACK: Manually delete the predictor to trigger its destructor, ensuring that the shape_range_info file would be saved.
  133. del predictor
  134. # pir trt
  135. @function_requires_deps("paddlepaddle")
  136. def _convert_trt(
  137. trt_cfg_setting,
  138. pp_model_file,
  139. pp_params_file,
  140. trt_save_path,
  141. device_id,
  142. dynamic_shapes,
  143. dynamic_shape_input_data,
  144. ):
  145. import paddle.inference
  146. from paddle.tensorrt.export import Input, TensorRTConfig, convert
  147. def _set_trt_config():
  148. for attr_name in trt_cfg_setting:
  149. assert hasattr(
  150. trt_config, attr_name
  151. ), f"The `{type(trt_config)}` don't have the attribute `{attr_name}`!"
  152. setattr(trt_config, attr_name, trt_cfg_setting[attr_name])
  153. def _get_predictor(model_file, params_file):
  154. # HACK
  155. config = paddle.inference.Config(str(model_file), str(params_file))
  156. config.enable_use_gpu(100, device_id)
  157. # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
  158. config.disable_mkldnn()
  159. config.disable_glog_info()
  160. return paddle.inference.create_predictor(config)
  161. dynamic_shape_input_data = dynamic_shape_input_data or {}
  162. predictor = _get_predictor(pp_model_file, pp_params_file)
  163. input_names = predictor.get_input_names()
  164. for name in dynamic_shapes:
  165. if name not in input_names:
  166. raise ValueError(
  167. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  168. )
  169. for name in input_names:
  170. if name not in dynamic_shapes:
  171. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  172. for name in dynamic_shape_input_data:
  173. if name not in input_names:
  174. raise ValueError(
  175. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  176. )
  177. trt_inputs = []
  178. for name, candidate_shapes in dynamic_shapes.items():
  179. # XXX: Currently we have no way to get the data type of the tensor
  180. # without creating an input handle.
  181. handle = predictor.get_input_handle(name)
  182. dtype = _pd_dtype_to_np_dtype(handle.type())
  183. min_shape, opt_shape, max_shape = candidate_shapes
  184. if name in dynamic_shape_input_data:
  185. min_arr = np.array(dynamic_shape_input_data[name][0], dtype=dtype).reshape(
  186. min_shape
  187. )
  188. opt_arr = np.array(dynamic_shape_input_data[name][1], dtype=dtype).reshape(
  189. opt_shape
  190. )
  191. max_arr = np.array(dynamic_shape_input_data[name][2], dtype=dtype).reshape(
  192. max_shape
  193. )
  194. else:
  195. min_arr = np.ones(min_shape, dtype=dtype)
  196. opt_arr = np.ones(opt_shape, dtype=dtype)
  197. max_arr = np.ones(max_shape, dtype=dtype)
  198. # refer to: https://github.com/PolaKuma/Paddle/blob/3347f225bc09f2ec09802a2090432dd5cb5b6739/test/tensorrt/test_converter_model_resnet50.py
  199. trt_input = Input((min_arr, opt_arr, max_arr))
  200. trt_inputs.append(trt_input)
  201. # Create TensorRTConfig
  202. trt_config = TensorRTConfig(inputs=trt_inputs)
  203. _set_trt_config()
  204. trt_config.save_model_dir = str(trt_save_path)
  205. pp_model_path = str(pp_model_file.with_suffix(""))
  206. convert(pp_model_path, trt_config)
  207. def _sort_inputs(inputs, names):
  208. # NOTE: Adjust input tensors to match the sorted sequence.
  209. indices = sorted(range(len(names)), key=names.__getitem__)
  210. inputs = [inputs[indices.index(i)] for i in range(len(inputs))]
  211. return inputs
  212. def _concatenate(*callables):
  213. def _chain(x):
  214. for c in callables:
  215. x = c(x)
  216. return x
  217. return _chain
  218. @benchmark.timeit
  219. @class_requires_deps("paddlepaddle")
  220. class PaddleCopyToDevice:
  221. def __init__(self, device_type, device_id):
  222. self.device_type = device_type
  223. self.device_id = device_id
  224. def __call__(self, arrs):
  225. import paddle
  226. device_id = [self.device_id] if self.device_id is not None else self.device_id
  227. device = constr_device(self.device_type, device_id)
  228. paddle_tensors = [paddle.to_tensor(i, place=device) for i in arrs]
  229. return paddle_tensors
  230. @benchmark.timeit
  231. @class_requires_deps("paddlepaddle")
  232. class PaddleCopyToHost:
  233. def __call__(self, paddle_tensors):
  234. arrs = [i.numpy() for i in paddle_tensors]
  235. return arrs
  236. @benchmark.timeit
  237. @class_requires_deps("paddlepaddle")
  238. class PaddleModelInfer:
  239. def __init__(self, predictor):
  240. super().__init__()
  241. self.predictor = predictor
  242. def __call__(self, x):
  243. return self.predictor.run(x)
  244. # FIXME: Name might be misleading
  245. @benchmark.timeit
  246. @class_requires_deps("paddlepaddle")
  247. class PaddleInferChainLegacy:
  248. def __init__(self, predictor):
  249. self.predictor = predictor
  250. input_names = self.predictor.get_input_names()
  251. self.input_handles = []
  252. self.output_handles = []
  253. for input_name in input_names:
  254. input_handle = self.predictor.get_input_handle(input_name)
  255. self.input_handles.append(input_handle)
  256. output_names = self.predictor.get_output_names()
  257. for output_name in output_names:
  258. output_handle = self.predictor.get_output_handle(output_name)
  259. self.output_handles.append(output_handle)
  260. def __call__(self, x):
  261. for input_, input_handle in zip(x, self.input_handles):
  262. input_handle.reshape(input_.shape)
  263. input_handle.copy_from_cpu(input_)
  264. self.predictor.run()
  265. outputs = [o.copy_to_cpu() for o in self.output_handles]
  266. return outputs
  267. class StaticInfer(metaclass=abc.ABCMeta):
  268. @abc.abstractmethod
  269. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  270. raise NotImplementedError
  271. @class_requires_deps("paddlepaddle")
  272. class PaddleInfer(StaticInfer):
  273. def __init__(
  274. self,
  275. model_dir: str,
  276. model_file_prefix: str,
  277. option: PaddlePredictorOption,
  278. ) -> None:
  279. super().__init__()
  280. self.model_dir = model_dir
  281. self.model_file_prefix = model_file_prefix
  282. self._option = option
  283. self.predictor = self._create()
  284. if INFER_BENCHMARK_USE_NEW_INFER_API:
  285. device_type = self._option.device_type
  286. device_type = "gpu" if device_type == "dcu" else device_type
  287. copy_to_device = PaddleCopyToDevice(device_type, self._option.device_id)
  288. copy_to_host = PaddleCopyToHost()
  289. model_infer = PaddleModelInfer(self.predictor)
  290. self.infer = _concatenate(copy_to_device, model_infer, copy_to_host)
  291. else:
  292. self.infer = PaddleInferChainLegacy(self.predictor)
  293. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  294. names = self.predictor.get_input_names()
  295. if len(names) != len(x):
  296. raise ValueError(
  297. f"The number of inputs does not match the model: {len(names)} vs {len(x)}"
  298. )
  299. # TODO:
  300. # Ensure that input tensors follow the model's input sequence without sorting.
  301. x = _sort_inputs(x, names)
  302. x = list(map(np.ascontiguousarray, x))
  303. pred = self.infer(x)
  304. return pred
  305. def _create(
  306. self,
  307. ):
  308. """_create"""
  309. import paddle
  310. import paddle.inference
  311. model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
  312. if "paddle" not in model_paths:
  313. raise RuntimeError("No valid PaddlePaddle model found")
  314. model_file, params_file = model_paths["paddle"]
  315. if (
  316. self._option.model_name == "LaTeX_OCR_rec"
  317. and self._option.device_type == "cpu"
  318. ):
  319. import cpuinfo
  320. if (
  321. "GenuineIntel" in cpuinfo.get_cpu_info().get("vendor_id_raw", "")
  322. and self._option.run_mode != "mkldnn"
  323. ):
  324. logging.warning(
  325. "Now, the `LaTeX_OCR_rec` model only support `mkldnn` mode when running on Intel CPU devices. So using `mkldnn` instead."
  326. )
  327. self._option.run_mode = "mkldnn"
  328. logging.debug("`run_mode` updated to 'mkldnn'")
  329. if self._option.device_type == "cpu" and self._option.device_id is not None:
  330. self._option.device_id = None
  331. logging.debug("`device_id` has been set to None")
  332. if (
  333. self._option.device_type in ("gpu", "dcu")
  334. and self._option.device_id is None
  335. ):
  336. self._option.device_id = 0
  337. logging.debug("`device_id` has been set to 0")
  338. # for TRT
  339. if self._option.run_mode.startswith("trt"):
  340. assert self._option.device_type == "gpu"
  341. cache_dir = self.model_dir / CACHE_DIR / "paddle"
  342. config = self._configure_trt(
  343. model_file,
  344. params_file,
  345. cache_dir,
  346. )
  347. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  348. config.enable_use_gpu(100, self._option.device_id)
  349. # for Native Paddle and MKLDNN
  350. else:
  351. config = paddle.inference.Config(str(model_file), str(params_file))
  352. if self._option.device_type == "gpu":
  353. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  354. from paddle.inference import PrecisionType
  355. precision = (
  356. PrecisionType.Half
  357. if self._option.run_mode == "paddle_fp16"
  358. else PrecisionType.Float32
  359. )
  360. config.enable_use_gpu(100, self._option.device_id, precision)
  361. if hasattr(config, "enable_new_ir"):
  362. config.enable_new_ir(self._option.enable_new_ir)
  363. if hasattr(config, "enable_new_executor"):
  364. config.enable_new_executor()
  365. config.set_optimization_level(3)
  366. elif self._option.device_type == "npu":
  367. config.enable_custom_device("npu")
  368. if hasattr(config, "enable_new_executor"):
  369. config.enable_new_executor()
  370. elif self._option.device_type == "xpu":
  371. if hasattr(config, "enable_new_executor"):
  372. config.enable_new_executor()
  373. elif self._option.device_type == "mlu":
  374. config.enable_custom_device("mlu")
  375. if hasattr(config, "enable_new_executor"):
  376. config.enable_new_executor()
  377. elif self._option.device_type == "gcu":
  378. from paddle_custom_device.gcu import passes as gcu_passes
  379. gcu_passes.setUp()
  380. config.enable_custom_device("gcu")
  381. if hasattr(config, "enable_new_executor"):
  382. config.enable_new_ir()
  383. config.enable_new_executor()
  384. else:
  385. pass_builder = config.pass_builder()
  386. name = "PaddleX_" + self._option.model_name
  387. gcu_passes.append_passes_for_legacy_ir(pass_builder, name)
  388. elif self._option.device_type == "dcu":
  389. config.enable_use_gpu(100, self._option.device_id)
  390. if hasattr(config, "enable_new_executor"):
  391. config.enable_new_executor()
  392. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  393. if paddle.is_compiled_with_rocm():
  394. # Delete unsupported passes in dcu
  395. config.delete_pass("conv2d_add_act_fuse_pass")
  396. config.delete_pass("conv2d_add_fuse_pass")
  397. else:
  398. assert self._option.device_type == "cpu"
  399. config.disable_gpu()
  400. if "mkldnn" in self._option.run_mode:
  401. try:
  402. config.enable_mkldnn()
  403. if "bf16" in self._option.run_mode:
  404. config.enable_mkldnn_bfloat16()
  405. except Exception:
  406. logging.warning(
  407. "MKL-DNN is not available. We will disable MKL-DNN."
  408. )
  409. config.set_mkldnn_cache_capacity(-1)
  410. else:
  411. if hasattr(config, "disable_mkldnn"):
  412. config.disable_mkldnn()
  413. config.set_cpu_math_library_num_threads(self._option.cpu_threads)
  414. if hasattr(config, "enable_new_ir"):
  415. config.enable_new_ir(self._option.enable_new_ir)
  416. if hasattr(config, "enable_new_executor"):
  417. config.enable_new_executor()
  418. config.set_optimization_level(3)
  419. config.enable_memory_optim()
  420. for del_p in self._option.delete_pass:
  421. config.delete_pass(del_p)
  422. # Disable paddle inference logging
  423. if not DEBUG:
  424. config.disable_glog_info()
  425. predictor = paddle.inference.create_predictor(config)
  426. return predictor
  427. def _configure_trt(self, model_file, params_file, cache_dir):
  428. # TODO: Support calibration
  429. import paddle.inference
  430. if USE_PIR_TRT:
  431. trt_save_path = cache_dir / "trt" / self.model_file_prefix
  432. _convert_trt(
  433. self._option.trt_cfg_setting,
  434. model_file,
  435. params_file,
  436. trt_save_path,
  437. self._option.device_id,
  438. self._option.trt_dynamic_shapes,
  439. self._option.trt_dynamic_shape_input_data,
  440. )
  441. model_file = trt_save_path.with_suffix(".json")
  442. params_file = trt_save_path.with_suffix(".pdiparams")
  443. config = paddle.inference.Config(str(model_file), str(params_file))
  444. else:
  445. config = paddle.inference.Config(str(model_file), str(params_file))
  446. config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
  447. # call enable_use_gpu() first to use TensorRT engine
  448. config.enable_use_gpu(100, self._option.device_id)
  449. for func_name in self._option.trt_cfg_setting:
  450. assert hasattr(
  451. config, func_name
  452. ), f"The `{type(config)}` don't have function `{func_name}`!"
  453. args = self._option.trt_cfg_setting[func_name]
  454. if isinstance(args, list):
  455. getattr(config, func_name)(*args)
  456. else:
  457. getattr(config, func_name)(**args)
  458. if self._option.trt_use_dynamic_shapes:
  459. if self._option.trt_collect_shape_range_info:
  460. # NOTE: We always use a shape range info file.
  461. if self._option.trt_shape_range_info_path is not None:
  462. trt_shape_range_info_path = Path(
  463. self._option.trt_shape_range_info_path
  464. )
  465. else:
  466. trt_shape_range_info_path = cache_dir / "shape_range_info.pbtxt"
  467. should_collect_shape_range_info = True
  468. if not trt_shape_range_info_path.exists():
  469. trt_shape_range_info_path.parent.mkdir(
  470. parents=True, exist_ok=True
  471. )
  472. logging.info(
  473. f"Shape range info will be collected into {trt_shape_range_info_path}"
  474. )
  475. elif self._option.trt_discard_cached_shape_range_info:
  476. trt_shape_range_info_path.unlink()
  477. logging.info(
  478. f"The shape range info file ({trt_shape_range_info_path}) has been removed, and the shape range info will be re-collected."
  479. )
  480. else:
  481. logging.info(
  482. f"A shape range info file ({trt_shape_range_info_path}) already exists. There is no need to collect the info again."
  483. )
  484. should_collect_shape_range_info = False
  485. if should_collect_shape_range_info:
  486. _collect_trt_shape_range_info(
  487. str(model_file),
  488. str(params_file),
  489. self._option.device_id,
  490. str(trt_shape_range_info_path),
  491. self._option.trt_dynamic_shapes,
  492. self._option.trt_dynamic_shape_input_data,
  493. )
  494. if (
  495. self._option.model_name in DISABLE_TRT_HALF_OPS_CONFIG
  496. and self._option.run_mode == "trt_fp16"
  497. ):
  498. paddle.inference.InternalUtils.disable_tensorrt_half_ops(
  499. config, DISABLE_TRT_HALF_OPS_CONFIG[self._option.model_name]
  500. )
  501. config.enable_tuned_tensorrt_dynamic_shape(
  502. str(trt_shape_range_info_path),
  503. self._option.trt_allow_rebuild_at_runtime,
  504. )
  505. else:
  506. if self._option.trt_dynamic_shapes is not None:
  507. min_shapes, opt_shapes, max_shapes = {}, {}, {}
  508. for (
  509. key,
  510. shapes,
  511. ) in self._option.trt_dynamic_shapes.items():
  512. min_shapes[key] = shapes[0]
  513. opt_shapes[key] = shapes[1]
  514. max_shapes[key] = shapes[2]
  515. config.set_trt_dynamic_shape_info(
  516. min_shapes, max_shapes, opt_shapes
  517. )
  518. else:
  519. raise RuntimeError("No dynamic shape information provided")
  520. return config
  521. # FIXME: Name might be misleading
  522. @benchmark.timeit
  523. @class_requires_deps("ultra-infer")
  524. class MultiBackendInfer(object):
  525. def __init__(self, ui_runtime):
  526. super().__init__()
  527. self.ui_runtime = ui_runtime
  528. # The time consumed by the wrapper code will also be taken into account.
  529. def __call__(self, x):
  530. outputs = self.ui_runtime.infer(x)
  531. return outputs
  532. # TODO: It would be better to refactor the code to make `HPInfer` a higher-level
  533. # class that uses `PaddleInfer`.
  534. @class_requires_deps("ultra-infer", "paddlepaddle")
  535. class HPInfer(StaticInfer):
  536. def __init__(
  537. self,
  538. model_dir: str,
  539. model_file_prefix: str,
  540. config: HPIConfig,
  541. ) -> None:
  542. super().__init__()
  543. self._model_dir = model_dir
  544. self._model_file_prefix = model_file_prefix
  545. self._config = config
  546. backend, backend_config = self._determine_backend_and_config()
  547. if backend == "paddle":
  548. self._use_paddle = True
  549. self._paddle_infer = self._build_paddle_infer(backend_config)
  550. else:
  551. self._use_paddle = False
  552. ui_runtime = self._build_ui_runtime(backend, backend_config)
  553. self._multi_backend_infer = MultiBackendInfer(ui_runtime)
  554. num_inputs = ui_runtime.num_inputs()
  555. self._input_names = [
  556. ui_runtime.get_input_info(i).name for i in range(num_inputs)
  557. ]
  558. @property
  559. def model_dir(self) -> str:
  560. return self._model_dir
  561. @property
  562. def model_file_prefix(self) -> str:
  563. return self._model_file_prefix
  564. @property
  565. def config(self) -> HPIConfig:
  566. return self._config
  567. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  568. if self._use_paddle:
  569. return self._call_paddle_infer(x)
  570. else:
  571. return self._call_multi_backend_infer(x)
  572. def _call_paddle_infer(self, x):
  573. return self._paddle_infer(x)
  574. def _call_multi_backend_infer(self, x):
  575. num_inputs = len(self._input_names)
  576. if len(x) != num_inputs:
  577. raise ValueError(f"Expected {num_inputs} inputs but got {len(x)} instead")
  578. x = _sort_inputs(x, self._input_names)
  579. inputs = {}
  580. for name, input_ in zip(self._input_names, x):
  581. inputs[name] = np.ascontiguousarray(input_)
  582. return self._multi_backend_infer(inputs)
  583. def _determine_backend_and_config(self):
  584. from ultra_infer import (
  585. is_built_with_om,
  586. is_built_with_openvino,
  587. is_built_with_ort,
  588. is_built_with_trt,
  589. )
  590. model_paths = get_model_paths(self._model_dir, self._model_file_prefix)
  591. is_onnx_model_available = "onnx" in model_paths
  592. # TODO: Give a warning if the Paddle2ONNX plugin is not available but
  593. # can be used to select a better backend.
  594. if self._config.auto_paddle2onnx:
  595. if is_paddle2onnx_plugin_available():
  596. is_onnx_model_available = (
  597. is_onnx_model_available or "paddle" in model_paths
  598. )
  599. else:
  600. logging.debug(
  601. "The Paddle2ONNX plugin is not available. Automatic model conversion will not be performed."
  602. )
  603. available_backends = []
  604. if "paddle" in model_paths:
  605. available_backends.append("paddle")
  606. if is_built_with_openvino() and is_onnx_model_available:
  607. available_backends.append("openvino")
  608. if is_built_with_ort() and is_onnx_model_available:
  609. available_backends.append("onnxruntime")
  610. if is_built_with_trt() and is_onnx_model_available:
  611. available_backends.append("tensorrt")
  612. if is_built_with_om() and "om" in model_paths:
  613. available_backends.append("om")
  614. if not available_backends:
  615. raise RuntimeError("No inference backend is available")
  616. if (
  617. self._config.backend is not None
  618. and self._config.backend not in available_backends
  619. ):
  620. raise RuntimeError(
  621. f"Inference backend {repr(self._config.backend)} is unavailable"
  622. )
  623. if self._config.auto_config:
  624. # Should we use the strategy pattern here to allow extensible
  625. # strategies?
  626. ret = suggest_inference_backend_and_config(
  627. self._config, available_backends=available_backends
  628. )
  629. if ret[0] is None:
  630. # Should I use a custom exception?
  631. raise RuntimeError(
  632. f"No inference backend and configuration could be suggested. Reason: {ret[1]}"
  633. )
  634. backend, backend_config = ret
  635. else:
  636. backend = self._config.backend
  637. if backend is None:
  638. raise RuntimeError(
  639. "When automatic configuration is not used, the inference backend must be specified manually."
  640. )
  641. backend_config = self._config.backend_config or {}
  642. if backend == "paddle" and not backend_config:
  643. logging.warning(
  644. "The Paddle Inference backend is selected with the default configuration. This may not provide optimal performance."
  645. )
  646. return backend, backend_config
  647. def _build_paddle_infer(self, backend_config):
  648. kwargs = {
  649. "device_type": self._config.device_type,
  650. "device_id": self._config.device_id,
  651. **backend_config,
  652. }
  653. # TODO: This is probably redundant. Can we reuse the code in the
  654. # predictor class?
  655. paddle_info = self._config.hpi_info.backend_configs.paddle_infer
  656. if paddle_info is not None:
  657. if (
  658. kwargs.get("trt_dynamic_shapes") is None
  659. and paddle_info.trt_dynamic_shapes is not None
  660. ):
  661. trt_dynamic_shapes = paddle_info.trt_dynamic_shapes
  662. logging.debug("TensorRT dynamic shapes set to %s", trt_dynamic_shapes)
  663. kwargs["trt_dynamic_shapes"] = trt_dynamic_shapes
  664. if (
  665. kwargs.get("trt_dynamic_shape_input_data") is None
  666. and paddle_info.trt_dynamic_shape_input_data is not None
  667. ):
  668. trt_dynamic_shape_input_data = paddle_info.trt_dynamic_shape_input_data
  669. logging.debug(
  670. "TensorRT dynamic shape input data set to %s",
  671. trt_dynamic_shape_input_data,
  672. )
  673. kwargs["trt_dynamic_shape_input_data"] = trt_dynamic_shape_input_data
  674. pp_option = PaddlePredictorOption(self._config.pdx_model_name, **kwargs)
  675. logging.info("Using Paddle Inference backend")
  676. logging.info("Paddle predictor option: %s", pp_option)
  677. return PaddleInfer(self._model_dir, self._model_file_prefix, option=pp_option)
  678. def _build_ui_runtime(self, backend, backend_config, ui_option=None):
  679. from ultra_infer import ModelFormat, Runtime, RuntimeOption
  680. if ui_option is None:
  681. ui_option = RuntimeOption()
  682. if self._config.device_type == "cpu":
  683. pass
  684. elif self._config.device_type == "gpu":
  685. ui_option.use_gpu(self._config.device_id or 0)
  686. elif self._config.device_type == "npu":
  687. ui_option.use_ascend(self._config.device_id or 0)
  688. else:
  689. raise RuntimeError(
  690. f"Unsupported device type {repr(self._config.device_type)}"
  691. )
  692. model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
  693. if backend in ("openvino", "onnxruntime", "tensorrt"):
  694. # XXX: This introduces side effects.
  695. if "onnx" not in model_paths:
  696. if self._config.auto_paddle2onnx:
  697. if "paddle" not in model_paths:
  698. raise RuntimeError("PaddlePaddle model required")
  699. # The CLI is used here since there is currently no API.
  700. logging.info(
  701. "Automatically converting PaddlePaddle model to ONNX format"
  702. )
  703. subprocess.check_call(
  704. [
  705. "paddlex",
  706. "--paddle2onnx",
  707. "--paddle_model_dir",
  708. self._model_dir,
  709. "--onnx_model_dir",
  710. self._model_dir,
  711. ]
  712. )
  713. model_paths = get_model_paths(
  714. self.model_dir, self.model_file_prefix
  715. )
  716. assert "onnx" in model_paths
  717. else:
  718. raise RuntimeError("ONNX model required")
  719. ui_option.set_model_path(str(model_paths["onnx"]), "", ModelFormat.ONNX)
  720. elif backend == "om":
  721. if "om" not in model_paths:
  722. raise RuntimeError("OM model required")
  723. ui_option.set_model_path(str(model_paths["om"]), "", ModelFormat.OM)
  724. else:
  725. raise ValueError(f"Unsupported inference backend {repr(backend)}")
  726. if backend == "openvino":
  727. backend_config = OpenVINOConfig.model_validate(backend_config)
  728. ui_option.use_openvino_backend()
  729. ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
  730. elif backend == "onnxruntime":
  731. backend_config = ONNXRuntimeConfig.model_validate(backend_config)
  732. ui_option.use_ort_backend()
  733. ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
  734. elif backend == "tensorrt":
  735. if (
  736. backend_config.get("use_dynamic_shapes", True)
  737. and backend_config.get("dynamic_shapes") is None
  738. ):
  739. trt_info = self._config.hpi_info.backend_configs.tensorrt
  740. if trt_info is not None and trt_info.dynamic_shapes is not None:
  741. trt_dynamic_shapes = trt_info.dynamic_shapes
  742. logging.debug(
  743. "TensorRT dynamic shapes set to %s", trt_dynamic_shapes
  744. )
  745. backend_config = {
  746. **backend_config,
  747. "dynamic_shapes": trt_dynamic_shapes,
  748. }
  749. backend_config = TensorRTConfig.model_validate(backend_config)
  750. ui_option.use_trt_backend()
  751. cache_dir = self.model_dir / CACHE_DIR / "tensorrt"
  752. cache_dir.mkdir(parents=True, exist_ok=True)
  753. ui_option.trt_option.serialize_file = str(cache_dir / "trt_serialized.trt")
  754. if backend_config.precision == "FP16":
  755. ui_option.trt_option.enable_fp16 = True
  756. if not backend_config.use_dynamic_shapes:
  757. raise RuntimeError(
  758. "TensorRT static shape inference is currently not supported"
  759. )
  760. if backend_config.dynamic_shapes is not None:
  761. if not Path(ui_option.trt_option.serialize_file).exists():
  762. for name, shapes in backend_config.dynamic_shapes.items():
  763. ui_option.trt_option.set_shape(name, *shapes)
  764. else:
  765. logging.warning(
  766. "TensorRT dynamic shapes will be loaded from the file."
  767. )
  768. elif backend == "om":
  769. backend_config = OMConfig.model_validate(backend_config)
  770. ui_option.use_om_backend()
  771. else:
  772. raise ValueError(f"Unsupported inference backend {repr(backend)}")
  773. logging.info("Inference backend: %s", backend)
  774. logging.info("Inference backend config: %s", backend_config)
  775. ui_runtime = Runtime(ui_option)
  776. return ui_runtime