|
|
@@ -23,8 +23,10 @@ import numpy as np
|
|
|
from ....utils.flags import DEBUG, FLAGS_json_format_model, USE_PIR_TRT
|
|
|
from ....utils import logging
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
+from ...utils.trt_config import TRT_CFG
|
|
|
|
|
|
|
|
|
+# old trt
|
|
|
def collect_trt_shapes(
|
|
|
model_file, model_params, gpu_id, shape_range_info_path, trt_dynamic_shapes
|
|
|
):
|
|
|
@@ -48,7 +50,15 @@ def collect_trt_shapes(
|
|
|
predictor.run()
|
|
|
|
|
|
|
|
|
-def convert_trt(mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
|
|
|
+# pir trt
|
|
|
+def convert_trt(model_name, mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
|
|
|
+ def _set_trt_config():
|
|
|
+ if settings := TRT_CFG.get(model_name):
|
|
|
+ for attr_name in settings:
|
|
|
+ if not hasattr(trt_config, attr_name):
|
|
|
+ logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
|
|
|
+ setattr(trt_config, attr_name, settings[attr_name])
|
|
|
+
|
|
|
from lazy_paddle.tensorrt.export import (
|
|
|
Input,
|
|
|
TensorRTConfig,
|
|
|
@@ -73,6 +83,7 @@ def convert_trt(mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
|
|
|
|
|
|
# Create TensorRTConfig
|
|
|
trt_config = TensorRTConfig(inputs=trt_inputs)
|
|
|
+ _set_trt_config()
|
|
|
trt_config.precision_mode = precision_map[mode]
|
|
|
trt_config.save_model_dir = trt_save_path
|
|
|
convert(pp_model_path, trt_config)
|
|
|
@@ -197,6 +208,7 @@ class StaticInfer:
|
|
|
).as_posix()
|
|
|
pp_model_path = (Path(self.model_dir) / self.model_prefix).as_posix()
|
|
|
convert_trt(
|
|
|
+ self.option.model_name,
|
|
|
self.option.run_mode,
|
|
|
pp_model_path,
|
|
|
trt_save_path,
|