ソースを参照

support to set trt config by model

set trt config for SegFormer-B3 B4 B5
gaotingquan 8 ヶ月 前
コミット
d3dc02cb7a

+ 7 - 30
paddlex/inference/models/common/static_infer.py

@@ -139,8 +139,7 @@ def _collect_trt_shape_range_info(
 
 # pir trt
 def _convert_trt(
-    model_name,
-    mode,
+    trt_cfg,
     pp_model_file,
     pp_params_file,
     trt_save_path,
@@ -151,15 +150,13 @@ def _convert_trt(
         Input,
         TensorRTConfig,
         convert,
-        PrecisionMode,
     )
 
     def _set_trt_config():
-        if settings := TRT_CFG.get(model_name):
-            for attr_name in settings:
-                if not hasattr(trt_config, attr_name):
-                    logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
-                setattr(trt_config, attr_name, settings[attr_name])
+        for attr_name in trt_cfg:
+            if not hasattr(trt_config, attr_name):
+                logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
+            setattr(trt_config, attr_name, trt_cfg[attr_name])
 
     def _get_predictor(model_file, params_file):
         # HACK
@@ -187,11 +184,6 @@ def _convert_trt(
                 f"Invalid input name {repr(name)} found in `dynamic_shape_input_data`"
             )
 
-    precision_map = {
-        "trt_int8": PrecisionMode.INT8,
-        "trt_fp32": PrecisionMode.FP32,
-        "trt_fp16": PrecisionMode.FP16,
-    }
     trt_inputs = []
     for name, candidate_shapes in dynamic_shapes.items():
         # XXX: Currently we have no way to get the data type of the tensor
@@ -221,7 +213,6 @@ def _convert_trt(
     # Create TensorRTConfig
     trt_config = TensorRTConfig(inputs=trt_inputs)
     _set_trt_config()
-    trt_config.precision_mode = precision_map[mode]
     trt_config.save_model_dir = str(trt_save_path)
     pp_model_path = str(pp_model_file.with_suffix(""))
     convert(pp_model_path, trt_config)
@@ -466,8 +457,7 @@ class StaticInfer(object):
         if USE_PIR_TRT:
             trt_save_path = cache_dir / "trt" / self.model_file_prefix
             _convert_trt(
-                self._option.model_name,
-                self._option.run_mode,
+                self._option.trt_cfg,
                 model_file,
                 params_file,
                 trt_save_path,
@@ -478,24 +468,11 @@ class StaticInfer(object):
             params_file = trt_save_path.with_suffix(".pdiparams")
             config = lazy_paddle.inference.Config(str(model_file), str(params_file))
         else:
-            PRECISION_MAP = {
-                "trt_int8": lazy_paddle.inference.Config.Precision.Int8,
-                "trt_fp32": lazy_paddle.inference.Config.Precision.Float32,
-                "trt_fp16": lazy_paddle.inference.Config.Precision.Half,
-            }
-
             config = lazy_paddle.inference.Config(str(model_file), str(params_file))
 
             config.set_optim_cache_dir(str(cache_dir / "optim_cache"))
             config.enable_use_gpu(100, self._option.device_id)
-            config.enable_tensorrt_engine(
-                workspace_size=self._option.trt_max_workspace_size,
-                max_batch_size=self._option.trt_max_batch_size,
-                min_subgraph_size=self._option.trt_min_subgraph_size,
-                precision_mode=PRECISION_MAP[self._option.run_mode],
-                use_static=self._option.trt_use_static,
-                use_calib_mode=self._option.trt_use_calib_mode,
-            )
+            config.enable_tensorrt_engine(**self._option.trt_cfg)
 
             if self._option.trt_use_dynamic_shapes:
                 if self._option.trt_collect_shape_range_info:

+ 20 - 72
paddlex/inference/utils/pp_option.py

@@ -24,6 +24,7 @@ from ...utils.device import (
 )
 from .new_ir_blacklist import NEWIR_BLOCKLIST
 from .trt_blacklist import TRT_BLOCKLIST
+from .trt_config import TRT_PRECISION_MAP, TRT_CFG
 
 
 class PaddlePredictorOption(object):
@@ -69,21 +70,24 @@ class PaddlePredictorOption(object):
         for k, v in self._get_default_config().items():
             self._cfg.setdefault(k, v)
 
+        # for trt
+        if self.run_mode in TRT_PRECISION_MAP:
+            trt_cfg = TRT_CFG[self.model_name]
+            trt_cfg["precision_mode"] = TRT_PRECISION_MAP[self.run_mode]
+            self.trt_cfg = trt_cfg
+
     def _get_default_config(self):
         """get default config"""
         device_type, device_ids = parse_device(get_default_device())
-        return {
+
+        default_config = {
             "run_mode": "paddle",
             "device_type": device_type,
             "device_id": None if device_ids is None else device_ids[0],
             "cpu_threads": 8,
             "delete_pass": [],
             "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
-            "trt_max_workspace_size": 1 << 30,  # only for trt
-            "trt_max_batch_size": 32,  # only for trt
-            "trt_min_subgraph_size": 3,  # only for trt
-            "trt_use_static": True,  # only for trt
-            "trt_use_calib_mode": False,  # only for trt
+            "trt_cfg": {},
             "trt_use_dynamic_shapes": True,  # only for trt
             "trt_collect_shape_range_info": True,  # only for trt
             "trt_discard_cached_shape_range_info": False,  # only for trt
@@ -92,6 +96,7 @@ class PaddlePredictorOption(object):
             "trt_shape_range_info_path": None,  # only for trt
             "trt_allow_rebuild_at_runtime": True,  # only for trt
         }
+        return default_config
 
     def _update(self, k, v):
         self._cfg[k] = v
@@ -173,49 +178,16 @@ class PaddlePredictorOption(object):
         self._update("enable_new_ir", enable_new_ir)
 
     @property
-    def trt_max_workspace_size(self):
-        return self._cfg["trt_max_workspace_size"]
-
-    @trt_max_workspace_size.setter
-    def trt_max_workspace_size(self, trt_max_workspace_size):
-        self._update("trt_max_workspace_size", trt_max_workspace_size)
-
-    @property
-    def trt_max_batch_size(self):
-        return self._cfg["trt_max_batch_size"]
-
-    @trt_max_batch_size.setter
-    def trt_max_batch_size(self, trt_max_batch_size):
-        self._update("trt_max_batch_size", trt_max_batch_size)
-
-    @property
-    def trt_min_subgraph_size(self):
-        return self._cfg["trt_min_subgraph_size"]
-
-    @trt_min_subgraph_size.setter
-    def trt_min_subgraph_size(self, trt_min_subgraph_size: int):
-        """set min subgraph size"""
-        if not isinstance(trt_min_subgraph_size, int):
-            raise Exception()
-        self._update("trt_min_subgraph_size", trt_min_subgraph_size)
-
-    @property
-    def trt_use_static(self):
-        return self._cfg["trt_use_static"]
+    def trt_cfg(self):
+        return self._cfg["trt_cfg"]
 
-    @trt_use_static.setter
-    def trt_use_static(self, trt_use_static):
-        """set trt use static"""
-        self._update("trt_use_static", trt_use_static)
-
-    @property
-    def trt_use_calib_mode(self):
-        return self._cfg["trt_use_calib_mode"]
-
-    @trt_use_calib_mode.setter
-    def trt_use_calib_mode(self, trt_use_calib_mode):
-        """set trt calib mode"""
-        self._update("trt_use_calib_mode", trt_use_calib_mode)
+    @trt_cfg.setter
+    def trt_cfg(self, config: Dict):
+        """set trt config"""
+        assert isinstance(
+            config, dict
+        ), f"The trt_cfg must be `dict` type, but recived `{type(config)}` type!"
+        self._update("trt_cfg", config)
 
     @property
     def trt_use_dynamic_shapes(self):
@@ -284,14 +256,6 @@ class PaddlePredictorOption(object):
     # For backward compatibility
     # TODO: Issue deprecation warnings
     @property
-    def min_subgraph_size(self):
-        return self.trt_min_subgraph_size
-
-    @min_subgraph_size.setter
-    def min_subgraph_size(self, min_subgraph_size):
-        self.trt_min_subgraph_size = min_subgraph_size
-
-    @property
     def shape_info_filename(self):
         return self.trt_shape_range_info_path
 
@@ -299,22 +263,6 @@ class PaddlePredictorOption(object):
     def shape_info_filename(self, shape_info_filename):
         self.trt_shape_range_info_path = shape_info_filename
 
-    @property
-    def trt_calib_mode(self):
-        return self.trt_use_calib_mode
-
-    @trt_calib_mode.setter
-    def trt_calib_mode(self, trt_calib_mode):
-        self.trt_use_calib_mode = trt_calib_mode
-
-    @property
-    def batch_size(self):
-        return self.trt_max_batch_size
-
-    @batch_size.setter
-    def batch_size(self, batch_size):
-        self.trt_max_batch_size = batch_size
-
     def set_device(self, device: str):
         """set device"""
         if not device:

+ 79 - 1
paddlex/inference/utils/trt_config.py

@@ -12,7 +12,77 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-TRT_CFG = {
+from collections import defaultdict
+import lazy_paddle
+from ...utils.flags import USE_PIR_TRT
+
+
+class LazyLoadDict(dict):
+    def __init__(self, *args, **kwargs):
+        self._initialized = False
+        super().__init__(*args, **kwargs)
+
+    def _initialize(self):
+        if not self._initialized:
+            self.update(self._load())
+            self._initialized = True
+
+    def __getitem__(self, key):
+        self._initialize()
+        return super().__getitem__(key)
+
+    def __contains__(self, key):
+        self._initialize()
+        return super().__contains__(key)
+
+    def _load(self):
+        raise NotImplementedError
+
+
+class OLD_IR_TRT_PRECISION_MAP_CLASS(LazyLoadDict):
+    def _load(self):
+        from lazy_paddle.inference.Config import Precision
+
+        return {
+            "trt_int8": Precision.Int8,
+            "trt_fp32": Precision.Float32,
+            "trt_fp16": Precision.Half,
+        }
+
+
+class PIR_TRT_PRECISION_MAP_CLASS(LazyLoadDict):
+    def _load(self):
+        from lazy_paddle.tensorrt.export import PrecisionMode
+
+        return {
+            "trt_int8": PrecisionMode.INT8,
+            "trt_fp32": PrecisionMode.FP32,
+            "trt_fp16": PrecisionMode.FP16,
+        }
+
+
+############ old ir trt ############
+OLD_IR_TRT_PRECISION_MAP = OLD_IR_TRT_PRECISION_MAP_CLASS()
+
+OLD_IR_TRT_DEFAULT_CFG = {
+    "workspace_size": 1 << 30,
+    "max_batch_size": 32,
+    "min_subgraph_size": 3,
+    "use_static": True,
+    "use_calib_mode": False,
+}
+
+OLD_IR_TRT_CFG = {
+    "SegFormer-B3": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31},
+    "SegFormer-B4": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31},
+    "SegFormer-B5": {**OLD_IR_TRT_DEFAULT_CFG, "workspace_size": 1 << 31},
+}
+
+
+############ pir trt ############
+PIR_TRT_PRECISION_MAP = PIR_TRT_PRECISION_MAP_CLASS()
+
+PIR_TRT_CFG = {
     "DETR-R50": {"optimization_level": 4, "workspace_size": 1 << 32},
     "SegFormer-B0": {"optimization_level": 4, "workspace_size": 1 << 32},
     "SegFormer-B1": {"optimization_level": 4, "workspace_size": 1 << 32},
@@ -31,3 +101,11 @@ TRT_CFG = {
         "workspace_size": 1 << 32,
     },
 }
+
+
+if USE_PIR_TRT:
+    TRT_PRECISION_MAP = PIR_TRT_PRECISION_MAP
+    TRT_CFG = defaultdict(dict, PIR_TRT_CFG)
+else:
+    TRT_PRECISION_MAP = OLD_IR_TRT_PRECISION_MAP
+    TRT_CFG = defaultdict(lambda: OLD_IR_TRT_DEFAULT_CFG, OLD_IR_TRT_CFG)