static_infer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Sequence, List
  15. from pathlib import Path
  16. import lazy_paddle
  17. import numpy as np
  18. from ....utils import logging
  19. from ....utils.device import constr_device
  20. from ....utils.flags import (
  21. DEBUG,
  22. USE_PIR_TRT,
  23. INFER_BENCHMARK_USE_NEW_INFER_API,
  24. )
  25. from ...utils.benchmark import benchmark, set_inference_operations
  26. from ...utils.hpi import get_model_paths
  27. from ...utils.pp_option import PaddlePredictorOption
  28. from ...utils.trt_config import DISABLE_TRT_HALF_OPS_CONFIG
  29. CACHE_DIR = ".cache"
  30. if INFER_BENCHMARK_USE_NEW_INFER_API:
  31. INFERENCE_OPERATIONS = [
  32. "PaddleCopyToDevice",
  33. "PaddleCopyToHost",
  34. "PaddleModelInfer",
  35. ]
  36. else:
  37. INFERENCE_OPERATIONS = ["PaddleInferChainLegacy"]
  38. set_inference_operations(INFERENCE_OPERATIONS)
  39. # XXX: Better use Paddle Inference API to do this
  40. def _pd_dtype_to_np_dtype(pd_dtype):
  41. if pd_dtype == lazy_paddle.inference.DataType.FLOAT64:
  42. return np.float64
  43. elif pd_dtype == lazy_paddle.inference.DataType.FLOAT32:
  44. return np.float32
  45. elif pd_dtype == lazy_paddle.inference.DataType.INT64:
  46. return np.int64
  47. elif pd_dtype == lazy_paddle.inference.DataType.INT32:
  48. return np.int32
  49. elif pd_dtype == lazy_paddle.inference.DataType.UINT8:
  50. return np.uint8
  51. elif pd_dtype == lazy_paddle.inference.DataType.INT8:
  52. return np.int8
  53. else:
  54. raise TypeError(f"Unsupported data type: {pd_dtype}")
  55. # old trt
  56. def _collect_trt_shape_range_info(
  57. model_file,
  58. model_params,
  59. gpu_id,
  60. shape_range_info_path,
  61. dynamic_shapes,
  62. dynamic_shape_input_data,
  63. ):
  64. dynamic_shape_input_data = dynamic_shape_input_data or {}
  65. config = lazy_paddle.inference.Config(model_file, model_params)
  66. config.enable_use_gpu(100, gpu_id)
  67. config.collect_shape_range_info(shape_range_info_path)
  68. # TODO: Add other needed options
  69. config.disable_glog_info()
  70. predictor = lazy_paddle.inference.create_predictor(config)
  71. input_names = predictor.get_input_names()
  72. for name in dynamic_shapes:
  73. if name not in input_names:
  74. raise ValueError(
  75. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  76. )
  77. for name in input_names:
  78. if name not in dynamic_shapes:
  79. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  80. for name in dynamic_shape_input_data:
  81. if name not in input_names:
  82. raise ValueError(
  83. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  84. )
  85. # It would be better to check if the shapes are valid.
  86. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  87. for name, candidate_shapes in dynamic_shapes.items():
  88. # XXX: Currently we have no way to get the data type of the tensor
  89. # without creating an input handle.
  90. handle = predictor.get_input_handle(name)
  91. dtype = _pd_dtype_to_np_dtype(handle.type())
  92. min_shape, opt_shape, max_shape = candidate_shapes
  93. if name in dynamic_shape_input_data:
  94. min_arrs[name] = np.array(
  95. dynamic_shape_input_data[name][0], dtype=dtype
  96. ).reshape(min_shape)
  97. opt_arrs[name] = np.array(
  98. dynamic_shape_input_data[name][1], dtype=dtype
  99. ).reshape(opt_shape)
  100. max_arrs[name] = np.array(
  101. dynamic_shape_input_data[name][2], dtype=dtype
  102. ).reshape(max_shape)
  103. else:
  104. min_arrs[name] = np.ones(min_shape, dtype=dtype)
  105. opt_arrs[name] = np.ones(opt_shape, dtype=dtype)
  106. max_arrs[name] = np.ones(max_shape, dtype=dtype)
  107. # `opt_arrs` is used twice to ensure it is the most frequently used.
  108. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  109. for name, arr in arrs.items():
  110. handle = predictor.get_input_handle(name)
  111. handle.reshape(arr.shape)
  112. handle.copy_from_cpu(arr)
  113. predictor.run()
  114. # HACK: The shape range info will be written to the file only when
  115. # `predictor` is garbage collected. It works in CPython, but it is
  116. # definitely a bad idea to count on the implementation-dependent behavior of
  117. # a garbage collector. Is there a more explicit and deterministic way to
  118. # handle this?
  119. # HACK: Manually delete the predictor to trigger its destructor, ensuring that the shape_range_info file would be saved.
  120. del predictor
  121. # pir trt
  122. def _convert_trt(
  123. trt_cfg_setting,
  124. pp_model_file,
  125. pp_params_file,
  126. trt_save_path,
  127. device_id,
  128. dynamic_shapes,
  129. dynamic_shape_input_data,
  130. ):
  131. from lazy_paddle.tensorrt.export import (
  132. Input,
  133. TensorRTConfig,
  134. convert,
  135. )
  136. def _set_trt_config():
  137. for attr_name in trt_cfg_setting:
  138. assert hasattr(
  139. trt_config, attr_name
  140. ), f"The `{type(trt_config)}` don't have the attribute `{attr_name}`!"
  141. setattr(trt_config, attr_name, trt_cfg_setting[attr_name])
  142. def _get_predictor(model_file, params_file):
  143. # HACK
  144. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  145. config.enable_use_gpu(100, device_id)
  146. # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
  147. config.disable_mkldnn()
  148. config.disable_glog_info()
  149. return lazy_paddle.inference.create_predictor(config)
  150. dynamic_shape_input_data = dynamic_shape_input_data or {}
  151. predictor = _get_predictor(pp_model_file, pp_params_file)
  152. input_names = predictor.get_input_names()
  153. for name in dynamic_shapes:
  154. if name not in input_names:
  155. raise ValueError(
  156. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  157. )
  158. for name in input_names:
  159. if name not in dynamic_shapes:
  160. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  161. for name in dynamic_shape_input_data:
  162. if name not in input_names:
  163. raise ValueError(
  164. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  165. )
  166. trt_inputs = []
  167. for name, candidate_shapes in dynamic_shapes.items():
  168. # XXX: Currently we have no way to get the data type of the tensor
  169. # without creating an input handle.
  170. handle = predictor.get_input_handle(name)
  171. dtype = _pd_dtype_to_np_dtype(handle.type())
  172. min_shape, opt_shape, max_shape = candidate_shapes
  173. if name in dynamic_shape_input_data:
  174. min_arr = np.array(dynamic_shape_input_data[name][0], dtype=dtype).reshape(
  175. min_shape
  176. )
  177. opt_arr = np.array(dynamic_shape_input_data[name][1], dtype=dtype).reshape(
  178. opt_shape
  179. )
  180. max_arr = np.array(dynamic_shape_input_data[name][2], dtype=dtype).reshape(
  181. max_shape
  182. )
  183. else:
  184. min_arr = np.ones(min_shape, dtype=dtype)
  185. opt_arr = np.ones(opt_shape, dtype=dtype)
  186. max_arr = np.ones(max_shape, dtype=dtype)
  187. # refer to: https://github.com/PolaKuma/Paddle/blob/3347f225bc09f2ec09802a2090432dd5cb5b6739/test/tensorrt/test_converter_model_resnet50.py
  188. trt_input = Input((min_arr, opt_arr, max_arr))
  189. trt_inputs.append(trt_input)
  190. # Create TensorRTConfig
  191. trt_config = TensorRTConfig(inputs=trt_inputs)
  192. _set_trt_config()
  193. trt_config.save_model_dir = str(trt_save_path)
  194. pp_model_path = str(pp_model_file.with_suffix(""))
  195. convert(pp_model_path, trt_config)
  196. def _sort_inputs(inputs, names):
  197. # NOTE: Adjust input tensors to match the sorted sequence.
  198. indices = sorted(range(len(names)), key=names.__getitem__)
  199. inputs = [inputs[indices.index(i)] for i in range(len(inputs))]
  200. return inputs
  201. def _concatenate(*callables):
  202. def _chain(x):
  203. for c in callables:
  204. x = c(x)
  205. return x
  206. return _chain
  207. @benchmark.timeit
  208. class PaddleCopyToDevice:
  209. def __init__(self, device_type, device_id):
  210. self.device_type = device_type
  211. self.device_id = device_id
  212. def __call__(self, arrs):
  213. device_id = [self.device_id] if self.device_id is not None else self.device_id
  214. device = constr_device(self.device_type, device_id)
  215. paddle_tensors = [lazy_paddle.to_tensor(i, place=device) for i in arrs]
  216. return paddle_tensors
  217. @benchmark.timeit
  218. class PaddleCopyToHost:
  219. def __call__(self, paddle_tensors):
  220. arrs = [i.numpy() for i in paddle_tensors]
  221. return arrs
  222. @benchmark.timeit
  223. class PaddleModelInfer:
  224. def __init__(self, predictor):
  225. super().__init__()
  226. self.predictor = predictor
  227. def __call__(self, x):
  228. return self.predictor.run(x)
  229. # FIXME: Name might be misleading
  230. @benchmark.timeit
  231. class PaddleInferChainLegacy:
  232. def __init__(self, predictor):
  233. self.predictor = predictor
  234. input_names = self.predictor.get_input_names()
  235. self.input_handles = []
  236. self.output_handles = []
  237. for input_name in input_names:
  238. input_handle = self.predictor.get_input_handle(input_name)
  239. self.input_handles.append(input_handle)
  240. output_names = self.predictor.get_output_names()
  241. for output_name in output_names:
  242. output_handle = self.predictor.get_output_handle(output_name)
  243. self.output_handles.append(output_handle)
  244. def __call__(self, x):
  245. for input_, input_handle in zip(x, self.input_handles):
  246. input_handle.reshape(input_.shape)
  247. input_handle.copy_from_cpu(input_)
  248. self.predictor.run()
  249. outputs = [o.copy_to_cpu() for o in self.output_handles]
  250. return outputs
  251. class StaticInfer(object):
  252. def __init__(
  253. self,
  254. model_dir: str,
  255. model_prefix: str,
  256. option: PaddlePredictorOption,
  257. ) -> None:
  258. super().__init__()
  259. self.model_dir = model_dir
  260. self.model_file_prefix = model_prefix
  261. self._option = option
  262. self.predictor = self._create()
  263. if self._use_new_inference_api:
  264. device_type = self._option.device_type
  265. device_type = "gpu" if device_type == "dcu" else device_type
  266. copy_to_device = PaddleCopyToDevice(device_type, self._option.device_id)
  267. copy_to_host = PaddleCopyToHost()
  268. model_infer = PaddleModelInfer(self.predictor)
  269. self.infer = _concatenate(copy_to_device, model_infer, copy_to_host)
  270. else:
  271. self.infer = PaddleInferChainLegacy(self.predictor)
  272. @property
  273. def _use_new_inference_api(self):
  274. # HACK: Temp fallback to legacy API via env var
  275. return INFER_BENCHMARK_USE_NEW_INFER_API
  276. # return self._option.device_type in ("cpu", "gpu", "dcu")
  277. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  278. names = self.predictor.get_input_names()
  279. if len(names) != len(x):
  280. raise ValueError(
  281. f"The number of inputs does not match the model: {len(names)} vs {len(x)}"
  282. )
  283. # TODO:
  284. # Ensure that input tensors follow the model's input sequence without sorting.
  285. x = _sort_inputs(x, names)
  286. x = list(map(np.ascontiguousarray, x))
  287. pred = self.infer(x)
  288. return pred
  289. def _create(
  290. self,
  291. ):
  292. """_create"""
  293. model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
  294. if "paddle" not in model_paths:
  295. raise RuntimeError("No valid Paddle model found")
  296. model_file, params_file = model_paths["paddle"]
  297. if (
  298. self._option.model_name == "LaTeX_OCR_rec"
  299. and self._option.device_type == "cpu"
  300. ):
  301. import cpuinfo
  302. if (
  303. "GenuineIntel" in cpuinfo.get_cpu_info().get("vendor_id_raw", "")
  304. and self._option.run_mode != "mkldnn"
  305. ):
  306. logging.warning(
  307. "Now, the `LaTeX_OCR_rec` model only support `mkldnn` mode when running on Intel CPU devices. So using `mkldnn` instead."
  308. )
  309. self._option.run_mode = "mkldnn"
  310. logging.debug("`run_mode` updated to 'mkldnn'")
  311. if self._option.device_type == "cpu" and self._option.device_id is not None:
  312. self._option.device_id = None
  313. logging.debug("`device_id` has been set to None")
  314. if (
  315. self._option.device_type in ("gpu", "dcu")
  316. and self._option.device_id is None
  317. ):
  318. self._option.device_id = 0
  319. logging.debug("`device_id` has been set to 0")
  320. # for TRT
  321. if self._option.run_mode.startswith("trt"):
  322. assert self._option.device_type == "gpu"
  323. cache_dir = self.model_dir / CACHE_DIR / "paddle"
  324. config = self._configure_trt(
  325. model_file,
  326. params_file,
  327. cache_dir,
  328. )
  329. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  330. config.enable_use_gpu(100, self._option.device_id)
  331. # for Native Paddle and MKLDNN
  332. else:
  333. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  334. if self._option.device_type == "gpu":
  335. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  336. from lazy_paddle.inference import PrecisionType
  337. precision = (
  338. PrecisionType.Half
  339. if self._option.run_mode == "paddle_fp16"
  340. else PrecisionType.Float32
  341. )
  342. config.enable_use_gpu(100, self._option.device_id, precision)
  343. if hasattr(config, "enable_new_ir"):
  344. config.enable_new_ir(self._option.enable_new_ir)
  345. if hasattr(config, "enable_new_executor"):
  346. config.enable_new_executor()
  347. config.set_optimization_level(3)
  348. elif self._option.device_type == "npu":
  349. config.enable_custom_device("npu")
  350. if hasattr(config, "enable_new_executor"):
  351. config.enable_new_executor()
  352. elif self._option.device_type == "xpu":
  353. if hasattr(config, "enable_new_executor"):
  354. config.enable_new_executor()
  355. elif self._option.device_type == "mlu":
  356. config.enable_custom_device("mlu")
  357. if hasattr(config, "enable_new_executor"):
  358. config.enable_new_executor()
  359. elif self._option.device_type == "gcu":
  360. from paddle_custom_device.gcu import passes as gcu_passes
  361. gcu_passes.setUp()
  362. config.enable_custom_device("gcu")
  363. if hasattr(config, "enable_new_executor"):
  364. config.enable_new_ir()
  365. config.enable_new_executor()
  366. else:
  367. pass_builder = config.pass_builder()
  368. name = "PaddleX_" + self._option.model_name
  369. gcu_passes.append_passes_for_legacy_ir(pass_builder, name)
  370. elif self._option.device_type == "dcu":
  371. config.enable_use_gpu(100, self._option.device_id)
  372. if hasattr(config, "enable_new_executor"):
  373. config.enable_new_executor()
  374. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  375. if lazy_paddle.is_compiled_with_rocm():
  376. # Delete unsupported passes in dcu
  377. config.delete_pass("conv2d_add_act_fuse_pass")
  378. config.delete_pass("conv2d_add_fuse_pass")
  379. else:
  380. assert self._option.device_type == "cpu"
  381. config.disable_gpu()
  382. if "mkldnn" in self._option.run_mode:
  383. try:
  384. config.enable_mkldnn()
  385. if "bf16" in self._option.run_mode:
  386. config.enable_mkldnn_bfloat16()
  387. except Exception as e:
  388. logging.warning(
  389. "MKL-DNN is not available. We will disable MKL-DNN."
  390. )
  391. config.set_mkldnn_cache_capacity(-1)
  392. else:
  393. if hasattr(config, "disable_mkldnn"):
  394. config.disable_mkldnn()
  395. config.set_cpu_math_library_num_threads(self._option.cpu_threads)
  396. if hasattr(config, "enable_new_ir"):
  397. config.enable_new_ir(self._option.enable_new_ir)
  398. if hasattr(config, "enable_new_executor"):
  399. config.enable_new_executor()
  400. config.set_optimization_level(3)
  401. config.enable_memory_optim()
  402. for del_p in self._option.delete_pass:
  403. config.delete_pass(del_p)
  404. # Disable paddle inference logging
  405. if not DEBUG:
  406. config.disable_glog_info()
  407. predictor = lazy_paddle.inference.create_predictor(config)
  408. return predictor
  409. def _configure_trt(self, model_file, params_file, cache_dir):
  410. # TODO: Support calibration
  411. if USE_PIR_TRT:
  412. trt_save_path = cache_dir / "trt" / self.model_file_prefix
  413. _convert_trt(
  414. self._option.trt_cfg_setting,
  415. model_file,
  416. params_file,
  417. trt_save_path,
  418. self._option.device_id,
  419. self._option.trt_dynamic_shapes,
  420. self._option.trt_dynamic_shape_input_data,
  421. )
  422. model_file = trt_save_path.with_suffix(".json")
  423. params_file = trt_save_path.with_suffix(".pdiparams")
  424. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  425. else:
  426. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  427. config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
  428. # call enable_use_gpu() first to use TensorRT engine
  429. config.enable_use_gpu(100, self._option.device_id)
  430. for func_name in self._option.trt_cfg_setting:
  431. assert hasattr(
  432. config, func_name
  433. ), f"The `{type(config)}` don't have function `{func_name}`!"
  434. args = self._option.trt_cfg_setting[func_name]
  435. if isinstance(args, list):
  436. getattr(config, func_name)(*args)
  437. else:
  438. getattr(config, func_name)(**args)
  439. if self._option.trt_use_dynamic_shapes:
  440. if self._option.trt_collect_shape_range_info:
  441. # NOTE: We always use a shape range info file.
  442. if self._option.trt_shape_range_info_path is not None:
  443. trt_shape_range_info_path = Path(
  444. self._option.trt_shape_range_info_path
  445. )
  446. else:
  447. trt_shape_range_info_path = cache_dir / "shape_range_info.pbtxt"
  448. should_collect_shape_range_info = True
  449. if not trt_shape_range_info_path.exists():
  450. trt_shape_range_info_path.parent.mkdir(
  451. parents=True, exist_ok=True
  452. )
  453. logging.info(
  454. f"Shape range info will be collected into {trt_shape_range_info_path}"
  455. )
  456. elif self._option.trt_discard_cached_shape_range_info:
  457. trt_shape_range_info_path.unlink()
  458. logging.info(
  459. f"The shape range info file ({trt_shape_range_info_path}) has been removed, and the shape range info will be re-collected."
  460. )
  461. else:
  462. logging.info(
  463. f"A shape range info file ({trt_shape_range_info_path}) already exists. There is no need to collect the info again."
  464. )
  465. should_collect_shape_range_info = False
  466. if should_collect_shape_range_info:
  467. _collect_trt_shape_range_info(
  468. str(model_file),
  469. str(params_file),
  470. self._option.device_id,
  471. str(trt_shape_range_info_path),
  472. self._option.trt_dynamic_shapes,
  473. self._option.trt_dynamic_shape_input_data,
  474. )
  475. if self._option.model_name in DISABLE_TRT_HALF_OPS_CONFIG and self._option.run_mode == "trt_fp16":
  476. lazy_paddle.inference.InternalUtils.disable_tensorrt_half_ops(
  477. config, DISABLE_TRT_HALF_OPS_CONFIG[self._option.model_name]
  478. )
  479. config.enable_tuned_tensorrt_dynamic_shape(
  480. str(trt_shape_range_info_path),
  481. self._option.trt_allow_rebuild_at_runtime,
  482. )
  483. else:
  484. if self._option.trt_dynamic_shapes is not None:
  485. min_shapes, opt_shapes, max_shapes = {}, {}, {}
  486. for (
  487. key,
  488. shapes,
  489. ) in self._option.trt_dynamic_shapes.items():
  490. min_shapes[key] = shapes[0]
  491. opt_shapes[key] = shapes[1]
  492. max_shapes[key] = shapes[2]
  493. config.set_trt_dynamic_shape_info(
  494. min_shapes, max_shapes, opt_shapes
  495. )
  496. else:
  497. raise RuntimeError("No dynamic shape information provided")
  498. return config