gaotingquan 8 months ago
parent
commit
19863acc7f

+ 13 - 10
paddlex/inference/models/common/static_infer.py

@@ -28,7 +28,6 @@ from ....utils.flags import (
 from ...utils.benchmark import benchmark, set_inference_operations
 from ...utils.hpi import get_model_paths
 from ...utils.pp_option import PaddlePredictorOption
-from ...utils.trt_config import TRT_CFG
 
 
 CACHE_DIR = ".cache"
@@ -139,7 +138,7 @@ def _collect_trt_shape_range_info(
 
 # pir trt
 def _convert_trt(
-    trt_cfg,
+    trt_cfg_setting,
     pp_model_file,
     pp_params_file,
     trt_save_path,
@@ -153,10 +152,11 @@ def _convert_trt(
     )
 
     def _set_trt_config():
-        for attr_name in trt_cfg:
-            if not hasattr(trt_config, attr_name):
-                logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
-            setattr(trt_config, attr_name, trt_cfg[attr_name])
+        for attr_name in trt_cfg_setting:
+            assert hasattr(
+                trt_config, attr_name
+            ), f"The `{type(trt_config)}` don't have the attribute `{attr_name}`!"
+            setattr(trt_config, attr_name, trt_cfg_setting[attr_name])
 
     def _get_predictor(model_file, params_file):
         # HACK
@@ -457,7 +457,7 @@ class StaticInfer(object):
         if USE_PIR_TRT:
             trt_save_path = cache_dir / "trt" / self.model_file_prefix
             _convert_trt(
-                self._option.trt_cfg,
+                self._option.trt_cfg_setting,
                 model_file,
                 params_file,
                 trt_save_path,
@@ -472,12 +472,15 @@ class StaticInfer(object):
 
             config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
             config.enable_use_gpu(100, self._option.device_id)
-            for func_name in self._option.trt_cfg:
+            for func_name in self._option.trt_cfg_setting:
                 assert hasattr(
                     config, func_name
                 ), f"The `{type(config)}` don't have function `{func_name}`!"
-                kwargs = self._option.trt_cfg[func_name]
-                getattr(config, func_name)(**kwargs)
+                args = self._option.trt_cfg_setting[func_name]
+                if isinstance(args, list):
+                    getattr(config, func_name)(*args)
+                else:
+                    getattr(config, func_name)(**args)
 
             if self._option.trt_use_dynamic_shapes:
                 if self._option.trt_collect_shape_range_info:

+ 17 - 13
paddlex/inference/utils/pp_option.py

@@ -15,6 +15,7 @@
 import os
 from typing import Dict, List
 
+from ...utils.flags import USE_PIR_TRT
 from ...utils import logging
 from ...utils.device import (
     check_supported_device_type,
@@ -24,7 +25,7 @@ from ...utils.device import (
 )
 from .new_ir_blacklist import NEWIR_BLOCKLIST
 from .trt_blacklist import TRT_BLOCKLIST
-from .trt_config import TRT_PRECISION_MAP, TRT_CFG
+from .trt_config import TRT_PRECISION_MAP, TRT_CFG_SETTING
 
 
 class PaddlePredictorOption(object):
@@ -72,11 +73,14 @@ class PaddlePredictorOption(object):
 
         # for trt
         if self.run_mode in TRT_PRECISION_MAP:
-            trt_cfg = TRT_CFG[self.model_name]
-            trt_cfg["enable_tensorrt_engine"]["precision_mode"] = TRT_PRECISION_MAP[
-                self.run_mode
-            ]
-            self.trt_cfg = trt_cfg
+            trt_cfg_setting = TRT_CFG_SETTING[self.model_name]
+            if USE_PIR_TRT:
+                trt_cfg_setting["precision_mode"] = TRT_PRECISION_MAP[self.run_mode]
+            else:
+                trt_cfg_setting["enable_tensorrt_engine"]["precision_mode"] = (
+                    TRT_PRECISION_MAP[self.run_mode]
+                )
+            self.trt_cfg_setting = trt_cfg_setting
 
     def _get_default_config(self):
         """get default config"""
@@ -89,7 +93,7 @@ class PaddlePredictorOption(object):
             "cpu_threads": 8,
             "delete_pass": [],
             "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
-            "trt_cfg": {},
+            "trt_cfg_setting": {},
             "trt_use_dynamic_shapes": True,  # only for trt
             "trt_collect_shape_range_info": True,  # only for trt
             "trt_discard_cached_shape_range_info": False,  # only for trt
@@ -180,16 +184,16 @@ class PaddlePredictorOption(object):
         self._update("enable_new_ir", enable_new_ir)
 
     @property
-    def trt_cfg(self):
-        return self._cfg["trt_cfg"]
+    def trt_cfg_setting(self):
+        return self._cfg["trt_cfg_setting"]
 
-    @trt_cfg.setter
-    def trt_cfg(self, config: Dict):
+    @trt_cfg_setting.setter
+    def trt_cfg_setting(self, config: Dict):
         """set trt config"""
         assert isinstance(
             config, dict
-        ), f"The trt_cfg must be `dict` type, but recived `{type(config)}` type!"
-        self._update("trt_cfg", config)
+        ), f"The trt_cfg_setting must be `dict` type, but recived `{type(config)}` type!"
+        self._update("trt_cfg_setting", config)
 
     @property
     def trt_use_dynamic_shapes(self):

+ 33 - 21
paddlex/inference/utils/trt_config.py

@@ -64,7 +64,7 @@ class PIR_TRT_PRECISION_MAP_CLASS(LazyLoadDict):
 ############ old ir trt ############
 OLD_IR_TRT_PRECISION_MAP = OLD_IR_TRT_PRECISION_MAP_CLASS()
 
-OLD_IR_TRT_DEFAULT_CFG = {
+OLD_IR_TRT_CFG_DEFAULT_SETTING = {
     "workspace_size": 1 << 30,
     "max_batch_size": 32,
     "min_subgraph_size": 3,
@@ -72,20 +72,31 @@ OLD_IR_TRT_DEFAULT_CFG = {
     "use_calib_mode": False,
 }
 
-OLD_IR_TRT_CFG = {
+OLD_IR_TRT_CFG_SETTING = {
     "SegFormer-B3": {
-        "enable_tensorrt_engine": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31}
+        "enable_tensorrt_engine": {
+            **OLD_IR_TRT_CFG_DEFAULT_SETTING,
+            "workspace_size": 1 << 31,
+        }
     },
     "SegFormer-B4": {
-        "enable_tensorrt_engine": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31}
+        "enable_tensorrt_engine": {
+            **OLD_IR_TRT_CFG_DEFAULT_SETTING,
+            "workspace_size": 1 << 31,
+        }
     },
     "SegFormer-B5": {
-        "enable_tensorrt_engine": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31}
+        "enable_tensorrt_engine": {
+            **OLD_IR_TRT_CFG_DEFAULT_SETTING,
+            "workspace_size": 1 << 31,
+        }
     },
     "SLANeXt_wired": {
-        "enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG,
-        "exp_disable_tensorrt_ops": {
-            "ops": [
+        "enable_tensorrt_engine": OLD_IR_TRT_CFG_DEFAULT_SETTING,
+        # the exp_disable_tensorrt_ops() func don't support to be pass argument by keyword
+        # therefore, using list instead of dict
+        "exp_disable_tensorrt_ops": [
+            [
                 "linear_0.tmp_0",
                 "linear_4.tmp_0",
                 "linear_12.tmp_0",
@@ -95,12 +106,12 @@ OLD_IR_TRT_CFG = {
                 "linear_36.tmp_0",
                 "linear_40.tmp_0",
             ]
-        },
+        ],
     },
     "SLANeXt_wireless": {
-        "enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG,
-        "exp_disable_tensorrt_ops": {
-            "ops": [
+        "enable_tensorrt_engine": OLD_IR_TRT_CFG_DEFAULT_SETTING,
+        "exp_disable_tensorrt_ops": [
+            [
                 "linear_0.tmp_0",
                 "linear_4.tmp_0",
                 "linear_12.tmp_0",
@@ -110,20 +121,20 @@ OLD_IR_TRT_CFG = {
                 "linear_36.tmp_0",
                 "linear_40.tmp_0",
             ]
-        },
+        ],
     },
     "PP-YOLOE_seg-S": {
-        "enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG,
-        "exp_disable_tensorrt_ops": {
-            "ops": ["bilinear_interp_v2_1.tmp_0", "bilinear_interp_v2_1.tmp_0_slice_0"]
-        },
+        "enable_tensorrt_engine": OLD_IR_TRT_CFG_DEFAULT_SETTING,
+        "exp_disable_tensorrt_ops": [
+            ["bilinear_interp_v2_1.tmp_0", "bilinear_interp_v2_1.tmp_0_slice_0"]
+        ],
     },
 }
 
 ############ pir trt ############
 PIR_TRT_PRECISION_MAP = PIR_TRT_PRECISION_MAP_CLASS()
 
-PIR_TRT_CFG = {
+PIR_TRT_CFG_SETTING = {
     "DETR-R50": {"optimization_level": 4, "workspace_size": 1 << 32},
     "SegFormer-B0": {"optimization_level": 4, "workspace_size": 1 << 32},
     "SegFormer-B1": {"optimization_level": 4, "workspace_size": 1 << 32},
@@ -146,9 +157,10 @@ PIR_TRT_CFG = {
 
 if USE_PIR_TRT:
     TRT_PRECISION_MAP = PIR_TRT_PRECISION_MAP
-    TRT_CFG = defaultdict(dict, PIR_TRT_CFG)
+    TRT_CFG_SETTING = defaultdict(dict, PIR_TRT_CFG_SETTING)
 else:
     TRT_PRECISION_MAP = OLD_IR_TRT_PRECISION_MAP
-    TRT_CFG = defaultdict(
-        lambda: {"enable_tensorrt_engine": OLD_IR_TRT_DEFAULT_CFG}, OLD_IR_TRT_CFG
+    TRT_CFG_SETTING = defaultdict(
+        lambda: {"enable_tensorrt_engine": OLD_IR_TRT_CFG_DEFAULT_SETTING},
+        OLD_IR_TRT_CFG_SETTING,
     )