static_infer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union, Tuple, List, Dict, Any, Iterator
  15. import os
  16. import shutil
  17. import threading
  18. from pathlib import Path
  19. import lazy_paddle as paddle
  20. import numpy as np
  21. from ....utils.flags import DEBUG, FLAGS_json_format_model, USE_PIR_TRT
  22. from ....utils import logging
  23. from ...utils.pp_option import PaddlePredictorOption
  24. def collect_trt_shapes(
  25. model_file, model_params, gpu_id, shape_range_info_path, trt_dynamic_shapes
  26. ):
  27. config = paddle.inference.Config(model_file, model_params)
  28. config.enable_use_gpu(100, gpu_id)
  29. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  30. for name, candidate_shapes in trt_dynamic_shapes.items():
  31. min_shape, opt_shape, max_shape = candidate_shapes
  32. min_arrs[name] = np.ones(min_shape, dtype=np.float32)
  33. opt_arrs[name] = np.ones(opt_shape, dtype=np.float32)
  34. max_arrs[name] = np.ones(max_shape, dtype=np.float32)
  35. config.collect_shape_range_info(shape_range_info_path)
  36. predictor = paddle.inference.create_predictor(config)
  37. # opt_arrs would be used twice to simulate the most common situations
  38. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  39. for name, arr in arrs.items():
  40. input_handler = predictor.get_input_handle(name)
  41. input_handler.reshape(arr.shape)
  42. input_handler.copy_from_cpu(arr)
  43. predictor.run()
  44. def convert_trt(mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
  45. from lazy_paddle.tensorrt.export import (
  46. Input,
  47. TensorRTConfig,
  48. convert,
  49. PrecisionMode,
  50. )
  51. precision_map = {
  52. "trt_int8": PrecisionMode.INT8,
  53. "trt_fp32": PrecisionMode.FP32,
  54. "trt_fp16": PrecisionMode.FP16,
  55. }
  56. trt_inputs = []
  57. for name, candidate_shapes in trt_dynamic_shapes.items():
  58. min_shape, opt_shape, max_shape = candidate_shapes
  59. trt_input = Input(
  60. min_input_shape=min_shape,
  61. optim_input_shape=opt_shape,
  62. max_input_shape=max_shape,
  63. )
  64. trt_inputs.append(trt_input)
  65. # Create TensorRTConfig
  66. trt_config = TensorRTConfig(inputs=trt_inputs)
  67. trt_config.precision_mode = precision_map[mode]
  68. trt_config.save_model_dir = trt_save_path
  69. convert(pp_model_path, trt_config)
  70. class Copy2GPU:
  71. def __init__(self, input_handlers):
  72. super().__init__()
  73. self.input_handlers = input_handlers
  74. def __call__(self, x):
  75. for idx in range(len(x)):
  76. self.input_handlers[idx].reshape(x[idx].shape)
  77. self.input_handlers[idx].copy_from_cpu(x[idx])
  78. class Copy2CPU:
  79. def __init__(self, output_handlers):
  80. super().__init__()
  81. self.output_handlers = output_handlers
  82. def __call__(self):
  83. output = []
  84. for out_tensor in self.output_handlers:
  85. batch = out_tensor.copy_to_cpu()
  86. output.append(batch)
  87. return output
  88. class Infer:
  89. def __init__(self, predictor):
  90. super().__init__()
  91. self.predictor = predictor
  92. def __call__(self):
  93. self.predictor.run()
  94. class StaticInfer:
  95. """Predictor based on Paddle Inference"""
  96. def __init__(
  97. self, model_dir: str, model_prefix: str, option: PaddlePredictorOption
  98. ) -> None:
  99. super().__init__()
  100. self.model_dir = model_dir
  101. self.model_prefix = model_prefix
  102. self.option = option
  103. self.option.changed = True
  104. self._lock = threading.Lock()
  105. def _reset(self) -> None:
  106. with self._lock:
  107. self.option.changed = False
  108. logging.debug(f"Env: {self.option}")
  109. (
  110. predictor,
  111. input_handlers,
  112. output_handlers,
  113. ) = self._create()
  114. self.copy2gpu = Copy2GPU(input_handlers)
  115. self.copy2cpu = Copy2CPU(output_handlers)
  116. self.infer = Infer(predictor)
  117. def _create(
  118. self,
  119. ) -> Tuple[
  120. "paddle.base.libpaddle.PaddleInferPredictor",
  121. "paddle.base.libpaddle.PaddleInferTensor",
  122. "paddle.base.libpaddle.PaddleInferTensor",
  123. ]:
  124. """_create"""
  125. from lazy_paddle.inference import Config, create_predictor
  126. if FLAGS_json_format_model:
  127. model_file = (self.model_dir / f"{self.model_prefix}.json").as_posix()
  128. # when FLAGS_json_format_model is not set, use inference.json if exist, otherwise inference.pdmodel
  129. else:
  130. model_file = self.model_dir / f"{self.model_prefix}.json"
  131. if model_file.exists():
  132. model_file = model_file.as_posix()
  133. # default by `pdmodel` suffix
  134. else:
  135. model_file = (
  136. self.model_dir / f"{self.model_prefix}.pdmodel"
  137. ).as_posix()
  138. params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
  139. # for TRT
  140. if self.option.run_mode.startswith("trt"):
  141. assert self.option.device == "gpu"
  142. if not USE_PIR_TRT:
  143. if self.option.shape_info_filename is None:
  144. shape_range_info_path = (
  145. self.model_dir / "shape_range_info.pbtxt"
  146. ).as_posix()
  147. else:
  148. shape_range_info_path = self.option.shape_info_filename
  149. if not os.path.exists(shape_range_info_path):
  150. logging.info(
  151. f"Dynamic shape info is collected into: {shape_range_info_path}"
  152. )
  153. collect_trt_shapes(
  154. model_file,
  155. params_file,
  156. self.option.device_id,
  157. shape_range_info_path,
  158. self.option.trt_dynamic_shapes,
  159. )
  160. else:
  161. logging.info(
  162. f"A dynamic shape info file ( {shape_range_info_path} ) already exists. No need to collect again."
  163. )
  164. self.option.shape_info_filename = shape_range_info_path
  165. else:
  166. trt_save_path = (
  167. Path(self.model_dir) / "trt" / self.model_prefix
  168. ).as_posix()
  169. pp_model_path = (Path(self.model_dir) / self.model_prefix).as_posix()
  170. convert_trt(
  171. self.option.run_mode,
  172. pp_model_path,
  173. trt_save_path,
  174. self.option.trt_dynamic_shapes,
  175. )
  176. model_file = trt_save_path + ".json"
  177. params_file = trt_save_path + ".pdiparams"
  178. config = Config(model_file, params_file)
  179. if self.option.device == "gpu":
  180. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  181. config.enable_use_gpu(100, self.option.device_id)
  182. if not self.option.run_mode.startswith("trt"):
  183. if hasattr(config, "enable_new_ir"):
  184. config.enable_new_ir(self.option.enable_new_ir)
  185. if hasattr(config, "enable_new_executor"):
  186. config.enable_new_executor()
  187. config.set_optimization_level(3)
  188. # NOTE: The pptrt settings are not aligned with those of FD.
  189. else:
  190. if not USE_PIR_TRT:
  191. precision_map = {
  192. "trt_int8": Config.Precision.Int8,
  193. "trt_fp32": Config.Precision.Float32,
  194. "trt_fp16": Config.Precision.Half,
  195. }
  196. config.enable_tensorrt_engine(
  197. workspace_size=(1 << 30) * self.option.batch_size,
  198. max_batch_size=self.option.batch_size,
  199. min_subgraph_size=self.option.min_subgraph_size,
  200. precision_mode=precision_map[self.option.run_mode],
  201. use_static=self.option.trt_use_static,
  202. use_calib_mode=self.option.trt_calib_mode,
  203. )
  204. config.enable_tuned_tensorrt_dynamic_shape(
  205. self.option.shape_info_filename, True
  206. )
  207. elif self.option.device == "npu":
  208. config.enable_custom_device("npu")
  209. elif self.option.device == "xpu":
  210. pass
  211. elif self.option.device == "mlu":
  212. config.enable_custom_device("mlu")
  213. elif self.option.device == "dcu":
  214. config.enable_use_gpu(100, self.option.device_id)
  215. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  216. if paddle.is_compiled_with_rocm():
  217. # Delete unsupported passes in dcu
  218. config.delete_pass("conv2d_add_act_fuse_pass")
  219. config.delete_pass("conv2d_add_fuse_pass")
  220. else:
  221. assert self.option.device == "cpu"
  222. config.disable_gpu()
  223. if "mkldnn" in self.option.run_mode:
  224. try:
  225. config.enable_mkldnn()
  226. if "bf16" in self.option.run_mode:
  227. config.enable_mkldnn_bfloat16()
  228. except Exception as e:
  229. logging.warning(
  230. "MKL-DNN is not available. We will disable MKL-DNN."
  231. )
  232. config.set_mkldnn_cache_capacity(-1)
  233. else:
  234. if hasattr(config, "disable_mkldnn"):
  235. config.disable_mkldnn()
  236. config.set_cpu_math_library_num_threads(self.option.cpu_threads)
  237. if hasattr(config, "enable_new_ir"):
  238. config.enable_new_ir(self.option.enable_new_ir)
  239. if hasattr(config, "enable_new_executor"):
  240. config.enable_new_executor()
  241. config.set_optimization_level(3)
  242. config.enable_memory_optim()
  243. for del_p in self.option.delete_pass:
  244. config.delete_pass(del_p)
  245. # Disable paddle inference logging
  246. if not DEBUG:
  247. config.disable_glog_info()
  248. predictor = create_predictor(config)
  249. # Get input and output handlers
  250. input_names = predictor.get_input_names()
  251. input_names.sort()
  252. input_handlers = []
  253. output_handlers = []
  254. for input_name in input_names:
  255. input_handler = predictor.get_input_handle(input_name)
  256. input_handlers.append(input_handler)
  257. output_names = predictor.get_output_names()
  258. for output_name in output_names:
  259. output_handler = predictor.get_output_handle(output_name)
  260. output_handlers.append(output_handler)
  261. return predictor, input_handlers, output_handlers
  262. def __call__(self, x) -> List[Any]:
  263. if self.option.changed:
  264. self._reset()
  265. self.copy2gpu(x)
  266. self.infer()
  267. pred = self.copy2cpu()
  268. return pred
  269. @property
  270. def benchmark(self):
  271. return {
  272. "Copy2GPU": self.copy2gpu,
  273. "Infer": self.infer,
  274. "Copy2CPU": self.copy2cpu,
  275. }