瀏覽代碼

fix RT-DETR series model

gaotingquan 1 年之前
父節點
當前提交
c07fdbd12a
共有 2 個文件被更改,包括 2 次插入14 次删除
  1. 1 7
      paddlex/inference/models/instance_segmentation.py
  2. 1 7
      paddlex/inference/models/object_detection.py

+ 1 - 7
paddlex/inference/models/instance_segmentation.py

@@ -41,13 +41,7 @@ class InstanceSegPredictor(DetPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        if self.model_name in [
-            "Mask-RT-DETR-S",
-            "Mask-RT-DETR-M",
-            "Mask-RT-DETR-L",
-            "Mask-RT-DETR-H",
-            "Mask-RT-DETR-X",
-        ]:
+        if "RT-DETR" in self.model_name:
             predictor.set_inputs(
                 {"img": "img", "scale_factors": "scale_factors", "img_size": "img_size"}
             )

+ 1 - 7
paddlex/inference/models/object_detection.py

@@ -44,13 +44,7 @@ class DetPredictor(CVPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        if self.model_name in [
-            "RT-DETR-R18",
-            "RT-DETR-R50",
-            "RT-DETR-L",
-            "RT-DETR-H",
-            "RT-DETR-X",
-        ]:
+        if "RT-DETR" in self.model_name:
             predictor.set_inputs(
                 {
                     "img": "img",