static_infer.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import importlib.util
  16. import subprocess
  17. from typing import Sequence, List
  18. from pathlib import Path
  19. import lazy_paddle as paddle
  20. import numpy as np
  21. from ....utils import logging
  22. from ....utils.device import constr_device
  23. from ....utils.flags import (
  24. DEBUG,
  25. USE_PIR_TRT,
  26. INFER_BENCHMARK_USE_NEW_INFER_API,
  27. )
  28. from ...utils.benchmark import benchmark, set_inference_operations
  29. from ...utils.hpi import (
  30. HPIConfig,
  31. ONNXRuntimeConfig,
  32. OpenVINOConfig,
  33. TensorRTConfig,
  34. OMConfig,
  35. get_model_paths,
  36. suggest_inference_backend_and_config,
  37. )
  38. from ...utils.pp_option import PaddlePredictorOption
  39. from ...utils.trt_config import DISABLE_TRT_HALF_OPS_CONFIG
  40. CACHE_DIR = ".cache"
  41. INFERENCE_OPERATIONS = [
  42. "PaddleCopyToDevice",
  43. "PaddleCopyToHost",
  44. "PaddleModelInfer",
  45. "PaddleInferChainLegacy",
  46. "MultiBackendInfer",
  47. ]
  48. set_inference_operations(INFERENCE_OPERATIONS)
  49. # XXX: Better use Paddle Inference API to do this
  50. def _pd_dtype_to_np_dtype(pd_dtype):
  51. if pd_dtype == paddle.inference.DataType.FLOAT64:
  52. return np.float64
  53. elif pd_dtype == paddle.inference.DataType.FLOAT32:
  54. return np.float32
  55. elif pd_dtype == paddle.inference.DataType.INT64:
  56. return np.int64
  57. elif pd_dtype == paddle.inference.DataType.INT32:
  58. return np.int32
  59. elif pd_dtype == paddle.inference.DataType.UINT8:
  60. return np.uint8
  61. elif pd_dtype == paddle.inference.DataType.INT8:
  62. return np.int8
  63. else:
  64. raise TypeError(f"Unsupported data type: {pd_dtype}")
  65. # old trt
  66. def _collect_trt_shape_range_info(
  67. model_file,
  68. model_params,
  69. gpu_id,
  70. shape_range_info_path,
  71. dynamic_shapes,
  72. dynamic_shape_input_data,
  73. ):
  74. dynamic_shape_input_data = dynamic_shape_input_data or {}
  75. config = paddle.inference.Config(model_file, model_params)
  76. config.enable_use_gpu(100, gpu_id)
  77. config.collect_shape_range_info(shape_range_info_path)
  78. # TODO: Add other needed options
  79. config.disable_glog_info()
  80. predictor = paddle.inference.create_predictor(config)
  81. input_names = predictor.get_input_names()
  82. for name in dynamic_shapes:
  83. if name not in input_names:
  84. raise ValueError(
  85. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  86. )
  87. for name in input_names:
  88. if name not in dynamic_shapes:
  89. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  90. for name in dynamic_shape_input_data:
  91. if name not in input_names:
  92. raise ValueError(
  93. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  94. )
  95. # It would be better to check if the shapes are valid.
  96. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  97. for name, candidate_shapes in dynamic_shapes.items():
  98. # XXX: Currently we have no way to get the data type of the tensor
  99. # without creating an input handle.
  100. handle = predictor.get_input_handle(name)
  101. dtype = _pd_dtype_to_np_dtype(handle.type())
  102. min_shape, opt_shape, max_shape = candidate_shapes
  103. if name in dynamic_shape_input_data:
  104. min_arrs[name] = np.array(
  105. dynamic_shape_input_data[name][0], dtype=dtype
  106. ).reshape(min_shape)
  107. opt_arrs[name] = np.array(
  108. dynamic_shape_input_data[name][1], dtype=dtype
  109. ).reshape(opt_shape)
  110. max_arrs[name] = np.array(
  111. dynamic_shape_input_data[name][2], dtype=dtype
  112. ).reshape(max_shape)
  113. else:
  114. min_arrs[name] = np.ones(min_shape, dtype=dtype)
  115. opt_arrs[name] = np.ones(opt_shape, dtype=dtype)
  116. max_arrs[name] = np.ones(max_shape, dtype=dtype)
  117. # `opt_arrs` is used twice to ensure it is the most frequently used.
  118. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  119. for name, arr in arrs.items():
  120. handle = predictor.get_input_handle(name)
  121. handle.reshape(arr.shape)
  122. handle.copy_from_cpu(arr)
  123. predictor.run()
  124. # HACK: The shape range info will be written to the file only when
  125. # `predictor` is garbage collected. It works in CPython, but it is
  126. # definitely a bad idea to count on the implementation-dependent behavior of
  127. # a garbage collector. Is there a more explicit and deterministic way to
  128. # handle this?
  129. # HACK: Manually delete the predictor to trigger its destructor, ensuring that the shape_range_info file would be saved.
  130. del predictor
  131. # pir trt
  132. def _convert_trt(
  133. trt_cfg_setting,
  134. pp_model_file,
  135. pp_params_file,
  136. trt_save_path,
  137. device_id,
  138. dynamic_shapes,
  139. dynamic_shape_input_data,
  140. ):
  141. from paddle.tensorrt.export import (
  142. Input,
  143. TensorRTConfig,
  144. convert,
  145. )
  146. def _set_trt_config():
  147. for attr_name in trt_cfg_setting:
  148. assert hasattr(
  149. trt_config, attr_name
  150. ), f"The `{type(trt_config)}` don't have the attribute `{attr_name}`!"
  151. setattr(trt_config, attr_name, trt_cfg_setting[attr_name])
  152. def _get_predictor(model_file, params_file):
  153. # HACK
  154. config = paddle.inference.Config(str(model_file), str(params_file))
  155. config.enable_use_gpu(100, device_id)
  156. # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
  157. config.disable_mkldnn()
  158. config.disable_glog_info()
  159. return paddle.inference.create_predictor(config)
  160. dynamic_shape_input_data = dynamic_shape_input_data or {}
  161. predictor = _get_predictor(pp_model_file, pp_params_file)
  162. input_names = predictor.get_input_names()
  163. for name in dynamic_shapes:
  164. if name not in input_names:
  165. raise ValueError(
  166. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  167. )
  168. for name in input_names:
  169. if name not in dynamic_shapes:
  170. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  171. for name in dynamic_shape_input_data:
  172. if name not in input_names:
  173. raise ValueError(
  174. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  175. )
  176. trt_inputs = []
  177. for name, candidate_shapes in dynamic_shapes.items():
  178. # XXX: Currently we have no way to get the data type of the tensor
  179. # without creating an input handle.
  180. handle = predictor.get_input_handle(name)
  181. dtype = _pd_dtype_to_np_dtype(handle.type())
  182. min_shape, opt_shape, max_shape = candidate_shapes
  183. if name in dynamic_shape_input_data:
  184. min_arr = np.array(dynamic_shape_input_data[name][0], dtype=dtype).reshape(
  185. min_shape
  186. )
  187. opt_arr = np.array(dynamic_shape_input_data[name][1], dtype=dtype).reshape(
  188. opt_shape
  189. )
  190. max_arr = np.array(dynamic_shape_input_data[name][2], dtype=dtype).reshape(
  191. max_shape
  192. )
  193. else:
  194. min_arr = np.ones(min_shape, dtype=dtype)
  195. opt_arr = np.ones(opt_shape, dtype=dtype)
  196. max_arr = np.ones(max_shape, dtype=dtype)
  197. # refer to: https://github.com/PolaKuma/Paddle/blob/3347f225bc09f2ec09802a2090432dd5cb5b6739/test/tensorrt/test_converter_model_resnet50.py
  198. trt_input = Input((min_arr, opt_arr, max_arr))
  199. trt_inputs.append(trt_input)
  200. # Create TensorRTConfig
  201. trt_config = TensorRTConfig(inputs=trt_inputs)
  202. _set_trt_config()
  203. trt_config.save_model_dir = str(trt_save_path)
  204. pp_model_path = str(pp_model_file.with_suffix(""))
  205. convert(pp_model_path, trt_config)
  206. def _sort_inputs(inputs, names):
  207. # NOTE: Adjust input tensors to match the sorted sequence.
  208. indices = sorted(range(len(names)), key=names.__getitem__)
  209. inputs = [inputs[indices.index(i)] for i in range(len(inputs))]
  210. return inputs
  211. def _concatenate(*callables):
  212. def _chain(x):
  213. for c in callables:
  214. x = c(x)
  215. return x
  216. return _chain
  217. @benchmark.timeit
  218. class PaddleCopyToDevice:
  219. def __init__(self, device_type, device_id):
  220. self.device_type = device_type
  221. self.device_id = device_id
  222. def __call__(self, arrs):
  223. device_id = [self.device_id] if self.device_id is not None else self.device_id
  224. device = constr_device(self.device_type, device_id)
  225. paddle_tensors = [paddle.to_tensor(i, place=device) for i in arrs]
  226. return paddle_tensors
  227. @benchmark.timeit
  228. class PaddleCopyToHost:
  229. def __call__(self, paddle_tensors):
  230. arrs = [i.numpy() for i in paddle_tensors]
  231. return arrs
  232. @benchmark.timeit
  233. class PaddleModelInfer:
  234. def __init__(self, predictor):
  235. super().__init__()
  236. self.predictor = predictor
  237. def __call__(self, x):
  238. return self.predictor.run(x)
  239. # FIXME: Name might be misleading
  240. @benchmark.timeit
  241. class PaddleInferChainLegacy:
  242. def __init__(self, predictor):
  243. self.predictor = predictor
  244. input_names = self.predictor.get_input_names()
  245. self.input_handles = []
  246. self.output_handles = []
  247. for input_name in input_names:
  248. input_handle = self.predictor.get_input_handle(input_name)
  249. self.input_handles.append(input_handle)
  250. output_names = self.predictor.get_output_names()
  251. for output_name in output_names:
  252. output_handle = self.predictor.get_output_handle(output_name)
  253. self.output_handles.append(output_handle)
  254. def __call__(self, x):
  255. for input_, input_handle in zip(x, self.input_handles):
  256. input_handle.reshape(input_.shape)
  257. input_handle.copy_from_cpu(input_)
  258. self.predictor.run()
  259. outputs = [o.copy_to_cpu() for o in self.output_handles]
  260. return outputs
  261. class StaticInfer(metaclass=abc.ABCMeta):
  262. @abc.abstractmethod
  263. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  264. raise NotImplementedError
  265. class PaddleInfer(StaticInfer):
  266. def __init__(
  267. self,
  268. model_dir: str,
  269. model_file_prefix: str,
  270. option: PaddlePredictorOption,
  271. ) -> None:
  272. super().__init__()
  273. self.model_dir = model_dir
  274. self.model_file_prefix = model_file_prefix
  275. self._option = option
  276. self.predictor = self._create()
  277. if INFER_BENCHMARK_USE_NEW_INFER_API:
  278. device_type = self._option.device_type
  279. device_type = "gpu" if device_type == "dcu" else device_type
  280. copy_to_device = PaddleCopyToDevice(device_type, self._option.device_id)
  281. copy_to_host = PaddleCopyToHost()
  282. model_infer = PaddleModelInfer(self.predictor)
  283. self.infer = _concatenate(copy_to_device, model_infer, copy_to_host)
  284. else:
  285. self.infer = PaddleInferChainLegacy(self.predictor)
  286. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  287. names = self.predictor.get_input_names()
  288. if len(names) != len(x):
  289. raise ValueError(
  290. f"The number of inputs does not match the model: {len(names)} vs {len(x)}"
  291. )
  292. # TODO:
  293. # Ensure that input tensors follow the model's input sequence without sorting.
  294. x = _sort_inputs(x, names)
  295. x = list(map(np.ascontiguousarray, x))
  296. pred = self.infer(x)
  297. return pred
  298. def _create(
  299. self,
  300. ):
  301. """_create"""
  302. model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
  303. if "paddle" not in model_paths:
  304. raise RuntimeError("No valid PaddlePaddle model found")
  305. model_file, params_file = model_paths["paddle"]
  306. if (
  307. self._option.model_name == "LaTeX_OCR_rec"
  308. and self._option.device_type == "cpu"
  309. ):
  310. import cpuinfo
  311. if (
  312. "GenuineIntel" in cpuinfo.get_cpu_info().get("vendor_id_raw", "")
  313. and self._option.run_mode != "mkldnn"
  314. ):
  315. logging.warning(
  316. "Now, the `LaTeX_OCR_rec` model only support `mkldnn` mode when running on Intel CPU devices. So using `mkldnn` instead."
  317. )
  318. self._option.run_mode = "mkldnn"
  319. logging.debug("`run_mode` updated to 'mkldnn'")
  320. if self._option.device_type == "cpu" and self._option.device_id is not None:
  321. self._option.device_id = None
  322. logging.debug("`device_id` has been set to None")
  323. if (
  324. self._option.device_type in ("gpu", "dcu")
  325. and self._option.device_id is None
  326. ):
  327. self._option.device_id = 0
  328. logging.debug("`device_id` has been set to 0")
  329. # for TRT
  330. if self._option.run_mode.startswith("trt"):
  331. assert self._option.device_type == "gpu"
  332. cache_dir = self.model_dir / CACHE_DIR / "paddle"
  333. config = self._configure_trt(
  334. model_file,
  335. params_file,
  336. cache_dir,
  337. )
  338. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  339. config.enable_use_gpu(100, self._option.device_id)
  340. # for Native Paddle and MKLDNN
  341. else:
  342. config = paddle.inference.Config(str(model_file), str(params_file))
  343. if self._option.device_type == "gpu":
  344. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  345. from paddle.inference import PrecisionType
  346. precision = (
  347. PrecisionType.Half
  348. if self._option.run_mode == "paddle_fp16"
  349. else PrecisionType.Float32
  350. )
  351. config.enable_use_gpu(100, self._option.device_id, precision)
  352. if hasattr(config, "enable_new_ir"):
  353. config.enable_new_ir(self._option.enable_new_ir)
  354. if hasattr(config, "enable_new_executor"):
  355. config.enable_new_executor()
  356. config.set_optimization_level(3)
  357. elif self._option.device_type == "npu":
  358. config.enable_custom_device("npu")
  359. if hasattr(config, "enable_new_executor"):
  360. config.enable_new_executor()
  361. elif self._option.device_type == "xpu":
  362. if hasattr(config, "enable_new_executor"):
  363. config.enable_new_executor()
  364. elif self._option.device_type == "mlu":
  365. config.enable_custom_device("mlu")
  366. if hasattr(config, "enable_new_executor"):
  367. config.enable_new_executor()
  368. elif self._option.device_type == "gcu":
  369. from paddle_custom_device.gcu import passes as gcu_passes
  370. gcu_passes.setUp()
  371. config.enable_custom_device("gcu")
  372. if hasattr(config, "enable_new_executor"):
  373. config.enable_new_ir()
  374. config.enable_new_executor()
  375. else:
  376. pass_builder = config.pass_builder()
  377. name = "PaddleX_" + self._option.model_name
  378. gcu_passes.append_passes_for_legacy_ir(pass_builder, name)
  379. elif self._option.device_type == "dcu":
  380. config.enable_use_gpu(100, self._option.device_id)
  381. if hasattr(config, "enable_new_executor"):
  382. config.enable_new_executor()
  383. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  384. if paddle.is_compiled_with_rocm():
  385. # Delete unsupported passes in dcu
  386. config.delete_pass("conv2d_add_act_fuse_pass")
  387. config.delete_pass("conv2d_add_fuse_pass")
  388. else:
  389. assert self._option.device_type == "cpu"
  390. config.disable_gpu()
  391. if "mkldnn" in self._option.run_mode:
  392. try:
  393. config.enable_mkldnn()
  394. if "bf16" in self._option.run_mode:
  395. config.enable_mkldnn_bfloat16()
  396. except Exception as e:
  397. logging.warning(
  398. "MKL-DNN is not available. We will disable MKL-DNN."
  399. )
  400. config.set_mkldnn_cache_capacity(-1)
  401. else:
  402. if hasattr(config, "disable_mkldnn"):
  403. config.disable_mkldnn()
  404. config.set_cpu_math_library_num_threads(self._option.cpu_threads)
  405. if hasattr(config, "enable_new_ir"):
  406. config.enable_new_ir(self._option.enable_new_ir)
  407. if hasattr(config, "enable_new_executor"):
  408. config.enable_new_executor()
  409. config.set_optimization_level(3)
  410. config.enable_memory_optim()
  411. for del_p in self._option.delete_pass:
  412. config.delete_pass(del_p)
  413. # Disable paddle inference logging
  414. if not DEBUG:
  415. config.disable_glog_info()
  416. predictor = paddle.inference.create_predictor(config)
  417. return predictor
  418. def _configure_trt(self, model_file, params_file, cache_dir):
  419. # TODO: Support calibration
  420. if USE_PIR_TRT:
  421. trt_save_path = cache_dir / "trt" / self.model_file_prefix
  422. _convert_trt(
  423. self._option.trt_cfg_setting,
  424. model_file,
  425. params_file,
  426. trt_save_path,
  427. self._option.device_id,
  428. self._option.trt_dynamic_shapes,
  429. self._option.trt_dynamic_shape_input_data,
  430. )
  431. model_file = trt_save_path.with_suffix(".json")
  432. params_file = trt_save_path.with_suffix(".pdiparams")
  433. config = paddle.inference.Config(str(model_file), str(params_file))
  434. else:
  435. config = paddle.inference.Config(str(model_file), str(params_file))
  436. config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
  437. # call enable_use_gpu() first to use TensorRT engine
  438. config.enable_use_gpu(100, self._option.device_id)
  439. for func_name in self._option.trt_cfg_setting:
  440. assert hasattr(
  441. config, func_name
  442. ), f"The `{type(config)}` don't have function `{func_name}`!"
  443. args = self._option.trt_cfg_setting[func_name]
  444. if isinstance(args, list):
  445. getattr(config, func_name)(*args)
  446. else:
  447. getattr(config, func_name)(**args)
  448. if self._option.trt_use_dynamic_shapes:
  449. if self._option.trt_collect_shape_range_info:
  450. # NOTE: We always use a shape range info file.
  451. if self._option.trt_shape_range_info_path is not None:
  452. trt_shape_range_info_path = Path(
  453. self._option.trt_shape_range_info_path
  454. )
  455. else:
  456. trt_shape_range_info_path = cache_dir / "shape_range_info.pbtxt"
  457. should_collect_shape_range_info = True
  458. if not trt_shape_range_info_path.exists():
  459. trt_shape_range_info_path.parent.mkdir(
  460. parents=True, exist_ok=True
  461. )
  462. logging.info(
  463. f"Shape range info will be collected into {trt_shape_range_info_path}"
  464. )
  465. elif self._option.trt_discard_cached_shape_range_info:
  466. trt_shape_range_info_path.unlink()
  467. logging.info(
  468. f"The shape range info file ({trt_shape_range_info_path}) has been removed, and the shape range info will be re-collected."
  469. )
  470. else:
  471. logging.info(
  472. f"A shape range info file ({trt_shape_range_info_path}) already exists. There is no need to collect the info again."
  473. )
  474. should_collect_shape_range_info = False
  475. if should_collect_shape_range_info:
  476. _collect_trt_shape_range_info(
  477. str(model_file),
  478. str(params_file),
  479. self._option.device_id,
  480. str(trt_shape_range_info_path),
  481. self._option.trt_dynamic_shapes,
  482. self._option.trt_dynamic_shape_input_data,
  483. )
  484. if (
  485. self._option.model_name in DISABLE_TRT_HALF_OPS_CONFIG
  486. and self._option.run_mode == "trt_fp16"
  487. ):
  488. paddle.inference.InternalUtils.disable_tensorrt_half_ops(
  489. config, DISABLE_TRT_HALF_OPS_CONFIG[self._option.model_name]
  490. )
  491. config.enable_tuned_tensorrt_dynamic_shape(
  492. str(trt_shape_range_info_path),
  493. self._option.trt_allow_rebuild_at_runtime,
  494. )
  495. else:
  496. if self._option.trt_dynamic_shapes is not None:
  497. min_shapes, opt_shapes, max_shapes = {}, {}, {}
  498. for (
  499. key,
  500. shapes,
  501. ) in self._option.trt_dynamic_shapes.items():
  502. min_shapes[key] = shapes[0]
  503. opt_shapes[key] = shapes[1]
  504. max_shapes[key] = shapes[2]
  505. config.set_trt_dynamic_shape_info(
  506. min_shapes, max_shapes, opt_shapes
  507. )
  508. else:
  509. raise RuntimeError("No dynamic shape information provided")
  510. return config
  511. # FIXME: Name might be misleading
  512. @benchmark.timeit
  513. class MultiBackendInfer(object):
  514. def __init__(self, ui_runtime):
  515. super().__init__()
  516. self.ui_runtime = ui_runtime
  517. # The time consumed by the wrapper code will also be taken into account.
  518. def __call__(self, x):
  519. outputs = self.ui_runtime.infer(x)
  520. return outputs
  521. # TODO: It would be better to refactor the code to make `HPInfer` a higher-level
  522. # class that uses `PaddleInfer`.
  523. class HPInfer(StaticInfer):
  524. def __init__(
  525. self,
  526. model_dir: str,
  527. model_file_prefix: str,
  528. config: HPIConfig,
  529. ) -> None:
  530. super().__init__()
  531. self._model_dir = model_dir
  532. self._model_file_prefix = model_file_prefix
  533. self._config = config
  534. backend, backend_config = self._determine_backend_and_config()
  535. if backend == "paddle":
  536. self._use_paddle = True
  537. self._paddle_infer = self._build_paddle_infer(backend_config)
  538. else:
  539. self._use_paddle = False
  540. ui_runtime = self._build_ui_runtime(backend, backend_config)
  541. self._multi_backend_infer = MultiBackendInfer(ui_runtime)
  542. num_inputs = ui_runtime.num_inputs()
  543. self._input_names = [
  544. ui_runtime.get_input_info(i).name for i in range(num_inputs)
  545. ]
  546. @property
  547. def model_dir(self) -> str:
  548. return self._model_dir
  549. @property
  550. def model_file_prefix(self) -> str:
  551. return self._model_file_prefix
  552. @property
  553. def config(self) -> HPIConfig:
  554. return self._config
  555. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  556. if self._use_paddle:
  557. return self._call_paddle_infer(x)
  558. else:
  559. return self._call_multi_backend_infer(x)
  560. def _call_paddle_infer(self, x):
  561. return self._paddle_infer(x)
  562. def _call_multi_backend_infer(self, x):
  563. num_inputs = len(self._input_names)
  564. if len(x) != num_inputs:
  565. raise ValueError(f"Expected {num_inputs} inputs but got {len(x)} instead")
  566. x = _sort_inputs(x, self._input_names)
  567. inputs = {}
  568. for name, input_ in zip(self._input_names, x):
  569. inputs[name] = np.ascontiguousarray(input_)
  570. return self._multi_backend_infer(inputs)
  571. def _determine_backend_and_config(self):
  572. from ultra_infer import (
  573. is_built_with_om,
  574. is_built_with_openvino,
  575. is_built_with_ort,
  576. is_built_with_trt,
  577. )
  578. model_paths = get_model_paths(self._model_dir, self._model_file_prefix)
  579. is_onnx_model_available = "onnx" in model_paths
  580. # TODO: Give a warning if Paddle2ONNX is not available but can be used
  581. # to select a better backend.
  582. if self._config.auto_paddle2onnx:
  583. if self._check_paddle2onnx():
  584. is_onnx_model_available = (
  585. is_onnx_model_available or "paddle" in model_paths
  586. )
  587. else:
  588. logging.debug(
  589. "Paddle2ONNX is not available. Automatic model conversion will not be performed."
  590. )
  591. available_backends = []
  592. if "paddle" in model_paths:
  593. available_backends.append("paddle")
  594. if is_built_with_openvino() and is_onnx_model_available:
  595. available_backends.append("openvino")
  596. if is_built_with_ort() and is_onnx_model_available:
  597. available_backends.append("onnxruntime")
  598. if is_built_with_trt() and is_onnx_model_available:
  599. available_backends.append("tensorrt")
  600. if is_built_with_om() and "om" in model_paths:
  601. available_backends.append("om")
  602. if not available_backends:
  603. raise RuntimeError("No inference backend is available")
  604. if (
  605. self._config.backend is not None
  606. and self._config.backend not in available_backends
  607. ):
  608. raise RuntimeError(
  609. f"Inference backend {repr(self._config.backend)} is unavailable"
  610. )
  611. if self._config.auto_config:
  612. # Should we use the strategy pattern here to allow extensible
  613. # strategies?
  614. ret = suggest_inference_backend_and_config(
  615. self._config, available_backends=available_backends
  616. )
  617. if ret[0] is None:
  618. # Should I use a custom exception?
  619. raise RuntimeError(
  620. f"No inference backend and configuration could be suggested. Reason: {ret[1]}"
  621. )
  622. backend, backend_config = ret
  623. else:
  624. backend = self._config.backend
  625. if backend is None:
  626. raise RuntimeError(
  627. "When automatic configuration is not used, the inference backend must be specified manually."
  628. )
  629. backend_config = self._config.backend_config or {}
  630. if backend == "paddle" and not backend_config:
  631. logging.warning(
  632. "The Paddle Inference backend is selected with the default configuration. This may not provide optimal performance."
  633. )
  634. return backend, backend_config
  635. def _build_paddle_infer(self, backend_config):
  636. kwargs = {
  637. "device_type": self._config.device_type,
  638. "device_id": self._config.device_id,
  639. **backend_config,
  640. }
  641. # TODO: This is probably redundant. Can we reuse the code in the
  642. # predictor class?
  643. paddle_info = self._config.hpi_info.backend_configs.paddle_infer
  644. if paddle_info is not None:
  645. if (
  646. kwargs.get("trt_dynamic_shapes") is None
  647. and paddle_info.trt_dynamic_shapes is not None
  648. ):
  649. trt_dynamic_shapes = paddle_info.trt_dynamic_shapes
  650. logging.debug("TensorRT dynamic shapes set to %s", trt_dynamic_shapes)
  651. kwargs["trt_dynamic_shapes"] = trt_dynamic_shapes
  652. if (
  653. kwargs.get("trt_dynamic_shape_input_data") is None
  654. and paddle_info.trt_dynamic_shape_input_data is not None
  655. ):
  656. trt_dynamic_shape_input_data = paddle_info.trt_dynamic_shape_input_data
  657. logging.debug(
  658. "TensorRT dynamic shape input data set to %s",
  659. trt_dynamic_shape_input_data,
  660. )
  661. kwargs["trt_dynamic_shape_input_data"] = trt_dynamic_shape_input_data
  662. pp_option = PaddlePredictorOption(self._config.pdx_model_name, **kwargs)
  663. logging.info("Using Paddle Inference backend")
  664. logging.info("Paddle predictor option: %s", pp_option)
  665. return PaddleInfer(self._model_dir, self._model_file_prefix, option=pp_option)
  666. def _build_ui_runtime(self, backend, backend_config, ui_option=None):
  667. from ultra_infer import ModelFormat, Runtime, RuntimeOption
  668. if ui_option is None:
  669. ui_option = RuntimeOption()
  670. if self._config.device_type == "cpu":
  671. pass
  672. elif self._config.device_type == "gpu":
  673. ui_option.use_gpu(self._config.device_id or 0)
  674. elif self._config.device_type == "npu":
  675. ui_option.use_ascend()
  676. else:
  677. raise RuntimeError(
  678. f"Unsupported device type {repr(self._config.device_type)}"
  679. )
  680. model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
  681. if backend in ("openvino", "onnxruntime", "tensorrt"):
  682. # XXX: This introduces side effects.
  683. if "onnx" not in model_paths:
  684. if self._config.auto_paddle2onnx:
  685. if "paddle" not in model_paths:
  686. raise RuntimeError("PaddlePaddle model required")
  687. # The CLI is used here since there is currently no API.
  688. logging.info(
  689. "Automatically converting PaddlePaddle model to ONNX format"
  690. )
  691. subprocess.check_call(
  692. [
  693. "paddlex",
  694. "--paddle2onnx",
  695. "--paddle_model_dir",
  696. self._model_dir,
  697. "--onnx_model_dir",
  698. self._model_dir,
  699. ]
  700. )
  701. model_paths = get_model_paths(
  702. self.model_dir, self.model_file_prefix
  703. )
  704. assert "onnx" in model_paths
  705. else:
  706. raise RuntimeError("ONNX model required")
  707. ui_option.set_model_path(str(model_paths["onnx"]), "", ModelFormat.ONNX)
  708. elif backend == "om":
  709. if "om" not in model_paths:
  710. raise RuntimeError("OM model required")
  711. ui_option.set_model_path(str(model_paths["om"]), "", ModelFormat.OM)
  712. else:
  713. raise ValueError(f"Unsupported inference backend {repr(backend)}")
  714. if backend == "openvino":
  715. backend_config = OpenVINOConfig.model_validate(backend_config)
  716. ui_option.use_openvino_backend()
  717. ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
  718. elif backend == "onnxruntime":
  719. backend_config = ONNXRuntimeConfig.model_validate(backend_config)
  720. ui_option.use_ort_backend()
  721. ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
  722. elif backend == "tensorrt":
  723. if (
  724. backend_config.get("use_dynamic_shapes", True)
  725. and backend_config.get("dynamic_shapes") is None
  726. ):
  727. trt_info = self._config.hpi_info.backend_configs.tensorrt
  728. if trt_info is not None and trt_info.dynamic_shapes is not None:
  729. trt_dynamic_shapes = trt_info.dynamic_shapes
  730. logging.debug(
  731. "TensorRT dynamic shapes set to %s", trt_dynamic_shapes
  732. )
  733. backend_config = {
  734. **backend_config,
  735. "dynamic_shapes": trt_dynamic_shapes,
  736. }
  737. backend_config = TensorRTConfig.model_validate(backend_config)
  738. ui_option.use_trt_backend()
  739. cache_dir = self.model_dir / CACHE_DIR / "tensorrt"
  740. cache_dir.mkdir(parents=True, exist_ok=True)
  741. ui_option.trt_option.serialize_file = str(cache_dir / "trt_serialized.trt")
  742. if backend_config.precision == "FP16":
  743. ui_option.trt_option.enable_fp16 = True
  744. if not backend_config.use_dynamic_shapes:
  745. raise RuntimeError(
  746. "TensorRT static shape inference is currently not supported"
  747. )
  748. if backend_config.dynamic_shapes is not None:
  749. if not Path(ui_option.trt_option.serialize_file).exists():
  750. for name, shapes in backend_config.dynamic_shapes.items():
  751. ui_option.trt_option.set_shape(name, *shapes)
  752. else:
  753. logging.warning(
  754. "TensorRT dynamic shapes will be loaded from the file."
  755. )
  756. elif backend == "om":
  757. backend_config = OMConfig.model_validate(backend_config)
  758. ui_option.use_om_backend()
  759. else:
  760. raise ValueError(f"Unsupported inference backend {repr(backend)}")
  761. logging.info("Inference backend: %s", backend)
  762. logging.info("Inference backend config: %s", backend_config)
  763. ui_runtime = Runtime(ui_option)
  764. return ui_runtime
  765. def _check_paddle2onnx(self):
  766. # HACK
  767. return importlib.util.find_spec("paddle2onnx") is not None