瀏覽代碼

support trt_dynamic_shape_input_data

copy from #2860
gaotingquan 9 月之前
父節點
當前提交
0ea7117782
共有 1 個文件被更改,包括 9 次插入0 次删除
  1. 9 0
      paddlex/inference/models/base/predictor/basic_predictor.py

+ 9 - 0
paddlex/inference/models/base/predictor/basic_predictor.py

@@ -65,6 +65,15 @@ class BasicPredictor(
         )
         if trt_dynamic_shapes:
             pp_option.trt_dynamic_shapes = trt_dynamic_shapes
+        trt_dynamic_shape_input_data = (
+            self.config.get("Hpi", {})
+            .get("backend_configs", {})
+            .get("paddle_infer", {})
+            .get("trt_dynamic_shape_input_data", None)
+        )
+        if trt_dynamic_shape_input_data:
+            pp_option.trt_dynamic_shape_input_data = trt_dynamic_shape_input_data
+
         self.pp_option = pp_option
         self.pp_option.batch_size = batch_size
         self.batch_sampler.batch_size = batch_size