|
|
@@ -133,15 +133,9 @@ def _convert_trt(
|
|
|
pp_model_file,
|
|
|
pp_params_file,
|
|
|
trt_save_path,
|
|
|
- trt_dynamic_shapes,
|
|
|
+ dynamic_shapes,
|
|
|
+ dynamic_shape_input_data,
|
|
|
):
|
|
|
- def _set_trt_config():
|
|
|
- if settings := TRT_CFG.get(model_name):
|
|
|
- for attr_name in settings:
|
|
|
- if not hasattr(trt_config, attr_name):
|
|
|
- logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
|
|
|
- setattr(trt_config, attr_name, settings[attr_name])
|
|
|
-
|
|
|
from lazy_paddle.tensorrt.export import (
|
|
|
Input,
|
|
|
TensorRTConfig,
|
|
|
@@ -149,25 +143,37 @@ def _convert_trt(
|
|
|
PrecisionMode,
|
|
|
)
|
|
|
|
|
|
- def _get_input_names(model_file, params_file):
|
|
|
+ def _set_trt_config():
|
|
|
+ if settings := TRT_CFG.get(model_name):
|
|
|
+ for attr_name in settings:
|
|
|
+ if not hasattr(trt_config, attr_name):
|
|
|
+ logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
|
|
|
+ setattr(trt_config, attr_name, settings[attr_name])
|
|
|
+
|
|
|
+ def _get_predictor(model_file, params_file):
|
|
|
# HACK
|
|
|
config = lazy_paddle.inference.Config(str(model_file), str(params_file))
|
|
|
# NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
|
|
|
config.disable_mkldnn()
|
|
|
config.disable_glog_info()
|
|
|
- predictor = lazy_paddle.inference.create_predictor(config)
|
|
|
- return predictor.get_input_names()
|
|
|
+ return lazy_paddle.inference.create_predictor(config)
|
|
|
|
|
|
- input_names = _get_input_names(pp_model_file, pp_params_file)
|
|
|
- for name in trt_dynamic_shapes:
|
|
|
+ dynamic_shape_input_data = dynamic_shape_input_data or {}
|
|
|
+
|
|
|
+ predictor = _get_predictor(pp_model_file, pp_params_file)
|
|
|
+ input_names = predictor.get_input_names()
|
|
|
+ for name in dynamic_shapes:
|
|
|
if name not in input_names:
|
|
|
raise ValueError(
|
|
|
- f"Invalid input name {repr(name)} found in `trt_dynamic_shapes`"
|
|
|
+ f"Invalid input name {repr(name)} found in `dynamic_shapes`"
|
|
|
)
|
|
|
for name in input_names:
|
|
|
- if name not in trt_dynamic_shapes:
|
|
|
+ if name not in dynamic_shapes:
|
|
|
+ raise ValueError(f"Input name {repr(name)} not found in `dynamic_shapes`")
|
|
|
+ for name in dynamic_shape_input_data:
|
|
|
+ if name not in input_names:
|
|
|
raise ValueError(
|
|
|
- f"Input name {repr(name)} not found in `trt_dynamic_shapes`"
|
|
|
+ f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
|
|
|
)
|
|
|
|
|
|
precision_map = {
|
|
|
@@ -176,13 +182,30 @@ def _convert_trt(
|
|
|
"trt_fp16": PrecisionMode.FP16,
|
|
|
}
|
|
|
trt_inputs = []
|
|
|
- for name in input_names:
|
|
|
- min_shape, opt_shape, max_shape = trt_dynamic_shapes[name]
|
|
|
- trt_input = Input(
|
|
|
- min_input_shape=min_shape,
|
|
|
- optim_input_shape=opt_shape,
|
|
|
- max_input_shape=max_shape,
|
|
|
- )
|
|
|
+ # for name in input_names:
|
|
|
+ for name, candidate_shapes in dynamic_shapes.items():
|
|
|
+ # XXX: Currently we have no way to get the data type of the tensor
|
|
|
+ # without creating an input handle.
|
|
|
+ handle = predictor.get_input_handle(name)
|
|
|
+ dtype = _pd_dtype_to_np_dtype(handle.type())
|
|
|
+ min_shape, opt_shape, max_shape = candidate_shapes
|
|
|
+ if name in dynamic_shape_input_data:
|
|
|
+ min_arr = np.array(dynamic_shape_input_data[name][0], dtype=dtype).reshape(
|
|
|
+ min_shape
|
|
|
+ )
|
|
|
+ opt_arr = np.array(dynamic_shape_input_data[name][1], dtype=dtype).reshape(
|
|
|
+ opt_shape
|
|
|
+ )
|
|
|
+ max_arr = np.array(dynamic_shape_input_data[name][2], dtype=dtype).reshape(
|
|
|
+ max_shape
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ min_arr = np.ones(min_shape, dtype=dtype)
|
|
|
+ opt_arr = np.ones(opt_shape, dtype=dtype)
|
|
|
+ max_arr = np.ones(max_shape, dtype=dtype)
|
|
|
+
|
|
|
+ # refer to: https://github.com/PolaKuma/Paddle/blob/3347f225bc09f2ec09802a2090432dd5cb5b6739/test/tensorrt/test_converter_model_resnet50.py
|
|
|
+ trt_input = Input((min_arr, opt_arr, max_arr))
|
|
|
trt_inputs.append(trt_input)
|
|
|
|
|
|
# Create TensorRTConfig
|
|
|
@@ -428,6 +451,7 @@ class StaticInfer(object):
|
|
|
params_file,
|
|
|
trt_save_path,
|
|
|
self._option.trt_dynamic_shapes,
|
|
|
+ self._option.trt_dynamic_shape_input_data,
|
|
|
)
|
|
|
model_file = trt_save_path.with_suffix(".json")
|
|
|
params_file = trt_save_path.with_suffix(".pdiparams")
|