Jelajahi Sumber

support to set more trt config by model

set exp_disable_tensorrt_ops for SLANeXt_wireless, SLANeXt_wired and PP-YOLOE_seg-S
gaotingquan 8 bulan lalu
induk
melakukan
db1bbfe948

+ 6 - 1
paddlex/inference/models/common/static_infer.py

@@ -472,7 +472,12 @@ class StaticInfer(object):
 
             config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
             config.enable_use_gpu(100, self._option.device_id)
-            config.enable_tensorrt_engine(**self._option.trt_cfg)
+            for func_name in self._option.trt_cfg:
+                assert hasattr(
+                    config, func_name
+                ), f"The `{type(config)}` don't have function `{func_name}`!"
+                kwargs = self._option.trt_cfg[func_name]
+                getattr(config, func_name)(**kwargs)
 
             if self._option.trt_use_dynamic_shapes:
                 if self._option.trt_collect_shape_range_info:

+ 3 - 1
paddlex/inference/utils/pp_option.py

@@ -73,7 +73,9 @@ class PaddlePredictorOption(object):
         # for trt
         if self.run_mode in TRT_PRECISION_MAP:
             trt_cfg = TRT_CFG[self.model_name]
-            trt_cfg["precision_mode"] = TRT_PRECISION_MAP[self.run_mode]
+            trt_cfg["enable_tensorrt_engine"]["precision_mode"] = TRT_PRECISION_MAP[
+                self.run_mode
+            ]
             self.trt_cfg = trt_cfg
 
     def _get_default_config(self):

+ 52 - 9
paddlex/inference/utils/trt_config.py

@@ -41,12 +41,12 @@ class LazyLoadDict(dict):
 
 class OLD_IR_TRT_PRECISION_MAP_CLASS(LazyLoadDict):
     def _load(self):
-        from lazy_paddle.inference.Config import Precision
+        from lazy_paddle.inference import PrecisionType
 
         return {
-            "trt_int8": Precision.Int8,
-            "trt_fp32": Precision.Float32,
-            "trt_fp16": Precision.Half,
+            "trt_int8": PrecisionType.Int8,
+            "trt_fp32": PrecisionType.Float32,
+            "trt_fp16": PrecisionType.Half,
         }
 
 
@@ -73,12 +73,53 @@ OLD_IR_TRT_DEFAULT_CFG = {
 }
 
 OLD_IR_TRT_CFG = {
-    "SegFormer-B3": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31},
-    "SegFormer-B4": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31},
-    "SegFormer-B5": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31},
+    "SegFormer-B3": {
+        "enable_tensorrt_engine": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31}
+    },
+    "SegFormer-B4": {
+        "enable_tensorrt_engine": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31}
+    },
+    "SegFormer-B5": {
+        "enable_tensorrt_engine": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31}
+    },
+    "SLANeXt_wired": {
+        "enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG,
+        "exp_disable_tensorrt_ops": {
+            "ops": [
+                "linear_0.tmp_0",
+                "linear_4.tmp_0",
+                "linear_12.tmp_0",
+                "linear_16.tmp_0",
+                "linear_24.tmp_0",
+                "linear_28.tmp_0",
+                "linear_36.tmp_0",
+                "linear_40.tmp_0",
+            ]
+        },
+    },
+    "SLANeXt_wireless": {
+        "enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG,
+        "exp_disable_tensorrt_ops": {
+            "ops": [
+                "linear_0.tmp_0",
+                "linear_4.tmp_0",
+                "linear_12.tmp_0",
+                "linear_16.tmp_0",
+                "linear_24.tmp_0",
+                "linear_28.tmp_0",
+                "linear_36.tmp_0",
+                "linear_40.tmp_0",
+            ]
+        },
+    },
+    "PP-YOLOE_seg-S": {
+        "enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG,
+        "exp_disable_tensorrt_ops": {
+            "ops": ["bilinear_interp_v2_1.tmp_0", "bilinear_interp_v2_1.tmp_0_slice_0"]
+        },
+    },
 }
 
-
 ############ pir trt ############
 PIR_TRT_PRECISION_MAP = PIR_TRT_PRECISION_MAP_CLASS()
 
@@ -108,4 +149,6 @@ if USE_PIR_TRT:
     TRT_CFG = defaultdict(dict, PIR_TRT_CFG)
 else:
     TRT_PRECISION_MAP = OLD_IR_TRT_PRECISION_MAP
-    TRT_CFG = defaultdict(lambda: OLD_IR_TRT_DEFAULT_CFG, OLD_IR_TRT_CFG)
+    TRT_CFG = defaultdict(
+        lambda: {"enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG}, OLD_IR_TRT_CFG
+    )