Browse Source

support trt_dynamic_shape_input_data for pir trt

gaotingquan 8 months ago
parent
commit
bf87804450
1 changed files with 47 additions and 23 deletions
  1. 47 23
      paddlex/inference/models/common/static_infer.py

+ 47 - 23
paddlex/inference/models/common/static_infer.py

@@ -133,15 +133,9 @@ def _convert_trt(
     pp_model_file,
     pp_params_file,
     trt_save_path,
-    trt_dynamic_shapes,
+    dynamic_shapes,
+    dynamic_shape_input_data,
 ):
-    def _set_trt_config():
-        if settings := TRT_CFG.get(model_name):
-            for attr_name in settings:
-                if not hasattr(trt_config, attr_name):
-                    logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
-                setattr(trt_config, attr_name, settings[attr_name])
-
     from lazy_paddle.tensorrt.export import (
         Input,
         TensorRTConfig,
@@ -149,25 +143,37 @@ def _convert_trt(
         PrecisionMode,
     )
 
-    def _get_input_names(model_file, params_file):
+    def _set_trt_config():
+        if settings := TRT_CFG.get(model_name):
+            for attr_name in settings:
+                if not hasattr(trt_config, attr_name):
+                    logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
+                setattr(trt_config, attr_name, settings[attr_name])
+
+    def _get_predictor(model_file, params_file):
         # HACK
         config = lazy_paddle.inference.Config(str(model_file), str(params_file))
         # NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
         config.disable_mkldnn()
         config.disable_glog_info()
-        predictor = lazy_paddle.inference.create_predictor(config)
-        return predictor.get_input_names()
+        return lazy_paddle.inference.create_predictor(config)
 
-    input_names = _get_input_names(pp_model_file, pp_params_file)
-    for name in trt_dynamic_shapes:
+    dynamic_shape_input_data = dynamic_shape_input_data or {}
+
+    predictor = _get_predictor(pp_model_file, pp_params_file)
+    input_names = predictor.get_input_names()
+    for name in dynamic_shapes:
         if name not in input_names:
             raise ValueError(
-                f"Invalid input name {repr(name)} found in `trt_dynamic_shapes`"
+                f"Invalid input name {repr(name)} found in `dynamic_shapes`"
             )
     for name in input_names:
-        if name not in trt_dynamic_shapes:
+        if name not in dynamic_shapes:
+            raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
+    for name in dynamic_shape_input_data:
+        if name not in input_names:
             raise ValueError(
-                f"Input name {repr(name)} not found in `trt_dynamic_shapes`"
+                f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
             )
 
     precision_map = {
@@ -176,13 +182,30 @@ def _convert_trt(
         "trt_fp16": PrecisionMode.FP16,
     }
     trt_inputs = []
-    for name in input_names:
-        min_shape, opt_shape, max_shape = trt_dynamic_shapes[name]
-        trt_input = Input(
-            min_input_shape=min_shape,
-            optim_input_shape=opt_shape,
-            max_input_shape=max_shape,
-        )
+    # for name in input_names:
+    for name, candidate_shapes in dynamic_shapes.items():
+        # XXX: Currently we have no way to get the data type of the tensor
+        # without creating an input handle.
+        handle = predictor.get_input_handle(name)
+        dtype = _pd_dtype_to_np_dtype(handle.type())
+        min_shape, opt_shape, max_shape = candidate_shapes
+        if name in dynamic_shape_input_data:
+            min_arr = np.array(dynamic_shape_input_data[name][0], dtype=dtype).reshape(
+                min_shape
+            )
+            opt_arr = np.array(dynamic_shape_input_data[name][1], dtype=dtype).reshape(
+                opt_shape
+            )
+            max_arr = np.array(dynamic_shape_input_data[name][2], dtype=dtype).reshape(
+                max_shape
+            )
+        else:
+            min_arr = np.ones(min_shape, dtype=dtype)
+            opt_arr = np.ones(opt_shape, dtype=dtype)
+            max_arr = np.ones(max_shape, dtype=dtype)
+
+        # refer to: https://github.com/PolaKuma/Paddle/blob/3347f225bc09f2ec09802a2090432dd5cb5b6739/test/tensorrt/test_converter_model_resnet50.py
+        trt_input = Input((min_arr, opt_arr, max_arr))
         trt_inputs.append(trt_input)
 
     # Create TensorRTConfig
@@ -428,6 +451,7 @@ class StaticInfer(object):
                 params_file,
                 trt_save_path,
                 self._option.trt_dynamic_shapes,
+                self._option.trt_dynamic_shape_input_data,
             )
             model_file = trt_save_path.with_suffix(".json")
             params_file = trt_save_path.with_suffix(".pdiparams")