Quellcode durchsuchen

support trt blocklist

gaotingquan vor 9 Monaten
Ursprung
Commit
a57079893d
2 geänderte Dateien mit 48 neuen und 0 gelöschten Zeilen
  1. 8 0
      paddlex/inference/utils/pp_option.py
  2. 40 0
      paddlex/inference/utils/trt_blacklist.py

+ 8 - 0
paddlex/inference/utils/pp_option.py

@@ -23,6 +23,7 @@ from ...utils.device import (
 )
 from ...utils import logging
 from .new_ir_blacklist import NEWIR_BLOCKLIST
+from .trt_blacklist import TRT_BLOCKLIST
 
 
 class PaddlePredictorOption(object):
@@ -100,6 +101,13 @@ class PaddlePredictorOption(object):
             raise ValueError(
                 f"`run_mode` must be {support_run_mode_str}, but received {repr(run_mode)}."
             )
+        # TRT Blocklist
+        if run_mode.startswith("trt") and self.model_name in TRT_BLOCKLIST:
+            logging.warning(
+                f"The model({self.model_name}) is not supported to run in trt mode! Using `paddle` instead!"
+            )
+            run_mode = "paddle"
+
         self._update("run_mode", run_mode)
 
     @property

+ 40 - 0
paddlex/inference/utils/trt_blacklist.py

@@ -0,0 +1,40 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+TRT_BLOCKLIST = [
+    "TimesNet_cls",
+    "TimesNet",
+    "TimesNet_ad",
+    "MaskRCNN-ResNet50-FPN",
+    "FasterRCNN-ResNeXt101-vd-FPN",
+    "Cascade-FasterRCNN-ResNet50-FPN",
+    "MaskRCNN-ResNet101-vd-FPN",
+    "FasterRCNN-ResNet50",
+    "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
+    "FasterRCNN-ResNet50-FPN",
+    "FasterRCNN-ResNet101",
+    "Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN",
+    "MaskRCNN-ResNeXt101-vd-FPN",
+    "MaskRCNN-ResNet50",
+    "FasterRCNN-ResNet50-vd-FPN",
+    "FasterRCNN-Swin-Tiny-FPN",
+    "FasterRCNN-ResNet34-FPN",
+    "MaskRCNN-ResNet101-FPN",
+    "FasterRCNN-ResNet50-vd-SSLDv2-FPN",
+    "FasterRCNN-ResNet101-FPN",
+    "MaskRCNN-ResNet50-vd-FPN",
+    "Cascade-MaskRCNN-ResNet50-FPN",
+    "MaskRCNN-ResNet50-vd-FPN",
+    "Cascade-MaskRCNN-ResNet50-FPN",
+]