|
|
@@ -12,10 +12,13 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import abc
|
|
|
+import importlib.util
|
|
|
+import subprocess
|
|
|
from typing import Sequence, List
|
|
|
from pathlib import Path
|
|
|
|
|
|
-import lazy_paddle
|
|
|
+import lazy_paddle as paddle
|
|
|
import numpy as np
|
|
|
|
|
|
from ....utils import logging
|
|
|
@@ -26,37 +29,44 @@ from ....utils.flags import (
|
|
|
INFER_BENCHMARK_USE_NEW_INFER_API,
|
|
|
)
|
|
|
from ...utils.benchmark import benchmark, set_inference_operations
|
|
|
-from ...utils.hpi import get_model_paths
|
|
|
+from ...utils.hpi import (
|
|
|
+ HPIConfig,
|
|
|
+ ONNXRuntimeConfig,
|
|
|
+ OpenVINOConfig,
|
|
|
+ TensorRTConfig,
|
|
|
+ OMConfig,
|
|
|
+ get_model_paths,
|
|
|
+ suggest_inference_backend_and_config,
|
|
|
+)
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
from ...utils.trt_config import DISABLE_TRT_HALF_OPS_CONFIG
|
|
|
|
|
|
|
|
|
CACHE_DIR = ".cache"
|
|
|
|
|
|
-if INFER_BENCHMARK_USE_NEW_INFER_API:
|
|
|
- INFERENCE_OPERATIONS = [
|
|
|
- "PaddleCopyToDevice",
|
|
|
- "PaddleCopyToHost",
|
|
|
- "PaddleModelInfer",
|
|
|
- ]
|
|
|
-else:
|
|
|
- INFERENCE_OPERATIONS = ["PaddleInferChainLegacy"]
|
|
|
+INFERENCE_OPERATIONS = [
|
|
|
+ "PaddleCopyToDevice",
|
|
|
+ "PaddleCopyToHost",
|
|
|
+ "PaddleModelInfer",
|
|
|
+ "PaddleInferChainLegacy",
|
|
|
+ "MultiBackendInfer",
|
|
|
+]
|
|
|
set_inference_operations(INFERENCE_OPERATIONS)
|
|
|
|
|
|
|
|
|
# XXX: Better use Paddle Inference API to do this
|
|
|
def _pd_dtype_to_np_dtype(pd_dtype):
|
|
|
- if pd_dtype == lazy_paddle.inference.DataType.FLOAT64:
|
|
|
+ if pd_dtype == paddle.inference.DataType.FLOAT64:
|
|
|
return np.float64
|
|
|
- elif pd_dtype == lazy_paddle.inference.DataType.FLOAT32:
|
|
|
+ elif pd_dtype == paddle.inference.DataType.FLOAT32:
|
|
|
return np.float32
|
|
|
- elif pd_dtype == lazy_paddle.inference.DataType.INT64:
|
|
|
+ elif pd_dtype == paddle.inference.DataType.INT64:
|
|
|
return np.int64
|
|
|
- elif pd_dtype == lazy_paddle.inference.DataType.INT32:
|
|
|
+ elif pd_dtype == paddle.inference.DataType.INT32:
|
|
|
return np.int32
|
|
|
- elif pd_dtype == lazy_paddle.inference.DataType.UINT8:
|
|
|
+ elif pd_dtype == paddle.inference.DataType.UINT8:
|
|
|
return np.uint8
|
|
|
- elif pd_dtype == lazy_paddle.inference.DataType.INT8:
|
|
|
+ elif pd_dtype == paddle.inference.DataType.INT8:
|
|
|
return np.int8
|
|
|
else:
|
|
|
raise TypeError(f"Unsupported data type: {pd_dtype}")
|
|
|
@@ -74,12 +84,12 @@ def _collect_trt_shape_range_info(
|
|
|
|
|
|
dynamic_shape_input_data = dynamic_shape_input_data or {}
|
|
|
|
|
|
- config = lazy_paddle.inference.Config(model_file, model_params)
|
|
|
+ config = paddle.inference.Config(model_file, model_params)
|
|
|
config.enable_use_gpu(100, gpu_id)
|
|
|
config.collect_shape_range_info(shape_range_info_path)
|
|
|
# TODO: Add other needed options
|
|
|
config.disable_glog_info()
|
|
|
- predictor = lazy_paddle.inference.create_predictor(config)
|
|
|
+ predictor = paddle.inference.create_predictor(config)
|
|
|
|
|
|
input_names = predictor.get_input_names()
|
|
|
for name in dynamic_shapes:
|
|
|
@@ -147,7 +157,7 @@ def _convert_trt(
|
|
|
dynamic_shapes,
|
|
|
dynamic_shape_input_data,
|
|
|
):
|
|
|
- from lazy_paddle.tensorrt.export import (
|
|
|
+ from paddle.tensorrt.export import (
|
|
|
Input,
|
|
|
TensorRTConfig,
|
|
|
convert,
|
|
|
@@ -162,12 +172,12 @@ def _convert_trt(
|
|
|
|
|
|
def _get_predictor(model_file, params_file):
|
|
|
# HACK
|
|
|
- config = lazy_paddle.inference.Config(str(model_file), str(params_file))
|
|
|
+ config = paddle.inference.Config(str(model_file), str(params_file))
|
|
|
config.enable_use_gpu(100, device_id)
|
|
|
# NOTE: Disable oneDNN to circumvent a bug in Paddle Inference
|
|
|
config.disable_mkldnn()
|
|
|
config.disable_glog_info()
|
|
|
- return lazy_paddle.inference.create_predictor(config)
|
|
|
+ return paddle.inference.create_predictor(config)
|
|
|
|
|
|
dynamic_shape_input_data = dynamic_shape_input_data or {}
|
|
|
|
|
|
@@ -246,7 +256,7 @@ class PaddleCopyToDevice:
|
|
|
def __call__(self, arrs):
|
|
|
device_id = [self.device_id] if self.device_id is not None else self.device_id
|
|
|
device = constr_device(self.device_type, device_id)
|
|
|
- paddle_tensors = [lazy_paddle.to_tensor(i, place=device) for i in arrs]
|
|
|
+ paddle_tensors = [paddle.to_tensor(i, place=device) for i in arrs]
|
|
|
return paddle_tensors
|
|
|
|
|
|
|
|
|
@@ -292,19 +302,25 @@ class PaddleInferChainLegacy:
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
-class StaticInfer(object):
|
|
|
+class StaticInfer(metaclass=abc.ABCMeta):
|
|
|
+ @abc.abstractmethod
|
|
|
+ def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+
|
|
|
+class PaddleInfer(StaticInfer):
|
|
|
def __init__(
|
|
|
self,
|
|
|
model_dir: str,
|
|
|
- model_prefix: str,
|
|
|
+ model_file_prefix: str,
|
|
|
option: PaddlePredictorOption,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
self.model_dir = model_dir
|
|
|
- self.model_file_prefix = model_prefix
|
|
|
+ self.model_file_prefix = model_file_prefix
|
|
|
self._option = option
|
|
|
self.predictor = self._create()
|
|
|
- if self._use_new_inference_api:
|
|
|
+ if INFER_BENCHMARK_USE_NEW_INFER_API:
|
|
|
device_type = self._option.device_type
|
|
|
device_type = "gpu" if device_type == "dcu" else device_type
|
|
|
copy_to_device = PaddleCopyToDevice(device_type, self._option.device_id)
|
|
|
@@ -314,13 +330,6 @@ class StaticInfer(object):
|
|
|
else:
|
|
|
self.infer = PaddleInferChainLegacy(self.predictor)
|
|
|
|
|
|
- @property
|
|
|
- def _use_new_inference_api(self):
|
|
|
- # HACK: Temp fallback to legacy API via env var
|
|
|
- return INFER_BENCHMARK_USE_NEW_INFER_API
|
|
|
-
|
|
|
- # return self._option.device_type in ("cpu", "gpu", "dcu")
|
|
|
-
|
|
|
def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
|
|
|
names = self.predictor.get_input_names()
|
|
|
if len(names) != len(x):
|
|
|
@@ -340,7 +349,7 @@ class StaticInfer(object):
|
|
|
"""_create"""
|
|
|
model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
|
|
|
if "paddle" not in model_paths:
|
|
|
- raise RuntimeError("No valid Paddle model found")
|
|
|
+ raise RuntimeError("No valid PaddlePaddle model found")
|
|
|
model_file, params_file = model_paths["paddle"]
|
|
|
|
|
|
if (
|
|
|
@@ -383,10 +392,10 @@ class StaticInfer(object):
|
|
|
config.enable_use_gpu(100, self._option.device_id)
|
|
|
# for Native Paddle and MKLDNN
|
|
|
else:
|
|
|
- config = lazy_paddle.inference.Config(str(model_file), str(params_file))
|
|
|
+ config = paddle.inference.Config(str(model_file), str(params_file))
|
|
|
if self._option.device_type == "gpu":
|
|
|
config.exp_disable_mixed_precision_ops({"feed", "fetch"})
|
|
|
- from lazy_paddle.inference import PrecisionType
|
|
|
+ from paddle.inference import PrecisionType
|
|
|
|
|
|
precision = (
|
|
|
PrecisionType.Half
|
|
|
@@ -427,7 +436,7 @@ class StaticInfer(object):
|
|
|
if hasattr(config, "enable_new_executor"):
|
|
|
config.enable_new_executor()
|
|
|
# XXX: is_compiled_with_rocm() must be True on dcu platform ?
|
|
|
- if lazy_paddle.is_compiled_with_rocm():
|
|
|
+ if paddle.is_compiled_with_rocm():
|
|
|
# Delete unsupported passes in dcu
|
|
|
config.delete_pass("conv2d_add_act_fuse_pass")
|
|
|
config.delete_pass("conv2d_add_fuse_pass")
|
|
|
@@ -463,7 +472,7 @@ class StaticInfer(object):
|
|
|
if not DEBUG:
|
|
|
config.disable_glog_info()
|
|
|
|
|
|
- predictor = lazy_paddle.inference.create_predictor(config)
|
|
|
+ predictor = paddle.inference.create_predictor(config)
|
|
|
|
|
|
return predictor
|
|
|
|
|
|
@@ -482,9 +491,9 @@ class StaticInfer(object):
|
|
|
)
|
|
|
model_file = trt_save_path.with_suffix(".json")
|
|
|
params_file = trt_save_path.with_suffix(".pdiparams")
|
|
|
- config = lazy_paddle.inference.Config(str(model_file), str(params_file))
|
|
|
+ config = paddle.inference.Config(str(model_file), str(params_file))
|
|
|
else:
|
|
|
- config = lazy_paddle.inference.Config(str(model_file), str(params_file))
|
|
|
+ config = paddle.inference.Config(str(model_file), str(params_file))
|
|
|
config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
|
|
|
# call enable_use_gpu() first to use TensorRT engine
|
|
|
config.enable_use_gpu(100, self._option.device_id)
|
|
|
@@ -534,8 +543,11 @@ class StaticInfer(object):
|
|
|
self._option.trt_dynamic_shapes,
|
|
|
self._option.trt_dynamic_shape_input_data,
|
|
|
)
|
|
|
- if self._option.model_name in DISABLE_TRT_HALF_OPS_CONFIG and self._option.run_mode == "trt_fp16":
|
|
|
- lazy_paddle.inference.InternalUtils.disable_tensorrt_half_ops(
|
|
|
+ if (
|
|
|
+ self._option.model_name in DISABLE_TRT_HALF_OPS_CONFIG
|
|
|
+ and self._option.run_mode == "trt_fp16"
|
|
|
+ ):
|
|
|
+ paddle.inference.InternalUtils.disable_tensorrt_half_ops(
|
|
|
config, DISABLE_TRT_HALF_OPS_CONFIG[self._option.model_name]
|
|
|
)
|
|
|
config.enable_tuned_tensorrt_dynamic_shape(
|
|
|
@@ -559,3 +571,288 @@ class StaticInfer(object):
|
|
|
raise RuntimeError("No dynamic shape information provided")
|
|
|
|
|
|
return config
|
|
|
+
|
|
|
+
|
|
|
+# FIXME: Name might be misleading
|
|
|
+@benchmark.timeit
|
|
|
+class MultiBackendInfer(object):
|
|
|
+ def __init__(self, ui_runtime):
|
|
|
+ super().__init__()
|
|
|
+ self.ui_runtime = ui_runtime
|
|
|
+
|
|
|
+ # The time consumed by the wrapper code will also be taken into account.
|
|
|
+ def __call__(self, x):
|
|
|
+ outputs = self.ui_runtime.infer(x)
|
|
|
+ return outputs
|
|
|
+
|
|
|
+
|
|
|
+# TODO: It would be better to refactor the code to make `HPInfer` a higher-level
|
|
|
+# class that uses `PaddleInfer`.
|
|
|
+class HPInfer(StaticInfer):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ model_dir: str,
|
|
|
+ model_file_prefix: str,
|
|
|
+ config: HPIConfig,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self._model_dir = model_dir
|
|
|
+ self._model_file_prefix = model_file_prefix
|
|
|
+ self._config = config
|
|
|
+ backend, backend_config = self._determine_backend_and_config()
|
|
|
+ if backend == "paddle":
|
|
|
+ self._use_paddle = True
|
|
|
+ self._paddle_infer = self._build_paddle_infer(backend_config)
|
|
|
+ else:
|
|
|
+ self._use_paddle = False
|
|
|
+ ui_runtime = self._build_ui_runtime(backend, backend_config)
|
|
|
+ self._multi_backend_infer = MultiBackendInfer(ui_runtime)
|
|
|
+ num_inputs = ui_runtime.num_inputs()
|
|
|
+ self._input_names = [
|
|
|
+ ui_runtime.get_input_info(i).name for i in range(num_inputs)
|
|
|
+ ]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def model_dir(self) -> str:
|
|
|
+ return self._model_dir
|
|
|
+
|
|
|
+ @property
|
|
|
+ def model_file_prefix(self) -> str:
|
|
|
+ return self._model_file_prefix
|
|
|
+
|
|
|
+ @property
|
|
|
+ def config(self) -> HPIConfig:
|
|
|
+ return self._config
|
|
|
+
|
|
|
+ def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
|
|
|
+ if self._use_paddle:
|
|
|
+ return self._call_paddle_infer(x)
|
|
|
+ else:
|
|
|
+ return self._call_multi_backend_infer(x)
|
|
|
+
|
|
|
+ def _call_paddle_infer(self, x):
|
|
|
+ return self._paddle_infer(x)
|
|
|
+
|
|
|
+ def _call_multi_backend_infer(self, x):
|
|
|
+ num_inputs = len(self._input_names)
|
|
|
+ if len(x) != num_inputs:
|
|
|
+ raise ValueError(f"Expected {num_inputs} inputs but got {len(x)} instead")
|
|
|
+ x = _sort_inputs(x, self._input_names)
|
|
|
+ inputs = {}
|
|
|
+ for name, input_ in zip(self._input_names, x):
|
|
|
+ inputs[name] = np.ascontiguousarray(input_)
|
|
|
+ return self._multi_backend_infer(inputs)
|
|
|
+
|
|
|
+ def _determine_backend_and_config(self):
|
|
|
+ from ultra_infer import (
|
|
|
+ is_built_with_om,
|
|
|
+ is_built_with_openvino,
|
|
|
+ is_built_with_ort,
|
|
|
+ is_built_with_trt,
|
|
|
+ )
|
|
|
+
|
|
|
+ model_paths = get_model_paths(self._model_dir, self._model_file_prefix)
|
|
|
+ is_onnx_model_available = "onnx" in model_paths
|
|
|
+ # TODO: Give a warning if Paddle2ONNX is not available but can be used
|
|
|
+ # to select a better backend.
|
|
|
+ if self._config.auto_paddle2onnx:
|
|
|
+ if self._check_paddle2onnx():
|
|
|
+ is_onnx_model_available = (
|
|
|
+ is_onnx_model_available or "paddle" in model_paths
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ logging.debug(
|
|
|
+ "Paddle2ONNX is not available. Automatic model conversion will not be performed."
|
|
|
+ )
|
|
|
+ available_backends = []
|
|
|
+ if "paddle" in model_paths:
|
|
|
+ available_backends.append("paddle")
|
|
|
+ if is_built_with_openvino() and is_onnx_model_available:
|
|
|
+ available_backends.append("openvino")
|
|
|
+ if is_built_with_ort() and is_onnx_model_available:
|
|
|
+ available_backends.append("onnxruntime")
|
|
|
+ if is_built_with_trt() and is_onnx_model_available:
|
|
|
+ available_backends.append("tensorrt")
|
|
|
+ if is_built_with_om() and "om" in model_paths:
|
|
|
+ available_backends.append("om")
|
|
|
+
|
|
|
+ if not available_backends:
|
|
|
+ raise RuntimeError("No inference backend is available")
|
|
|
+
|
|
|
+ if (
|
|
|
+ self._config.backend is not None
|
|
|
+ and self._config.backend not in available_backends
|
|
|
+ ):
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Inference backend {repr(self._config.backend)} is unavailable"
|
|
|
+ )
|
|
|
+
|
|
|
+ if self._config.auto_config:
|
|
|
+ # Should we use the strategy pattern here to allow extensible
|
|
|
+ # strategies?
|
|
|
+ ret = suggest_inference_backend_and_config(
|
|
|
+ self._config, available_backends=available_backends
|
|
|
+ )
|
|
|
+ if ret[0] is None:
|
|
|
+ # Should I use a custom exception?
|
|
|
+ raise RuntimeError(
|
|
|
+ f"No inference backend and configuration could be suggested. Reason: {ret[1]}"
|
|
|
+ )
|
|
|
+ backend, backend_config = ret
|
|
|
+ else:
|
|
|
+ backend = self._config.backend
|
|
|
+ if backend is None:
|
|
|
+ raise RuntimeError(
|
|
|
+ "When automatic configuration is not used, the inference backend must be specified manually."
|
|
|
+ )
|
|
|
+ backend_config = self._config.backend_config or {}
|
|
|
+
|
|
|
+ if backend == "paddle" and not backend_config:
|
|
|
+ logging.warning(
|
|
|
+ "The Paddle Inference backend is selected with the default configuration. This may not provide optimal performance."
|
|
|
+ )
|
|
|
+
|
|
|
+ return backend, backend_config
|
|
|
+
|
|
|
+ def _build_paddle_infer(self, backend_config):
|
|
|
+ kwargs = {
|
|
|
+ "device_type": self._config.device_type,
|
|
|
+ "device_id": self._config.device_id,
|
|
|
+ **backend_config,
|
|
|
+ }
|
|
|
+ # TODO: This is probably redundant. Can we reuse the code in the
|
|
|
+ # predictor class?
|
|
|
+ paddle_info = self._config.hpi_info.backend_configs.paddle_infer
|
|
|
+ if paddle_info is not None:
|
|
|
+ if (
|
|
|
+ kwargs.get("trt_dynamic_shapes") is None
|
|
|
+ and paddle_info.trt_dynamic_shapes is not None
|
|
|
+ ):
|
|
|
+ trt_dynamic_shapes = paddle_info.trt_dynamic_shapes
|
|
|
+ logging.debug("TensorRT dynamic shapes set to %s", trt_dynamic_shapes)
|
|
|
+ kwargs["trt_dynamic_shapes"] = trt_dynamic_shapes
|
|
|
+ if (
|
|
|
+ kwargs.get("trt_dynamic_shape_input_data") is None
|
|
|
+ and paddle_info.trt_dynamic_shape_input_data is not None
|
|
|
+ ):
|
|
|
+ trt_dynamic_shape_input_data = paddle_info.trt_dynamic_shape_input_data
|
|
|
+ logging.debug(
|
|
|
+ "TensorRT dynamic shape input data set to %s",
|
|
|
+ trt_dynamic_shape_input_data,
|
|
|
+ )
|
|
|
+ kwargs["trt_dynamic_shape_input_data"] = trt_dynamic_shape_input_data
|
|
|
+ pp_option = PaddlePredictorOption(self._config.pdx_model_name, **kwargs)
|
|
|
+ logging.info("Using Paddle Inference backend")
|
|
|
+ logging.info("Paddle predictor option: %s", pp_option)
|
|
|
+ return PaddleInfer(self._model_dir, self._model_file_prefix, option=pp_option)
|
|
|
+
|
|
|
+ def _build_ui_runtime(self, backend, backend_config, ui_option=None):
|
|
|
+ from ultra_infer import ModelFormat, Runtime, RuntimeOption
|
|
|
+
|
|
|
+ if ui_option is None:
|
|
|
+ ui_option = RuntimeOption()
|
|
|
+
|
|
|
+ if self._config.device_type == "cpu":
|
|
|
+ pass
|
|
|
+ elif self._config.device_type == "gpu":
|
|
|
+ ui_option.use_gpu(self._config.device_id or 0)
|
|
|
+ elif self._config.device_type == "npu":
|
|
|
+ ui_option.use_ascend()
|
|
|
+ else:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Unsupported device type {repr(self._config.device_type)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
|
|
|
+ if backend in ("openvino", "onnxruntime", "tensorrt"):
|
|
|
+ # XXX: This introduces side effects.
|
|
|
+ if "onnx" not in model_paths:
|
|
|
+ if self._config.auto_paddle2onnx:
|
|
|
+ if "paddle" not in model_paths:
|
|
|
+ raise RuntimeError("PaddlePaddle model required")
|
|
|
+ # The CLI is used here since there is currently no API.
|
|
|
+ logging.info(
|
|
|
+ "Automatically converting PaddlePaddle model to ONNX format"
|
|
|
+ )
|
|
|
+ subprocess.check_call(
|
|
|
+ [
|
|
|
+ "paddlex",
|
|
|
+ "--paddle2onnx",
|
|
|
+ "--paddle_model_dir",
|
|
|
+ self._model_dir,
|
|
|
+ "--onnx_model_dir",
|
|
|
+ self._model_dir,
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ model_paths = get_model_paths(
|
|
|
+ self.model_dir, self.model_file_prefix
|
|
|
+ )
|
|
|
+ assert "onnx" in model_paths
|
|
|
+ else:
|
|
|
+ raise RuntimeError("ONNX model required")
|
|
|
+ ui_option.set_model_path(str(model_paths["onnx"]), "", ModelFormat.ONNX)
|
|
|
+ elif backend == "om":
|
|
|
+ if "om" not in model_paths:
|
|
|
+ raise RuntimeError("OM model required")
|
|
|
+ ui_option.set_model_path(str(model_paths["om"]), "", ModelFormat.OM)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported inference backend {repr(backend)}")
|
|
|
+
|
|
|
+ if backend == "openvino":
|
|
|
+ backend_config = OpenVINOConfig.model_validate(backend_config)
|
|
|
+ ui_option.use_openvino_backend()
|
|
|
+ ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
|
|
|
+ elif backend == "onnxruntime":
|
|
|
+ backend_config = ONNXRuntimeConfig.model_validate(backend_config)
|
|
|
+ ui_option.use_ort_backend()
|
|
|
+ ui_option.set_cpu_thread_num(backend_config.cpu_num_threads)
|
|
|
+ elif backend == "tensorrt":
|
|
|
+ if (
|
|
|
+ backend_config.get("use_dynamic_shapes", True)
|
|
|
+ and backend_config.get("dynamic_shapes") is None
|
|
|
+ ):
|
|
|
+ trt_info = self._config.hpi_info.backend_configs.tensorrt
|
|
|
+ if trt_info is not None and trt_info.dynamic_shapes is not None:
|
|
|
+ trt_dynamic_shapes = trt_info.dynamic_shapes
|
|
|
+ logging.debug(
|
|
|
+ "TensorRT dynamic shapes set to %s", trt_dynamic_shapes
|
|
|
+ )
|
|
|
+ backend_config = {
|
|
|
+ **backend_config,
|
|
|
+ "dynamic_shapes": trt_dynamic_shapes,
|
|
|
+ }
|
|
|
+ backend_config = TensorRTConfig.model_validate(backend_config)
|
|
|
+ ui_option.use_trt_backend()
|
|
|
+ cache_dir = self.model_dir / CACHE_DIR / "tensorrt"
|
|
|
+ cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ ui_option.trt_option.serialize_file = str(cache_dir / "trt_serialized.trt")
|
|
|
+ if backend_config.precision == "FP16":
|
|
|
+ ui_option.trt_option.enable_fp16 = True
|
|
|
+ if not backend_config.use_dynamic_shapes:
|
|
|
+ raise RuntimeError(
|
|
|
+ "TensorRT static shape inference is currently not supported"
|
|
|
+ )
|
|
|
+ if backend_config.dynamic_shapes is not None:
|
|
|
+ if not Path(ui_option.trt_option.serialize_file).exists():
|
|
|
+ for name, shapes in backend_config.dynamic_shapes.items():
|
|
|
+ ui_option.trt_option.set_shape(name, *shapes)
|
|
|
+ else:
|
|
|
+ logging.warning(
|
|
|
+ "TensorRT dynamic shapes will be loaded from the file."
|
|
|
+ )
|
|
|
+ elif backend == "om":
|
|
|
+ backend_config = OMConfig.model_validate(backend_config)
|
|
|
+ ui_option.use_om_backend()
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported inference backend {repr(backend)}")
|
|
|
+
|
|
|
+ logging.info("Inference backend: %s", backend)
|
|
|
+ logging.info("Inference backend config: %s", backend_config)
|
|
|
+
|
|
|
+ ui_runtime = Runtime(ui_option)
|
|
|
+
|
|
|
+ return ui_runtime
|
|
|
+
|
|
|
+ def _check_paddle2onnx(self):
|
|
|
+ # HACK
|
|
|
+ return importlib.util.find_spec("paddle2onnx") is not None
|