|
|
@@ -14,12 +14,12 @@
|
|
|
|
|
|
from typing import Union, Tuple, List, Dict, Any, Iterator
|
|
|
import os
|
|
|
-import inspect
|
|
|
-from abc import abstractmethod
|
|
|
+import shutil
|
|
|
+from pathlib import Path
|
|
|
import lazy_paddle as paddle
|
|
|
import numpy as np
|
|
|
|
|
|
-from ....utils.flags import FLAGS_json_format_model
|
|
|
+from ....utils.flags import DEBUG, FLAGS_json_format_model
|
|
|
from ....utils import logging
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
|
|
|
@@ -47,6 +47,40 @@ def collect_trt_shapes(
|
|
|
predictor.run()
|
|
|
|
|
|
|
|
|
+def convert_trt(mode, pp_model_path, trt_dynamic_shapes):
|
|
|
+ from lazy_paddle.tensorrt.export import (
|
|
|
+ Input,
|
|
|
+ TensorRTConfig,
|
|
|
+ convert,
|
|
|
+ PrecisionMode,
|
|
|
+ )
|
|
|
+
|
|
|
+ trt_save_dir = str(Path(pp_model_path) / "trt" / "inference")
|
|
|
+
|
|
|
+ precision_map = {
|
|
|
+ "trt_int8": PrecisionMode.INT8,
|
|
|
+ "trt_fp32": PrecisionMode.FP32,
|
|
|
+ "trt_fp16": PrecisionMode.FP16,
|
|
|
+ }
|
|
|
+ trt_inputs = []
|
|
|
+ for name, candidate_shapes in trt_dynamic_shapes.items():
|
|
|
+ min_shape, opt_shape, max_shape = candidate_shapes
|
|
|
+ trt_input = Input(
|
|
|
+ min_input_shape=min_shape,
|
|
|
+ optim_input_shape=opt_shape,
|
|
|
+ max_input_shape=max_shape,
|
|
|
+ )
|
|
|
+ trt_inputs.append(trt_input)
|
|
|
+
|
|
|
+ # Create TensorRTConfig
|
|
|
+ trt_config = TensorRTConfig(inputs=trt_inputs)
|
|
|
+ trt_config.precision_mode = precision_map[mode]
|
|
|
+ trt_config.save_model_dir = trt_save_dir
|
|
|
+ convert(str(Path(pp_model_path) / "inference"), trt_config)
|
|
|
+ # copy inference.yaml to new model dir
|
|
|
+ shutil.copy(str(Path(pp_model_path) / "inference.yml"), trt_save_dir + ".yml")
|
|
|
+
|
|
|
+
|
|
|
class Copy2GPU:
|
|
|
|
|
|
def __init__(self, input_handlers):
|
|
|
@@ -144,48 +178,53 @@ class StaticInfer:
|
|
|
self.model_dir / f"{self.model_prefix}.pdmodel"
|
|
|
).as_posix()
|
|
|
params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
|
|
|
- config = Config(model_file, params_file)
|
|
|
|
|
|
- config.enable_memory_optim()
|
|
|
- if self.option.device in ("gpu", "dcu"):
|
|
|
+ config = Config(model_file, params_file)
|
|
|
+ if self.option.device == "gpu":
|
|
|
if self.option.device == "gpu":
|
|
|
config.exp_disable_mixed_precision_ops({"feed", "fetch"})
|
|
|
config.enable_use_gpu(100, self.option.device_id)
|
|
|
- if self.option.device == "gpu":
|
|
|
- # NOTE: The pptrt settings are not aligned with those of FD.
|
|
|
- precision_map = {
|
|
|
- "trt_int8": Config.Precision.Int8,
|
|
|
- "trt_fp32": Config.Precision.Float32,
|
|
|
- "trt_fp16": Config.Precision.Half,
|
|
|
- }
|
|
|
- if self.option.run_mode in precision_map.keys():
|
|
|
- config.enable_tensorrt_engine(
|
|
|
- workspace_size=(1 << 25) * self.option.batch_size,
|
|
|
- max_batch_size=self.option.batch_size,
|
|
|
- min_subgraph_size=self.option.min_subgraph_size,
|
|
|
- precision_mode=precision_map[self.option.run_mode],
|
|
|
- use_static=self.option.trt_use_static,
|
|
|
- use_calib_mode=self.option.trt_calib_mode,
|
|
|
- )
|
|
|
|
|
|
- if not os.path.exists(self.option.shape_info_filename):
|
|
|
- logging.info(
|
|
|
- f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
|
|
|
- )
|
|
|
- collect_trt_shapes(
|
|
|
- model_file,
|
|
|
- params_file,
|
|
|
- self.option.device_id,
|
|
|
- self.option.shape_info_filename,
|
|
|
- self.option.trt_dynamic_shapes,
|
|
|
- )
|
|
|
- else:
|
|
|
- logging.info(
|
|
|
- f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again."
|
|
|
- )
|
|
|
- config.enable_tuned_tensorrt_dynamic_shape(
|
|
|
- self.option.shape_info_filename, True
|
|
|
+ if hasattr(config, "enable_new_ir"):
|
|
|
+ config.enable_new_ir(self.option.enable_new_ir)
|
|
|
+ if hasattr(config, "enable_new_executor"):
|
|
|
+ config.enable_new_executor()
|
|
|
+ config.set_optimization_level(3)
|
|
|
+
|
|
|
+ # NOTE: The pptrt settings are not aligned with those of FD.
|
|
|
+ precision_map = {
|
|
|
+ "trt_int8": Config.Precision.Int8,
|
|
|
+ "trt_fp32": Config.Precision.Float32,
|
|
|
+ "trt_fp16": Config.Precision.Half,
|
|
|
+ }
|
|
|
+ if self.option.run_mode in precision_map.keys():
|
|
|
+ config.enable_tensorrt_engine(
|
|
|
+ workspace_size=(1 << 25) * self.option.batch_size,
|
|
|
+ max_batch_size=self.option.batch_size,
|
|
|
+ min_subgraph_size=self.option.min_subgraph_size,
|
|
|
+ precision_mode=precision_map[self.option.run_mode],
|
|
|
+ use_static=self.option.trt_use_static,
|
|
|
+ use_calib_mode=self.option.trt_calib_mode,
|
|
|
+ )
|
|
|
+
|
|
|
+ if not os.path.exists(self.option.shape_info_filename):
|
|
|
+ logging.info(
|
|
|
+ f"Dynamic shape info is collected into: {self.option.shape_info_filename}"
|
|
|
+ )
|
|
|
+ collect_trt_shapes(
|
|
|
+ model_file,
|
|
|
+ params_file,
|
|
|
+ self.option.device_id,
|
|
|
+ self.option.shape_info_filename,
|
|
|
+ self.option.trt_dynamic_shapes,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ logging.info(
|
|
|
+ f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again."
|
|
|
)
|
|
|
+ config.enable_tuned_tensorrt_dynamic_shape(
|
|
|
+ self.option.shape_info_filename, True
|
|
|
+ )
|
|
|
|
|
|
elif self.option.device == "npu":
|
|
|
config.enable_custom_device("npu")
|
|
|
@@ -193,6 +232,11 @@ class StaticInfer:
|
|
|
pass
|
|
|
elif self.option.device == "mlu":
|
|
|
config.enable_custom_device("mlu")
|
|
|
+ elif self.option.device == "dcu":
|
|
|
+ if paddle.is_compiled_with_rocm():
|
|
|
+ # Delete unsupported passes in dcu
|
|
|
+ config.delete_pass("conv2d_add_act_fuse_pass")
|
|
|
+ config.delete_pass("conv2d_add_fuse_pass")
|
|
|
else:
|
|
|
assert self.option.device == "cpu"
|
|
|
config.disable_gpu()
|
|
|
@@ -209,30 +253,21 @@ class StaticInfer:
|
|
|
else:
|
|
|
if hasattr(config, "disable_mkldnn"):
|
|
|
config.disable_mkldnn()
|
|
|
+ config.set_cpu_math_library_num_threads(self.option.cpu_threads)
|
|
|
|
|
|
- # Disable paddle inference logging
|
|
|
- config.disable_glog_info()
|
|
|
-
|
|
|
- config.set_cpu_math_library_num_threads(self.option.cpu_threads)
|
|
|
-
|
|
|
- if self.option.device in ("cpu", "gpu"):
|
|
|
- if not (
|
|
|
- self.option.device == "gpu" and self.option.run_mode.startswith("trt")
|
|
|
- ):
|
|
|
- if hasattr(config, "enable_new_ir"):
|
|
|
- config.enable_new_ir(self.option.enable_new_ir)
|
|
|
- if hasattr(config, "enable_new_executor"):
|
|
|
- config.enable_new_executor()
|
|
|
- config.set_optimization_level(3)
|
|
|
+ if hasattr(config, "enable_new_ir"):
|
|
|
+ config.enable_new_ir(self.option.enable_new_ir)
|
|
|
+ if hasattr(config, "enable_new_executor"):
|
|
|
+ config.enable_new_executor()
|
|
|
+ config.set_optimization_level(3)
|
|
|
|
|
|
+ config.enable_memory_optim()
|
|
|
for del_p in self.option.delete_pass:
|
|
|
config.delete_pass(del_p)
|
|
|
|
|
|
- if self.option.device in ("gpu", "dcu"):
|
|
|
- if paddle.is_compiled_with_rocm():
|
|
|
- # Delete unsupported passes in dcu
|
|
|
- config.delete_pass("conv2d_add_act_fuse_pass")
|
|
|
- config.delete_pass("conv2d_add_fuse_pass")
|
|
|
+ # Disable paddle inference logging
|
|
|
+ if not DEBUG:
|
|
|
+ config.disable_glog_info()
|
|
|
|
|
|
predictor = create_predictor(config)
|
|
|
|