Browse Source

when pir trt, it's need that enabling use gpu in convert_trt

gaotingquan 8 months ago
parent
commit
976eeb533f
1 changed files with 3 additions and 1 deletions
  1. 3 1
      paddlex/inference/models/common/static_infer.py

+ 3 - 1
paddlex/inference/models/common/static_infer.py

@@ -142,6 +142,7 @@ def _convert_trt(
     pp_model_file,
     pp_params_file,
     trt_save_path,
+    device_id,
     dynamic_shapes,
     dynamic_shape_input_data,
 ):
@@ -161,6 +162,7 @@ def _convert_trt(
     def _get_predictor(model_file, params_file):
         # HACK
         config = lazy_paddle.inference.Config(str(model_file), str(params_file))
+        config.enable_use_gpu(100, device_id)
         # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
         config.disable_mkldnn()
         config.disable_glog_info()
@@ -473,6 +475,7 @@ class StaticInfer(object):
                 model_file,
                 params_file,
                 trt_save_path,
+                self._option.device_id,
                 self._option.trt_dynamic_shapes,
                 self._option.trt_dynamic_shape_input_data,
             )
@@ -483,7 +486,6 @@ class StaticInfer(object):
             config = lazy_paddle.inference.Config(str(model_file), str(params_file))
 
             config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
-            config.enable_use_gpu(100, self._option.device_id)
             for func_name in self._option.trt_cfg_setting:
                 assert hasattr(
                     config, func_name