Răsfoiți Sursa

support to disable trt half ops

gaotingquan 8 luni în urmă
părinte
comite
c5c819d47b

+ 5 - 0
paddlex/inference/models/common/static_infer.py

@@ -28,6 +28,7 @@ from ....utils.flags import (
 from ...utils.benchmark import benchmark, set_inference_operations
 from ...utils.hpi import get_model_paths
 from ...utils.pp_option import PaddlePredictorOption
+from ...utils.trt_config import DISABLE_TRT_HALF_OPS_CONFIG
 
 
 CACHE_DIR = ".cache"
@@ -533,6 +534,10 @@ class StaticInfer(object):
                             self._option.trt_dynamic_shapes,
                             self._option.trt_dynamic_shape_input_data,
                         )
+                    if self._option.model_name in DISABLE_TRT_HALF_OPS_CONFIG:
+                        lazy_paddle.inference.InternalUtils.disable_tensorrt_half_ops(
+                            config, DISABLE_TRT_HALF_OPS_CONFIG[self._option.model_name]
+                        )
                     config.enable_tuned_tensorrt_dynamic_shape(
                         str(trt_shape_range_info_path),
                         self._option.trt_allow_rebuild_at_runtime,

+ 5 - 0
paddlex/inference/utils/trt_config.py

@@ -189,6 +189,11 @@ OLD_IR_TRT_CFG_SETTING = {
     },
 }
 
+DISABLE_TRT_HALF_OPS_CONFIG = {
+    # TODO: just for example
+    "model_name": {"layer_norm"}
+}
+
 ############ pir trt ############
 PIR_TRT_PRECISION_MAP = PIR_TRT_PRECISION_MAP_CLASS()