static_infer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union, Tuple, List, Dict, Any, Iterator
  15. import os
  16. import shutil
  17. import threading
  18. from pathlib import Path
  19. import lazy_paddle as paddle
  20. import numpy as np
  21. from ....utils.flags import DEBUG, FLAGS_json_format_model, USE_PIR_TRT
  22. from ....utils import logging
  23. from ...utils.pp_option import PaddlePredictorOption
  24. from ...utils.trt_config import TRT_CFG
  25. # old trt
  26. def collect_trt_shapes(
  27. model_file, model_params, gpu_id, shape_range_info_path, trt_dynamic_shapes
  28. ):
  29. config = paddle.inference.Config(model_file, model_params)
  30. config.enable_use_gpu(100, gpu_id)
  31. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  32. for name, candidate_shapes in trt_dynamic_shapes.items():
  33. min_shape, opt_shape, max_shape = candidate_shapes
  34. min_arrs[name] = np.ones(min_shape, dtype=np.float32)
  35. opt_arrs[name] = np.ones(opt_shape, dtype=np.float32)
  36. max_arrs[name] = np.ones(max_shape, dtype=np.float32)
  37. config.collect_shape_range_info(shape_range_info_path)
  38. predictor = paddle.inference.create_predictor(config)
  39. # opt_arrs would be used twice to simulate the most common situations
  40. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  41. for name, arr in arrs.items():
  42. input_handler = predictor.get_input_handle(name)
  43. input_handler.reshape(arr.shape)
  44. input_handler.copy_from_cpu(arr)
  45. predictor.run()
  46. # pir trt
  47. def convert_trt(model_name, mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
  48. def _set_trt_config():
  49. if settings := TRT_CFG.get(model_name):
  50. for attr_name in settings:
  51. if not hasattr(trt_config, attr_name):
  52. logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
  53. setattr(trt_config, attr_name, settings[attr_name])
  54. from lazy_paddle.tensorrt.export import (
  55. Input,
  56. TensorRTConfig,
  57. convert,
  58. PrecisionMode,
  59. )
  60. precision_map = {
  61. "trt_int8": PrecisionMode.INT8,
  62. "trt_fp32": PrecisionMode.FP32,
  63. "trt_fp16": PrecisionMode.FP16,
  64. }
  65. trt_inputs = []
  66. for name, candidate_shapes in trt_dynamic_shapes.items():
  67. min_shape, opt_shape, max_shape = candidate_shapes
  68. trt_input = Input(
  69. min_input_shape=min_shape,
  70. optim_input_shape=opt_shape,
  71. max_input_shape=max_shape,
  72. )
  73. trt_inputs.append(trt_input)
  74. # Create TensorRTConfig
  75. trt_config = TensorRTConfig(inputs=trt_inputs)
  76. _set_trt_config()
  77. trt_config.precision_mode = precision_map[mode]
  78. trt_config.save_model_dir = trt_save_path
  79. convert(pp_model_path, trt_config)
  80. class Copy2GPU:
  81. def __init__(self, input_handlers):
  82. super().__init__()
  83. self.input_handlers = input_handlers
  84. def __call__(self, x):
  85. for idx in range(len(x)):
  86. self.input_handlers[idx].reshape(x[idx].shape)
  87. self.input_handlers[idx].copy_from_cpu(x[idx])
  88. class Copy2CPU:
  89. def __init__(self, output_handlers):
  90. super().__init__()
  91. self.output_handlers = output_handlers
  92. def __call__(self):
  93. output = []
  94. for out_tensor in self.output_handlers:
  95. batch = out_tensor.copy_to_cpu()
  96. output.append(batch)
  97. return output
  98. class Infer:
  99. def __init__(self, predictor):
  100. super().__init__()
  101. self.predictor = predictor
  102. def __call__(self):
  103. self.predictor.run()
  104. class StaticInfer:
  105. """Predictor based on Paddle Inference"""
  106. def __init__(
  107. self, model_dir: str, model_prefix: str, option: PaddlePredictorOption
  108. ) -> None:
  109. super().__init__()
  110. self.model_dir = model_dir
  111. self.model_prefix = model_prefix
  112. self.option = option
  113. self.option.changed = True
  114. self._lock = threading.Lock()
  115. def _reset(self) -> None:
  116. with self._lock:
  117. self.option.changed = False
  118. logging.debug(f"Env: {self.option}")
  119. (
  120. predictor,
  121. input_handlers,
  122. output_handlers,
  123. ) = self._create()
  124. self.copy2gpu = Copy2GPU(input_handlers)
  125. self.copy2cpu = Copy2CPU(output_handlers)
  126. self.infer = Infer(predictor)
  127. def _create(
  128. self,
  129. ) -> Tuple[
  130. "paddle.base.libpaddle.PaddleInferPredictor",
  131. "paddle.base.libpaddle.PaddleInferTensor",
  132. "paddle.base.libpaddle.PaddleInferTensor",
  133. ]:
  134. """_create"""
  135. from lazy_paddle.inference import Config, create_predictor
  136. if FLAGS_json_format_model:
  137. model_file = (self.model_dir / f"{self.model_prefix}.json").as_posix()
  138. # when FLAGS_json_format_model is not set, use inference.json if exist, otherwise inference.pdmodel
  139. else:
  140. model_file = self.model_dir / f"{self.model_prefix}.json"
  141. if model_file.exists():
  142. model_file = model_file.as_posix()
  143. # default by `pdmodel` suffix
  144. else:
  145. model_file = (
  146. self.model_dir / f"{self.model_prefix}.pdmodel"
  147. ).as_posix()
  148. params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
  149. # for TRT
  150. if self.option.run_mode.startswith("trt"):
  151. assert self.option.device == "gpu"
  152. if not USE_PIR_TRT:
  153. if self.option.shape_info_filename is None:
  154. shape_range_info_path = (
  155. self.model_dir / "shape_range_info.pbtxt"
  156. ).as_posix()
  157. else:
  158. shape_range_info_path = self.option.shape_info_filename
  159. if not os.path.exists(shape_range_info_path):
  160. logging.info(
  161. f"Dynamic shape info is collected into: {shape_range_info_path}"
  162. )
  163. collect_trt_shapes(
  164. model_file,
  165. params_file,
  166. self.option.device_id,
  167. shape_range_info_path,
  168. self.option.trt_dynamic_shapes,
  169. )
  170. else:
  171. logging.info(
  172. f"A dynamic shape info file ( {shape_range_info_path} ) already exists. No need to collect again."
  173. )
  174. self.option.shape_info_filename = shape_range_info_path
  175. else:
  176. trt_save_path = (
  177. Path(self.model_dir) / "trt" / self.model_prefix
  178. ).as_posix()
  179. pp_model_path = (Path(self.model_dir) / self.model_prefix).as_posix()
  180. convert_trt(
  181. self.option.model_name,
  182. self.option.run_mode,
  183. pp_model_path,
  184. trt_save_path,
  185. self.option.trt_dynamic_shapes,
  186. )
  187. model_file = trt_save_path + ".json"
  188. params_file = trt_save_path + ".pdiparams"
  189. config = Config(model_file, params_file)
  190. if self.option.device == "gpu":
  191. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  192. config.enable_use_gpu(100, self.option.device_id)
  193. if not self.option.run_mode.startswith("trt"):
  194. if hasattr(config, "enable_new_ir"):
  195. config.enable_new_ir(self.option.enable_new_ir)
  196. if hasattr(config, "enable_new_executor"):
  197. config.enable_new_executor()
  198. config.set_optimization_level(3)
  199. # NOTE: The pptrt settings are not aligned with those of FD.
  200. else:
  201. if not USE_PIR_TRT:
  202. precision_map = {
  203. "trt_int8": Config.Precision.Int8,
  204. "trt_fp32": Config.Precision.Float32,
  205. "trt_fp16": Config.Precision.Half,
  206. }
  207. config.enable_tensorrt_engine(
  208. workspace_size=(1 << 30) * self.option.batch_size,
  209. max_batch_size=self.option.batch_size,
  210. min_subgraph_size=self.option.min_subgraph_size,
  211. precision_mode=precision_map[self.option.run_mode],
  212. use_static=self.option.trt_use_static,
  213. use_calib_mode=self.option.trt_calib_mode,
  214. )
  215. config.enable_tuned_tensorrt_dynamic_shape(
  216. self.option.shape_info_filename, True
  217. )
  218. elif self.option.device == "npu":
  219. config.enable_custom_device("npu")
  220. if hasattr(config, "enable_new_executor"):
  221. config.enable_new_executor()
  222. elif self.option.device == "xpu":
  223. if hasattr(config, "enable_new_executor"):
  224. config.enable_new_executor()
  225. elif self.option.device == "mlu":
  226. config.enable_custom_device("mlu")
  227. if hasattr(config, "enable_new_executor"):
  228. config.enable_new_executor()
  229. elif self.option.device == "dcu":
  230. config.enable_use_gpu(100, self.option.device_id)
  231. if hasattr(config, "enable_new_executor"):
  232. config.enable_new_executor()
  233. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  234. if paddle.is_compiled_with_rocm():
  235. # Delete unsupported passes in dcu
  236. config.delete_pass("conv2d_add_act_fuse_pass")
  237. config.delete_pass("conv2d_add_fuse_pass")
  238. else:
  239. assert self.option.device == "cpu"
  240. config.disable_gpu()
  241. if "mkldnn" in self.option.run_mode:
  242. try:
  243. config.enable_mkldnn()
  244. if "bf16" in self.option.run_mode:
  245. config.enable_mkldnn_bfloat16()
  246. except Exception as e:
  247. logging.warning(
  248. "MKL-DNN is not available. We will disable MKL-DNN."
  249. )
  250. config.set_mkldnn_cache_capacity(-1)
  251. else:
  252. if hasattr(config, "disable_mkldnn"):
  253. config.disable_mkldnn()
  254. config.set_cpu_math_library_num_threads(self.option.cpu_threads)
  255. if hasattr(config, "enable_new_ir"):
  256. config.enable_new_ir(self.option.enable_new_ir)
  257. if hasattr(config, "enable_new_executor"):
  258. config.enable_new_executor()
  259. config.set_optimization_level(3)
  260. config.enable_memory_optim()
  261. for del_p in self.option.delete_pass:
  262. config.delete_pass(del_p)
  263. # Disable paddle inference logging
  264. if not DEBUG:
  265. config.disable_glog_info()
  266. predictor = create_predictor(config)
  267. # Get input and output handlers
  268. input_names = predictor.get_input_names()
  269. input_names.sort()
  270. input_handlers = []
  271. output_handlers = []
  272. for input_name in input_names:
  273. input_handler = predictor.get_input_handle(input_name)
  274. input_handlers.append(input_handler)
  275. output_names = predictor.get_output_names()
  276. for output_name in output_names:
  277. output_handler = predictor.get_output_handle(output_name)
  278. output_handlers.append(output_handler)
  279. return predictor, input_handlers, output_handlers
  280. def __call__(self, x) -> List[Any]:
  281. if self.option.changed:
  282. self._reset()
  283. self.copy2gpu(x)
  284. self.infer()
  285. pred = self.copy2cpu()
  286. return pred
  287. @property
  288. def benchmark(self):
  289. return {
  290. "Copy2GPU": self.copy2gpu,
  291. "Infer": self.infer,
  292. "Copy2CPU": self.copy2cpu,
  293. }