static_infer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union, Tuple, List, Dict, Any, Iterator
  15. import os
  16. import shutil
  17. from pathlib import Path
  18. import lazy_paddle as paddle
  19. import numpy as np
  20. from ....utils.flags import DEBUG, FLAGS_json_format_model
  21. from ....utils import logging
  22. from ...utils.pp_option import PaddlePredictorOption
  23. def collect_trt_shapes(
  24. model_file, model_params, gpu_id, shape_range_info_path, trt_dynamic_shapes
  25. ):
  26. config = paddle.inference.Config(model_file, model_params)
  27. config.enable_use_gpu(100, gpu_id)
  28. min_arrs, opt_arrs, max_arrs = {}, {}, {}
  29. for name, candidate_shapes in trt_dynamic_shapes.items():
  30. min_shape, opt_shape, max_shape = candidate_shapes
  31. min_arrs[name] = np.ones(min_shape, dtype=np.float32)
  32. opt_arrs[name] = np.ones(opt_shape, dtype=np.float32)
  33. max_arrs[name] = np.ones(max_shape, dtype=np.float32)
  34. config.collect_shape_range_info(shape_range_info_path)
  35. predictor = paddle.inference.create_predictor(config)
  36. # opt_arrs would be used twice to simulate the most common situations
  37. for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]:
  38. for name, arr in arrs.items():
  39. input_handler = predictor.get_input_handle(name)
  40. input_handler.reshape(arr.shape)
  41. input_handler.copy_from_cpu(arr)
  42. predictor.run()
  43. def convert_trt(mode, pp_model_path, trt_dynamic_shapes):
  44. from lazy_paddle.tensorrt.export import (
  45. Input,
  46. TensorRTConfig,
  47. convert,
  48. PrecisionMode,
  49. )
  50. trt_save_dir = str(Path(pp_model_path) / "trt" / "inference")
  51. precision_map = {
  52. "trt_int8": PrecisionMode.INT8,
  53. "trt_fp32": PrecisionMode.FP32,
  54. "trt_fp16": PrecisionMode.FP16,
  55. }
  56. trt_inputs = []
  57. for name, candidate_shapes in trt_dynamic_shapes.items():
  58. min_shape, opt_shape, max_shape = candidate_shapes
  59. trt_input = Input(
  60. min_input_shape=min_shape,
  61. optim_input_shape=opt_shape,
  62. max_input_shape=max_shape,
  63. )
  64. trt_inputs.append(trt_input)
  65. # Create TensorRTConfig
  66. trt_config = TensorRTConfig(inputs=trt_inputs)
  67. trt_config.precision_mode = precision_map[mode]
  68. trt_config.save_model_dir = trt_save_dir
  69. convert(str(Path(pp_model_path) / "inference"), trt_config)
  70. # copy inference.yaml to new model dir
  71. shutil.copy(str(Path(pp_model_path) / "inference.yml"), trt_save_dir + ".yml")
  72. class Copy2GPU:
  73. def __init__(self, input_handlers):
  74. super().__init__()
  75. self.input_handlers = input_handlers
  76. def __call__(self, x):
  77. for idx in range(len(x)):
  78. self.input_handlers[idx].reshape(x[idx].shape)
  79. self.input_handlers[idx].copy_from_cpu(x[idx])
  80. class Copy2CPU:
  81. def __init__(self, output_handlers):
  82. super().__init__()
  83. self.output_handlers = output_handlers
  84. def __call__(self):
  85. output = []
  86. for out_tensor in self.output_handlers:
  87. batch = out_tensor.copy_to_cpu()
  88. output.append(batch)
  89. return output
  90. class Infer:
  91. def __init__(self, predictor):
  92. super().__init__()
  93. self.predictor = predictor
  94. def __call__(self):
  95. self.predictor.run()
  96. class StaticInfer:
  97. """Predictor based on Paddle Inference"""
  98. def __init__(
  99. self, model_dir: str, model_prefix: str, option: PaddlePredictorOption
  100. ) -> None:
  101. super().__init__()
  102. self.model_dir = model_dir
  103. self.model_prefix = model_prefix
  104. self._update_option(option)
  105. def _update_option(self, option: PaddlePredictorOption) -> None:
  106. if self.option and option == self.option:
  107. return
  108. self._option = option
  109. self._reset()
  110. @property
  111. def option(self) -> PaddlePredictorOption:
  112. return self._option if hasattr(self, "_option") else None
  113. @option.setter
  114. def option(self, option: Union[None, PaddlePredictorOption]) -> None:
  115. if option:
  116. self._update_option(option)
  117. def _reset(self) -> None:
  118. logging.debug(f"Env: {self.option}")
  119. (
  120. predictor,
  121. input_handlers,
  122. output_handlers,
  123. ) = self._create()
  124. self.copy2gpu = Copy2GPU(input_handlers)
  125. self.copy2cpu = Copy2CPU(output_handlers)
  126. self.infer = Infer(predictor)
  127. self.option.changed = False
  128. def _create(
  129. self,
  130. ) -> Tuple[
  131. "paddle.base.libpaddle.PaddleInferPredictor",
  132. "paddle.base.libpaddle.PaddleInferTensor",
  133. "paddle.base.libpaddle.PaddleInferTensor",
  134. ]:
  135. """_create"""
  136. from lazy_paddle.inference import Config, create_predictor
  137. if FLAGS_json_format_model:
  138. model_file = (self.model_dir / f"{self.model_prefix}.json").as_posix()
  139. # when FLAGS_json_format_model is not set, use inference.json if exist, otherwise inference.pdmodel
  140. else:
  141. model_file = self.model_dir / f"{self.model_prefix}.json"
  142. if model_file.exists():
  143. model_file = model_file.as_posix()
  144. # default by `pdmodel` suffix
  145. else:
  146. model_file = (
  147. self.model_dir / f"{self.model_prefix}.pdmodel"
  148. ).as_posix()
  149. params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
  150. config = Config(model_file, params_file)
  151. if self.option.device == "gpu":
  152. if self.option.device == "gpu":
  153. config.exp_disable_mixed_precision_ops({"feed", "fetch"})
  154. config.enable_use_gpu(100, self.option.device_id)
  155. if hasattr(config, "enable_new_ir"):
  156. config.enable_new_ir(self.option.enable_new_ir)
  157. if hasattr(config, "enable_new_executor"):
  158. config.enable_new_executor()
  159. config.set_optimization_level(3)
  160. # NOTE: The pptrt settings are not aligned with those of FD.
  161. precision_map = {
  162. "trt_int8": Config.Precision.Int8,
  163. "trt_fp32": Config.Precision.Float32,
  164. "trt_fp16": Config.Precision.Half,
  165. }
  166. if self.option.run_mode in precision_map.keys():
  167. config.enable_tensorrt_engine(
  168. workspace_size=(1 << 25) * self.option.batch_size,
  169. max_batch_size=self.option.batch_size,
  170. min_subgraph_size=self.option.min_subgraph_size,
  171. precision_mode=precision_map[self.option.run_mode],
  172. use_static=self.option.trt_use_static,
  173. use_calib_mode=self.option.trt_calib_mode,
  174. )
  175. if not os.path.exists(self.option.shape_info_filename):
  176. logging.info(
  177. f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
  178. )
  179. collect_trt_shapes(
  180. model_file,
  181. params_file,
  182. self.option.device_id,
  183. self.option.shape_info_filename,
  184. self.option.trt_dynamic_shapes,
  185. )
  186. else:
  187. logging.info(
  188. f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again."
  189. )
  190. config.enable_tuned_tensorrt_dynamic_shape(
  191. self.option.shape_info_filename, True
  192. )
  193. elif self.option.device == "npu":
  194. config.enable_custom_device("npu")
  195. elif self.option.device == "xpu":
  196. pass
  197. elif self.option.device == "mlu":
  198. config.enable_custom_device("mlu")
  199. elif self.option.device == "dcu":
  200. if paddle.is_compiled_with_rocm():
  201. # Delete unsupported passes in dcu
  202. config.delete_pass("conv2d_add_act_fuse_pass")
  203. config.delete_pass("conv2d_add_fuse_pass")
  204. else:
  205. assert self.option.device == "cpu"
  206. config.disable_gpu()
  207. if "mkldnn" in self.option.run_mode:
  208. try:
  209. config.enable_mkldnn()
  210. if "bf16" in self.option.run_mode:
  211. config.enable_mkldnn_bfloat16()
  212. except Exception as e:
  213. logging.warning(
  214. "MKL-DNN is not available. We will disable MKL-DNN."
  215. )
  216. config.set_mkldnn_cache_capacity(-1)
  217. else:
  218. if hasattr(config, "disable_mkldnn"):
  219. config.disable_mkldnn()
  220. config.set_cpu_math_library_num_threads(self.option.cpu_threads)
  221. if hasattr(config, "enable_new_ir"):
  222. config.enable_new_ir(self.option.enable_new_ir)
  223. if hasattr(config, "enable_new_executor"):
  224. config.enable_new_executor()
  225. config.set_optimization_level(3)
  226. config.enable_memory_optim()
  227. for del_p in self.option.delete_pass:
  228. config.delete_pass(del_p)
  229. # Disable paddle inference logging
  230. if not DEBUG:
  231. config.disable_glog_info()
  232. predictor = create_predictor(config)
  233. # Get input and output handlers
  234. input_names = predictor.get_input_names()
  235. input_names.sort()
  236. input_handlers = []
  237. output_handlers = []
  238. for input_name in input_names:
  239. input_handler = predictor.get_input_handle(input_name)
  240. input_handlers.append(input_handler)
  241. output_names = predictor.get_output_names()
  242. for output_name in output_names:
  243. output_handler = predictor.get_output_handle(output_name)
  244. output_handlers.append(output_handler)
  245. return predictor, input_handlers, output_handlers
  246. def __call__(self, x) -> List[Any]:
  247. if self.option.changed:
  248. self._reset()
  249. self.copy2gpu(x)
  250. self.infer()
  251. pred = self.copy2cpu()
  252. return pred
  253. @property
  254. def benchmark(self):
  255. return {
  256. "Copy2GPU": self.copy2gpu,
  257. "Infer": self.infer,
  258. "Copy2CPU": self.copy2cpu,
  259. }