Bläddra i källkod

add precision option (#2721)

* add precision option

* fix

* fix
zhang-prog 10 månader sedan
förälder
incheckning
3ac057ff21
1 ändrade filer med 14 tillägg och 8 borttagningar
  1. 14 8
      libs/paddlex-hpi/src/paddlex_hpi/_config.py

+ 14 - 8
libs/paddlex-hpi/src/paddlex_hpi/_config.py

@@ -14,7 +14,7 @@
 
 import warnings
 from pathlib import Path
-from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union
+from typing import Any, Dict, Literal, List, Mapping, Optional, Tuple, Type, Union
 
 import ultra_infer as ui
 from paddlex.utils import logging
@@ -36,6 +36,7 @@ class PaddleInferConfig(_BackendConfig):
     enable_trt: bool = False
     trt_dynamic_shapes: Optional[Dict[str, List[List[int]]]] = None
     trt_dynamic_shape_input_data: Optional[Dict[str, List[List[float]]]] = None
+    trt_precision: Literal["FP32", "FP16"] = "FP32"
     enable_log_info: bool = False
 
     def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
@@ -43,16 +44,18 @@ class PaddleInferConfig(_BackendConfig):
         option.set_cpu_thread_num(self.cpu_num_threads)
         option.paddle_infer_option.enable_mkldnn = self.enable_mkldnn
         option.paddle_infer_option.enable_trt = self.enable_trt
-        option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
-        if self.trt_dynamic_shapes is not None:
-            for name, shapes in self.trt_dynamic_shapes.items():
-                option.trt_option.set_shape(name, *shapes)
-        if self.trt_dynamic_shape_input_data is not None:
-            for name, data in self.trt_dynamic_shape_input_data.items():
-                option.trt_option.set_input_data(name, *data)
         if self.enable_trt:
+            option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
             option.paddle_infer_option.collect_trt_shape = True
             option.paddle_infer_option.collect_trt_shape_by_device = True
+            if self.trt_dynamic_shapes is not None:
+                for name, shapes in self.trt_dynamic_shapes.items():
+                    option.trt_option.set_shape(name, *shapes)
+            if self.trt_dynamic_shape_input_data is not None:
+                for name, data in self.trt_dynamic_shape_input_data.items():
+                    option.trt_option.set_input_data(name, *data)
+            if self.trt_precision == "FP16":
+                option.trt_option.enable_fp16 = True
         option.paddle_infer_option.enable_log_info = self.enable_log_info
 
 
@@ -73,11 +76,14 @@ class ONNXRuntimeConfig(_BackendConfig):
 
 
 class TensorRTConfig(_BackendConfig):
+    precision: Literal["FP32", "FP16"] = "FP32"
     dynamic_shapes: Optional[Dict[str, List[List[int]]]] = None
 
     def update_ui_option(self, option: ui.RuntimeOption, model_dir: Path) -> None:
         option.use_trt_backend()
         option.trt_option.serialize_file = str(model_dir / "trt_serialized.trt")
+        if self.precision == "FP16":
+            option.trt_option.enable_fp16 = True
         if self.dynamic_shapes is not None:
             for name, shapes in self.dynamic_shapes.items():
                 option.trt_option.set_shape(name, *shapes)