Browse Source

support set trt config by model

gaotingquan 8 months ago
parent
commit
9a4a934d47

+ 13 - 1
paddlex/inference/models/common/static_infer.py

@@ -23,8 +23,10 @@ import numpy as np
 from ....utils.flags import DEBUG, FLAGS_json_format_model, USE_PIR_TRT
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption
+from ...utils.trt_config import TRT_CFG
 
 
+# old trt
 def collect_trt_shapes(
     model_file, model_params, gpu_id, shape_range_info_path, trt_dynamic_shapes
 ):
@@ -48,7 +50,15 @@ def collect_trt_shapes(
         predictor.run()
 
 
-def convert_trt(mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
+# pir trt
+def convert_trt(model_name, mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
+    def _set_trt_config():
+        if settings := TRT_CFG.get(model_name):
+            for attr_name in settings:
+                if not hasattr(trt_config, attr_name):
+                    logging.warning(f"The TensorRTConfig don't have the `{attr_name}`!")
+                setattr(trt_config, attr_name, settings[attr_name])
+
     from lazy_paddle.tensorrt.export import (
         Input,
         TensorRTConfig,
@@ -73,6 +83,7 @@ def convert_trt(mode, pp_model_path, trt_save_path, trt_dynamic_shapes):
 
     # Create TensorRTConfig
     trt_config = TensorRTConfig(inputs=trt_inputs)
+    _set_trt_config()
     trt_config.precision_mode = precision_map[mode]
     trt_config.save_model_dir = trt_save_path
     convert(pp_model_path, trt_config)
@@ -197,6 +208,7 @@ class StaticInfer:
                 ).as_posix()
                 pp_model_path = (Path(self.model_dir) / self.model_prefix).as_posix()
                 convert_trt(
+                    self.option.model_name,
                     self.option.run_mode,
                     pp_model_path,
                     trt_save_path,

+ 23 - 0
paddlex/inference/utils/trt_config.py

@@ -0,0 +1,23 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+TRT_CFG = {
+    "DETR-R50": {"optimization_level": 4, "workspace_size": 1 << 32},
+    "SegFormer-B0": {"optimization_level": 4, "workspace_size": 1 << 32},
+    "SegFormer-B1": {"optimization_level": 4, "workspace_size": 1 << 32},
+    "SegFormer-B2": {"optimization_level": 4, "workspace_size": 1 << 32},
+    "SegFormer-B3": {"optimization_level": 4, "workspace_size": 1 << 32},
+    "SegFormer-B4": {"optimization_level": 4, "workspace_size": 1 << 32},
+    "SegFormer-B5": {"optimization_level": 4, "workspace_size": 1 << 32},
+}