|
|
@@ -14,7 +14,7 @@
|
|
|
|
|
|
import warnings
|
|
|
from pathlib import Path
|
|
|
-from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union
|
|
|
+from typing import Any, Dict, Literal, List, Mapping, Optional, Tuple, Type, Union
|
|
|
|
|
|
import ultra_infer as ui
|
|
|
from paddlex.utils import logging
|
|
|
@@ -36,6 +36,7 @@ class PaddleInferConfig(_BackendConfig):
|
|
|
enable_trt: bool = False
|
|
|
trt_dynamic_shapes: Optional[Dict[str, List[List[int]]]] = None
|
|
|
trt_dynamic_shape_input_data: Optional[Dict[str, List[List[float]]]] = None
|
|
|
+ trt_precision: Literal["FP32", "FP16"] = "FP32"
|
|
|
enable_log_info: bool = False
|
|
|
|
|
|
def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
|
|
|
@@ -43,16 +44,18 @@ class PaddleInferConfig(_BackendConfig):
|
|
|
option.set_cpu_thread_num(self.cpu_num_threads)
|
|
|
option.paddle_infer_option.enable_mkldnn = self.enable_mkldnn
|
|
|
option.paddle_infer_option.enable_trt = self.enable_trt
|
|
|
- option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
|
|
|
- if self.trt_dynamic_shapes is not None:
|
|
|
- for name, shapes in self.trt_dynamic_shapes.items():
|
|
|
- option.trt_option.set_shape(name, *shapes)
|
|
|
- if self.trt_dynamic_shape_input_data is not None:
|
|
|
- for name, data in self.trt_dynamic_shape_input_data.items():
|
|
|
- option.trt_option.set_input_data(name, *data)
|
|
|
if self.enable_trt:
|
|
|
+ option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
|
|
|
option.paddle_infer_option.collect_trt_shape = True
|
|
|
option.paddle_infer_option.collect_trt_shape_by_device = True
|
|
|
+ if self.trt_dynamic_shapes is not None:
|
|
|
+ for name, shapes in self.trt_dynamic_shapes.items():
|
|
|
+ option.trt_option.set_shape(name, *shapes)
|
|
|
+ if self.trt_dynamic_shape_input_data is not None:
|
|
|
+ for name, data in self.trt_dynamic_shape_input_data.items():
|
|
|
+ option.trt_option.set_input_data(name, *data)
|
|
|
+ if self.trt_precision == "FP16":
|
|
|
+ option.trt_option.enable_fp16 = True
|
|
|
option.paddle_infer_option.enable_log_info = self.enable_log_info
|
|
|
|
|
|
|
|
|
@@ -73,11 +76,14 @@ class ONNXRuntimeConfig(_BackendConfig):
|
|
|
|
|
|
|
|
|
class TensorRTConfig(_BackendConfig):
|
|
|
+ precision: Literal["FP32", "FP16"] = "FP32"
|
|
|
dynamic_shapes: Optional[Dict[str, List[List[int]]]] = None
|
|
|
|
|
|
def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
|
|
|
option.use_trt_backend()
|
|
|
option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
|
|
|
+ if self.precision == "FP16":
|
|
|
+ option.trt_option.enable_fp16 = True
|
|
|
if self.dynamic_shapes is not None:
|
|
|
for name, shapes in self.dynamic_shapes.items():
|
|
|
option.trt_option.set_shape(name, *shapes)
|