static_infer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Sequence, List
  15. from pathlib import Path
  16. import lazy_paddle
  17. import numpy as np
  18. from ....utils import logging
  19. from ....utils.device import constr_device
  20. from ....utils.flags import DEBUG, USE_PIR_TRT
  21. from ...utils.benchmark import benchmark, set_inference_operations
  22. from ...utils.hpi import get_model_paths
  23. from ...utils.pp_option import PaddlePredictorOption
  24. from ...utils.trt_config import TRT_CFG
  25. CACHE_DIR = ".cache"
  26. INFERENCE_OPERATIONS = ["PaddleCopyToDevice", "PaddleCopyToHost", "PaddleModelInfer"]
  27. set_inference_operations(INFERENCE_OPERATIONS)
  28. # XXX: Better use Paddle Inference API to do this
  29. def _pd_dtype_to_np_dtype(pd_dtype):
  30. if pd_dtype == lazy_paddle.inference.DataType.FLOAT64:
  31. return np.float64
  32. elif pd_dtype == lazy_paddle.inference.DataType.FLOAT32:
  33. return np.float32
  34. elif pd_dtype == lazy_paddle.inference.DataType.INT64:
  35. return np.int64
  36. elif pd_dtype == lazy_paddle.inference.DataType.INT32:
  37. return np.int32
  38. elif pd_dtype == lazy_paddle.inference.DataType.UINT8:
  39. return np.uint8
  40. elif pd_dtype == lazy_paddle.inference.DataType.INT8:
  41. return np.int8
  42. else:
  43. raise TypeError(f"Unsupported data type: {pd_dtype}")
  44. # old trt
  45. def _collect_trt_shape_range_info(
  46. model_file,
  47. model_params,
  48. gpu_id,
  49. shape_range_info_path,
  50. dynamic_shapes,
  51. dynamic_shape_input_data,
  52. ):
  53. dynamic_shape_input_data = dynamic_shape_input_data or {}
  54. config = lazy_paddle.inference.Config(model_file, model_params)
  55. config.enable_use_gpu(100, gpu_id)
  56. config.collect_shape_range_info(shape_range_info_path)
  57. # TODO: Add other needed options
  58. config.disable_glog_info()
  59. predictor = lazy_paddle.inference.create_predictor(config)
  60. input_names = predictor.get_input_names()
  61. for name in dynamic_shapes:
  62. if name not in input_names:
  63. raise ValueError(
  64. f"Invalid input name {repr(name)} found in `dynamic_shapes`"
  65. )
  66. for name in input_names:
  67. if name not in dynamic_shapes:
  68. raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
  69. for name in dynamic_shape_input_data:
  70. if name not in input_names:
  71. raise ValueError(
  72. f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
  73. )
  74. # It would be better to check if the shapes are valid.
  75. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  76. for name, candidate_shapes in dynamic_shapes.items():
  77. # XXX: Currently we have no way to get the data type of the tensor
  78. # without creating an input handle.
  79. handle = predictor.get_input_handle(name)
  80. dtype = _pd_dtype_to_np_dtype(handle.type())
  81. min_shape, opt_shape, max_shape = candidate_shapes
  82. if name in dynamic_shape_input_data:
  83. min_arrs[name] = np.array(
  84. dynamic_shape_input_data[name][0], dtype=dtype
  85. ).reshape(min_shape)
  86. opt_arrs[name] = np.array(
  87. dynamic_shape_input_data[name][1], dtype=dtype
  88. ).reshape(opt_shape)
  89. max_arrs[name] = np.array(
  90. dynamic_shape_input_data[name][2], dtype=dtype
  91. ).reshape(max_shape)
  92. else:
  93. min_arrs[name] = np.ones(min_shape, dtype=dtype)
  94. opt_arrs[name] = np.ones(opt_shape, dtype=dtype)
  95. max_arrs[name] = np.ones(max_shape, dtype=dtype)
  96. # `opt_arrs` is used twice to ensure it is the most frequently used.
  97. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  98. for name, arr in arrs.items():
  99. handle = predictor.get_input_handle(name)
  100. handle.reshape(arr.shape)
  101. handle.copy_from_cpu(arr)
  102. predictor.run()
  103. # HACK: The shape range info will be written to the file only when
  104. # `predictor` is garbage collected. It works in CPython, but it is
  105. # definitely a bad idea to count on the implementation-dependent behavior of
  106. # a garbage collector. Is there a more explicit and deterministic way to
  107. # handle this?
  108. # pir trt
  109. def _convert_trt(
  110. model_name,
  111. mode,
  112. pp_model_file,
  113. pp_params_file,
  114. trt_save_path,
  115. trt_dynamic_shapes,
  116. ):
  117. def _set_trt_config():
  118. if settings := TRT_CFG.get(model_name):
  119. for attr_name in settings:
  120. if not hasattr(trt_config, attr_name):
  121. logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
  122. setattr(trt_config, attr_name, settings[attr_name])
  123. from lazy_paddle.tensorrt.export import (
  124. Input,
  125. TensorRTConfig,
  126. convert,
  127. PrecisionMode,
  128. )
  129. def _get_input_names(model_file, params_file):
  130. # HACK
  131. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  132. # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
  133. config.disable_mkldnn()
  134. config.disable_glog_info()
  135. predictor = lazy_paddle.inference.create_predictor(config)
  136. return predictor.get_input_names()
  137. input_names = _get_input_names(pp_model_file, pp_params_file)
  138. for name in trt_dynamic_shapes:
  139. if name not in input_names:
  140. raise ValueError(
  141. f"Invalid input name {repr(name)} found in `trt_dynamic_shapes`"
  142. )
  143. for name in input_names:
  144. if name not in trt_dynamic_shapes:
  145. raise ValueError(
  146. f"Input name {repr(name)} not found in `trt_dynamic_shapes`"
  147. )
  148. precision_map = {
  149. "trt_int8": PrecisionMode.INT8,
  150. "trt_fp32": PrecisionMode.FP32,
  151. "trt_fp16": PrecisionMode.FP16,
  152. }
  153. trt_inputs = []
  154. for name in input_names:
  155. min_shape, opt_shape, max_shape = trt_dynamic_shapes[name]
  156. trt_input = Input(
  157. min_input_shape=min_shape,
  158. optim_input_shape=opt_shape,
  159. max_input_shape=max_shape,
  160. )
  161. trt_inputs.append(trt_input)
  162. # Create TensorRTConfig
  163. trt_config = TensorRTConfig(inputs=trt_inputs)
  164. _set_trt_config()
  165. trt_config.precision_mode = precision_map[mode]
  166. trt_config.save_model_dir = str(trt_save_path)
  167. pp_model_path = str(pp_model_file.with_suffix(""))
  168. convert(pp_model_path, trt_config)
  169. def _sort_inputs(inputs, names):
  170. # NOTE: Adjust input tensors to match the sorted sequence.
  171. indices = sorted(range(len(names)), key=names.__getitem__)
  172. inputs = [inputs[indices.index(i)] for i in range(len(inputs))]
  173. return inputs
  174. def _concatenate(*callables):
  175. def _chain(x):
  176. for c in callables:
  177. x = c(x)
  178. return x
  179. return _chain
  180. @benchmark.timeit
  181. class PaddleCopyToDevice:
  182. def __init__(self, device_type, device_id):
  183. self.device_type = device_type
  184. self.device_id = device_id
  185. def __call__(self, arrs):
  186. device_id = [self.device_id] if self.device_id is not None else self.device_id
  187. device = constr_device(self.device_type, device_id)
  188. paddle_tensors = [lazy_paddle.to_tensor(i, place=device) for i in arrs]
  189. return paddle_tensors
  190. @benchmark.timeit
  191. class PaddleCopyToHost:
  192. def __call__(self, paddle_tensors):
  193. arrs = [i.numpy() for i in paddle_tensors]
  194. return arrs
  195. @benchmark.timeit
  196. class PaddleModelInfer:
  197. def __init__(self, predictor):
  198. super().__init__()
  199. self.predictor = predictor
  200. def __call__(self, x):
  201. return self.predictor.run(x)
  202. # FIXME: Name might be misleading
  203. @benchmark.timeit
  204. class PaddleInferChainLegacy:
  205. def __init__(self, predictor):
  206. self.predictor = predictor
  207. input_names = self.predictor.get_input_names()
  208. self.input_handles = []
  209. self.output_handles = []
  210. for input_name in input_names:
  211. input_handle = self.predictor.get_input_handle(input_name)
  212. self.input_handles.append(input_handle)
  213. output_names = self.predictor.get_output_names()
  214. for output_name in output_names:
  215. output_handle = self.predictor.get_output_handle(output_name)
  216. self.output_handles.append(output_handle)
  217. def __call__(self, x):
  218. for input_, input_handle in zip(x, self.input_handles):
  219. input_handle.reshape(input_.shape)
  220. input_handle.copy_from_cpu(input_)
  221. self.predictor.run()
  222. outputs = [o.copy_to_cpu() for o in self.output_handles]
  223. return outputs
  224. class StaticInfer(object):
  225. def __init__(
  226. self,
  227. model_dir: str,
  228. model_prefix: str,
  229. option: PaddlePredictorOption,
  230. ) -> None:
  231. super().__init__()
  232. self.model_dir = model_dir
  233. self.model_file_prefix = model_prefix
  234. self._option = option
  235. self.predictor = self._create()
  236. if not self._use_legacy_api:
  237. device_type = self._option.device_type
  238. device_type = "gpu" if device_type == "dcu" else device_type
  239. copy_to_device = PaddleCopyToDevice(device_type, self._option.device_id)
  240. copy_to_host = PaddleCopyToHost()
  241. model_infer = PaddleModelInfer(self.predictor)
  242. self.infer = _concatenate(copy_to_device, model_infer, copy_to_host)
  243. else:
  244. self.infer = PaddleInferChainLegacy(self.predictor)
  245. @property
  246. def _use_legacy_api(self):
  247. return self._option.device_type not in ("cpu", "gpu", "dcu")
  248. def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
  249. names = self.predictor.get_input_names()
  250. if len(names) != len(x):
  251. raise ValueError(
  252. f"The number of inputs does not match the model: {len(names)} vs {len(x)}"
  253. )
  254. # TODO:
  255. # Ensure that input tensors follow the model's input sequence without sorting.
  256. x = _sort_inputs(x, names)
  257. x = list(map(np.ascontiguousarray, x))
  258. pred = self.infer(x)
  259. return pred
  260. def _create(
  261. self,
  262. ):
  263. """_create"""
  264. model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
  265. if "paddle" not in model_paths:
  266. raise RuntimeError("No valid Paddle model found")
  267. model_file, params_file = model_paths["paddle"]
  268. if (
  269. self._option.model_name == "LaTeX_OCR_rec"
  270. and self._option.device_type == "cpu"
  271. ):
  272. import cpuinfo
  273. if (
  274. "GenuineIntel" in cpuinfo.get_cpu_info().get("vendor_id_raw", "")
  275. and self._option.run_mode != "mkldnn"
  276. ):
  277. logging.warning(
  278. "Now, the `LaTeX_OCR_rec` model only support `mkldnn` mode when running on Intel CPU devices. So using `mkldnn` instead."
  279. )
  280. self._option.run_mode = "mkldnn"
  281. logging.debug("`run_mode` updated to 'mkldnn'")
  282. if self._option.device_type == "cpu" and self._option.device_id is not None:
  283. self._option.device_id = None
  284. logging.debug("`device_id` has been set to None")
  285. if (
  286. self._option.device_type in ("gpu", "dcu")
  287. and self._option.device_id is None
  288. ):
  289. self._option.device_id = 0
  290. logging.debug("`device_id` has been set to 0")
  291. # for TRT
  292. if self._option.run_mode.startswith("trt"):
  293. assert self._option.device_type == "gpu"
  294. cache_dir = self.model_dir / CACHE_DIR / "paddle"
  295. config = self._configure_trt(
  296. model_file,
  297. params_file,
  298. cache_dir,
  299. )
  300. else:
  301. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  302. if self._option.device_type == "gpu":
  303. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  304. config.enable_use_gpu(100, self._option.device_id)
  305. if not self._option.run_mode.startswith("trt"):
  306. if hasattr(config, "enable_new_ir"):
  307. config.enable_new_ir(self._option.enable_new_ir)
  308. if hasattr(config, "enable_new_executor"):
  309. config.enable_new_executor()
  310. config.set_optimization_level(3)
  311. elif self._option.device_type == "npu":
  312. config.enable_custom_device("npu")
  313. if hasattr(config, "enable_new_executor"):
  314. config.enable_new_executor()
  315. elif self._option.device_type == "xpu":
  316. if hasattr(config, "enable_new_executor"):
  317. config.enable_new_executor()
  318. elif self._option.device_type == "mlu":
  319. config.enable_custom_device("mlu")
  320. if hasattr(config, "enable_new_executor"):
  321. config.enable_new_executor()
  322. elif self._option.device_type == "dcu":
  323. config.enable_use_gpu(100, self._option.device_id)
  324. if hasattr(config, "enable_new_executor"):
  325. config.enable_new_executor()
  326. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  327. if lazy_paddle.is_compiled_with_rocm():
  328. # Delete unsupported passes in dcu
  329. config.delete_pass("conv2d_add_act_fuse_pass")
  330. config.delete_pass("conv2d_add_fuse_pass")
  331. else:
  332. assert self._option.device_type == "cpu"
  333. config.disable_gpu()
  334. if "mkldnn" in self._option.run_mode:
  335. try:
  336. config.enable_mkldnn()
  337. if "bf16" in self._option.run_mode:
  338. config.enable_mkldnn_bfloat16()
  339. except Exception as e:
  340. logging.warning(
  341. "MKL-DNN is not available. We will disable MKL-DNN."
  342. )
  343. config.set_mkldnn_cache_capacity(-1)
  344. else:
  345. if hasattr(config, "disable_mkldnn"):
  346. config.disable_mkldnn()
  347. config.set_cpu_math_library_num_threads(self._option.cpu_threads)
  348. if hasattr(config, "enable_new_ir"):
  349. config.enable_new_ir(self._option.enable_new_ir)
  350. if hasattr(config, "enable_new_executor"):
  351. config.enable_new_executor()
  352. config.set_optimization_level(3)
  353. config.enable_memory_optim()
  354. for del_p in self._option.delete_pass:
  355. config.delete_pass(del_p)
  356. # Disable paddle inference logging
  357. if not DEBUG:
  358. config.disable_glog_info()
  359. predictor = lazy_paddle.inference.create_predictor(config)
  360. return predictor
  361. def _configure_trt(self, model_file, params_file, cache_dir):
  362. # TODO: Support calibration
  363. if USE_PIR_TRT:
  364. trt_save_path = cache_dir / "trt" / self.model_file_prefix
  365. _convert_trt(
  366. self._option.model_name,
  367. self._option.run_mode,
  368. model_file,
  369. params_file,
  370. trt_save_path,
  371. self._option.trt_dynamic_shapes,
  372. )
  373. model_file = trt_save_path.with_suffix(".json")
  374. params_file = trt_save_path.with_suffix(".pdiparams")
  375. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  376. else:
  377. PRECISION_MAP = {
  378. "trt_int8": lazy_paddle.inference.Config.Precision.Int8,
  379. "trt_fp32": lazy_paddle.inference.Config.Precision.Float32,
  380. "trt_fp16": lazy_paddle.inference.Config.Precision.Half,
  381. }
  382. config = lazy_paddle.inference.Config(str(model_file), str(params_file))
  383. config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
  384. config.enable_use_gpu(100, self._option.device_id)
  385. config.enable_tensorrt_engine(
  386. workspace_size=self._option.trt_max_workspace_size,
  387. max_batch_size=self._option.trt_max_batch_size,
  388. min_subgraph_size=self._option.trt_min_subgraph_size,
  389. precision_mode=PRECISION_MAP[self._option.run_mode],
  390. use_static=self._option.trt_use_static,
  391. use_calib_mode=self._option.trt_use_calib_mode,
  392. )
  393. if self._option.trt_use_dynamic_shapes:
  394. if self._option.trt_collect_shape_range_info:
  395. # NOTE: We always use a shape range info file.
  396. if self._option.trt_shape_range_info_path is not None:
  397. trt_shape_range_info_path = Path(
  398. self._option.trt_shape_range_info_path
  399. )
  400. else:
  401. trt_shape_range_info_path = cache_dir / "shape_range_info.pbtxt"
  402. should_collect_shape_range_info = True
  403. if not trt_shape_range_info_path.exists():
  404. trt_shape_range_info_path.parent.mkdir(
  405. parents=True, exist_ok=True
  406. )
  407. logging.info(
  408. f"Shape range info will be collected into {trt_shape_range_info_path}"
  409. )
  410. elif self._option.trt_discard_cached_shape_range_info:
  411. trt_shape_range_info_path.unlink()
  412. logging.info(
  413. f"The shape range info file ({trt_shape_range_info_path}) has been removed, and the shape range info will be re-collected."
  414. )
  415. else:
  416. logging.info(
  417. f"A shape range info file ({trt_shape_range_info_path}) already exists. There is no need to collect the info again."
  418. )
  419. should_collect_shape_range_info = False
  420. if should_collect_shape_range_info:
  421. _collect_trt_shape_range_info(
  422. str(model_file),
  423. str(params_file),
  424. self._option.device_id,
  425. str(trt_shape_range_info_path),
  426. self._option.trt_dynamic_shapes,
  427. self._option.trt_dynamic_shape_input_data,
  428. )
  429. config.enable_tuned_tensorrt_dynamic_shape(
  430. str(trt_shape_range_info_path),
  431. self._option.trt_allow_rebuild_at_runtime,
  432. )
  433. else:
  434. if self._option.trt_dynamic_shapes is not None:
  435. min_shapes, opt_shapes, max_shapes = {}, {}, {}
  436. for (
  437. key,
  438. shapes,
  439. ) in self._option.trt_dynamic_shapes.items():
  440. min_shapes[key] = shapes[0]
  441. opt_shapes[key] = shapes[1]
  442. max_shapes[key] = shapes[2]
  443. config.set_trt_dynamic_shape_info(
  444. min_shapes, max_shapes, opt_shapes
  445. )
  446. else:
  447. raise RuntimeError("No dynamic shape information provided")
  448. return config