Pārlūkot izejas kodu

support white list for custom devices (#2690)

Tingquan Gao 9 mēneši atpakaļ
vecāks
revīzija
e037507f4c

+ 1 - 1
docs/support_list/model_list_xpu.md

@@ -258,7 +258,7 @@ PaddleX 内置了多条产线,每条产线都包含了若干模块,每个模
 <td>4.2 M</td>
 <td><a href="https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-OCRv4_mobile_det_infer.tar">推理模型</a>/<a href="https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv4_mobile_det_pretrained.pdparams">训练模型</a></td></tr>
 <tr>
-<td>PP-OCRv4_server_det</td>
+<td>PP-OCRv4_server_det</td>
 <td>82.69</td>
 <td>100.1M</td>
 <td><a href="https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-OCRv4_server_det )_infer.tar">推理模型</a>/<a href="https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv4_server_det )_pretrained.pdparams">训练模型</a></td></tr>

+ 9 - 2
paddlex/inference/utils/pp_option.py

@@ -15,7 +15,12 @@
 import os
 from typing import Dict, List
 
-from ...utils.device import parse_device, set_env_for_device, get_default_device
+from ...utils.device import (
+    parse_device,
+    set_env_for_device,
+    get_default_device,
+    check_supported_device,
+)
 from ...utils import logging
 from .new_ir_blacklist import NEWIR_BLOCKLIST
 
@@ -107,7 +112,9 @@ class PaddlePredictorOption(object):
 
     @property
     def device(self):
-        return self._cfg["device"]
+        device = self._cfg["device"]
+        check_supported_device(device, self.model_name)
+        return device
 
     @device.setter
     def device(self, device: str):

+ 6 - 1
paddlex/modules/base/evaluator.py

@@ -17,7 +17,11 @@ from pathlib import Path
 from abc import ABC, abstractmethod
 
 from .build_model import build_model
-from ...utils.device import update_device_num, set_env_for_device
+from ...utils.device import (
+    update_device_num,
+    set_env_for_device,
+    check_supported_device,
+)
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 from ...utils.logging import *
@@ -144,6 +148,7 @@ evaling!"
         Returns:
             str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
         """
+        check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
         if using_device_number:
             return update_device_num(self.global_config.device, using_device_number)

+ 6 - 1
paddlex/modules/base/exportor.py

@@ -17,7 +17,11 @@ from pathlib import Path
 from abc import ABC, abstractmethod
 
 from .build_model import build_model
-from ...utils.device import update_device_num, set_env_for_device
+from ...utils.device import (
+    update_device_num,
+    set_env_for_device,
+    check_supported_device,
+)
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 from ...utils import logging
@@ -114,6 +118,7 @@ exporting!"
         Returns:
             str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
         """
+        check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
         if using_device_number:
             return update_device_num(self.global_config.device, using_device_number)

+ 6 - 1
paddlex/modules/base/trainer.py

@@ -16,7 +16,11 @@ import os
 from abc import ABC, abstractmethod
 from pathlib import Path
 from .build_model import build_model
-from ...utils.device import update_device_num, set_env_for_device
+from ...utils.device import (
+    update_device_num,
+    set_env_for_device,
+    check_supported_device,
+)
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 from ...utils.logging import info
@@ -108,6 +112,7 @@ training!"
         Returns:
             str: device setting, such as: `gpu:0,1`, `npu:0,1` `cpu`.
         """
+        check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
         if using_device_number:
             return update_device_num(self.global_config.device, using_device_number)

+ 285 - 0
paddlex/utils/custom_device_whitelist.py

@@ -0,0 +1,285 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+DCU_WHITELIST = [
+    "ResNet18",
+    "ResNet34",
+    "ResNet50",
+    "ResNet101",
+    "ResNet152",
+    "Deeplabv3_Plus-R50",
+    "Deeplabv3_Plus-R101",
+]
+
+MLU_WHITELIST = [
+    "MobileNetV3_large_x0_5",
+    "MobileNetV3_large_x0_35",
+    "MobileNetV3_large_x0_75",
+    "MobileNetV3_large_x1_0",
+    "MobileNetV3_large_x1_25",
+    "MobileNetV3_small_x0_5",
+    "MobileNetV3_small_x0_35",
+    "MobileNetV3_small_x0_75",
+    "MobileNetV3_small_x1_0",
+    "MobileNetV3_small_x1_25",
+    "PP-HGNet_small",
+    "PP-LCNet_x0_5",
+    "PP-LCNet_x0_25",
+    "PP-LCNet_x0_35",
+    "PP-LCNet_x0_75",
+    "PP-LCNet_x1_0",
+    "PP-LCNet_x1_5",
+    "PP-LCNet_x2_0",
+    "PP-LCNet_x2_5",
+    "ResNet18",
+    "ResNet34",
+    "ResNet50",
+    "ResNet101",
+    "ResNet152",
+    "PicoDet-L",
+    "PicoDet-S",
+    "PP-YOLOE_plus-L",
+    "PP-YOLOE_plus-M",
+    "PP-YOLOE_plus-S",
+    "PP-YOLOE_plus-X",
+    "PP-LiteSeg-T",
+    "PP-OCRv4_mobile_det",
+    "PP-OCRv4_server_det",
+    "PP-OCRv4_mobile_rec",
+    "PP-OCRv4_server_rec",
+    "PicoDet_layout_1x",
+    "DLinear",
+    "NLinear",
+    "RLinear",
+]
+
+NPU_WHITELIST = [
+    "CLIP_vit_base_patch16_224",
+    "CLIP_vit_large_patch14_224",
+    "ConvNeXt_base_224",
+    "ConvNeXt_base_384",
+    "ConvNeXt_large_224",
+    "ConvNeXt_large_384",
+    "ConvNeXt_small",
+    "ConvNeXt_tiny",
+    "MobileNetV1_x0_5",
+    "MobileNetV1_x0_25",
+    "MobileNetV1_x0_75",
+    "MobileNetV1_x1_0",
+    "MobileNetV2_x0_5",
+    "MobileNetV2_x0_25",
+    "MobileNetV2_x1_0",
+    "MobileNetV2_x1_5",
+    "MobileNetV2_x2_0",
+    "MobileNetV3_large_x0_5",
+    "MobileNetV3_large_x0_35",
+    "MobileNetV3_large_x0_75",
+    "MobileNetV3_large_x1_0",
+    "MobileNetV3_large_x1_25",
+    "MobileNetV3_small_x0_5",
+    "MobileNetV3_small_x0_35",
+    "MobileNetV3_small_x0_75",
+    "MobileNetV3_small_x1_0",
+    "MobileNetV3_small_x1_25",
+    "MobileNetV4_conv_large",
+    "MobileNetV4_conv_medium",
+    "MobileNetV4_conv_small",
+    "PP-HGNet_base",
+    "PP-HGNet_small",
+    "PP-HGNet_tiny",
+    "PP-HGNetV2-B0",
+    "PP-HGNetV2-B1",
+    "PP-HGNetV2-B2",
+    "PP-HGNetV2-B3",
+    "PP-HGNetV2-B4",
+    "PP-HGNetV2-B5",
+    "PP-HGNetV2-B6",
+    "PP-LCNet_x0_5",
+    "PP-LCNet_x0_25",
+    "PP-LCNet_x0_35",
+    "PP-LCNet_x0_75",
+    "PP-LCNet_x1_0",
+    "PP-LCNet_x1_5",
+    "PP-LCNet_x2_0",
+    "PP-LCNet_x2_5",
+    "PP-LCNetV2_base",
+    "PP-LCNetV2_large",
+    "PP-LCNetV2_small",
+    "ResNet18_vd",
+    "ResNet18",
+    "ResNet34_vd",
+    "ResNet34",
+    "ResNet50_vd",
+    "ResNet50",
+    "ResNet101_vd",
+    "ResNet101",
+    "ResNet152_vd",
+    "ResNet152",
+    "ResNet200_vd",
+    "SwinTransformer_base_patch4_window7_224",
+    "SwinTransformer_base_patch4_window12_384",
+    "SwinTransformer_large_patch4_window7_224",
+    "SwinTransformer_large_patch4_window12_384",
+    "SwinTransformer_small_patch4_window7_224",
+    "SwinTransformer_tiny_patch4_window7_224",
+    "CLIP_vit_base_patch16_448_ML",
+    "PP-HGNetV2-B0_ML",
+    "PP-HGNetV2-B4_ML",
+    "PP-HGNetV2-B6_ML",
+    "Cascade-FasterRCNN-ResNet50-FPN",
+    "Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN",
+    "CenterNet-DLA-34",
+    "CenterNet-ResNet50",
+    "DETR-R50",
+    "FasterRCNN-ResNet34-FPN",
+    "FasterRCNN-ResNet50",
+    "FasterRCNN-ResNet50-FPN",
+    "FasterRCNN-ResNet50-vd-FPN",
+    "FasterRCNN-ResNet50-vd-SSLDv2-FPN",
+    "FasterRCNN-ResNet101",
+    "FasterRCNN-ResNet101-FPN",
+    "FasterRCNN-ResNeXt101-vd-FPN",
+    "FasterRCNN-Swin-Tiny-FPN",
+    "FCOS-ResNet50",
+    "PicoDet-L",
+    "PicoDet-M",
+    "PicoDet-S",
+    "PicoDet-XS",
+    "PP-YOLOE_plus-L",
+    "PP-YOLOE_plus-M",
+    "PP-YOLOE_plus-S",
+    "PP-YOLOE_plus-X",
+    "RT-DETR-H",
+    "RT-DETR-L",
+    "RT-DETR-R18",
+    "RT-DETR-R50",
+    "RT-DETR-X",
+    "YOLOv3-DarkNet53",
+    "YOLOv3-MobileNetV3",
+    "YOLOv3-ResNet50_vd_DCN",
+    "PP-YOLOE_plus_SOD-S",
+    "PP-YOLOE_plus_SOD-L",
+    "PP-YOLOE_plus_SOD-largesize-L",
+    "PP-YOLOE-L_human",
+    "PP-YOLOE-S_human",
+    "Deeplabv3_Plus-R50",
+    "Deeplabv3_Plus-R101",
+    "Deeplabv3-R50",
+    "Deeplabv3-R101",
+    "OCRNet_HRNet-W48",
+    "PP-LiteSeg-T",
+    "Mask-RT-DETR-H",
+    "Mask-RT-DETR-L",
+    "Mask-RT-DETR-M",
+    "Mask-RT-DETR-S",
+    "Mask-RT-DETR-X",
+    "Cascade-MaskRCNN-ResNet50-FPN",
+    "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN",
+    "MaskRCNN-ResNet50-FPN",
+    "MaskRCNN-ResNet50-vd-FPN",
+    "MaskRCNN-ResNet50",
+    "MaskRCNN-ResNet101-FPN",
+    "MaskRCNN-ResNet101-vd-FPN",
+    "MaskRCNN-ResNeXt101-vd-FPN",
+    "PP-YOLOE_seg-S",
+    "PP-ShiTuV2_rec_CLIP_vit_base",
+    "PP-ShiTuV2_rec_CLIP_vit_large",
+    "PP-ShiTuV2_det",
+    "PP-YOLOE-L_vehicle",
+    "PP-YOLOE-S_vehicle",
+    "STFPM",
+    "PP-OCRv4_mobile_det",
+    "PP-OCRv4_server_det",
+    "PP-OCRv4_mobile_rec",
+    "PP-OCRv4_server_rec",
+    "ch_SVTRv2_rec",
+    "ch_RepSVTR_rec",
+    "SLANet",
+    "PicoDet_layout_1x",
+    "PicoDet-L_layout_3cls",
+    "RT-DETR-H_layout_3cls",
+    "RT-DETR-H_layout_17cls",
+    "DLinear",
+    "NLinear",
+    "Nonstationary",
+    "PatchTST",
+    "RLinear",
+    "TiDE",
+    "TimesNet",
+    "AutoEncoder_ad",
+    "DLinear_ad",
+    "Nonstationary_ad",
+    "PatchTST_ad",
+    "TimesNet_ad",
+    "TimesNet_cls",
+]
+
+XPU_WHITELIST = [
+    "MobileNetV3_large_x0_5",
+    "MobileNetV3_large_x0_35",
+    "MobileNetV3_large_x0_75",
+    "MobileNetV3_large_x1_0",
+    "MobileNetV3_large_x1_25",
+    "MobileNetV3_small_x0_5",
+    "MobileNetV3_small_x0_35",
+    "MobileNetV3_small_x0_75",
+    "MobileNetV3_small_x1_0",
+    "MobileNetV3_small_x1_25",
+    "PP-HGNet_small",
+    "PP-LCNet_x0_5",
+    "PP-LCNet_x0_25",
+    "PP-LCNet_x0_35",
+    "PP-LCNet_x0_75",
+    "PP-LCNet_x1_0",
+    "PP-LCNet_x1_5",
+    "PP-LCNet_x2_0",
+    "PP-LCNet_x2_5",
+    "ResNet18",
+    "ResNet34",
+    "ResNet50",
+    "ResNet101",
+    "ResNet152",
+    "PicoDet-L",
+    "PicoDet-S",
+    "PP-YOLOE_plus-L",
+    "PP-YOLOE_plus-M",
+    "PP-YOLOE_plus-S",
+    "PP-YOLOE_plus-X",
+    "PP-LiteSeg-T",
+    "PP-OCRv4_mobile_det",
+    "PP-OCRv4_server_det",
+    "PP-OCRv4_mobile_rec",
+    "PP-OCRv4_server_rec",
+    "PicoDet_layout_1x",
+    "DLinear",
+    "NLinear",
+    "RLinear",
+]
+
+GCU_WHITELIST = [
+    "ResNet50",
+    "PP-YOLOE_plus-L",
+    "PP-YOLOE_plus-M",
+    "PP-YOLOE_plus-S",
+    "PP-YOLOE_plus-X",
+    "RT-DETR-H",
+    "RT-DETR-L",
+    "RT-DETR-R18",
+    "RT-DETR-R50",
+    "RT-DETR-X",
+    "PP-OCRv4_mobile_det",
+    "PP-OCRv4_server_det",
+    "PP-OCRv4_mobile_rec",
+    "PP-OCRv4_server_rec",
+]

+ 31 - 0
paddlex/utils/device.py

@@ -18,6 +18,13 @@ import GPUtil
 import lazy_paddle as paddle
 from . import logging
 from .errors import raise_unsupported_device_error
+from .custom_device_whitelist import (
+    DCU_WHITELIST,
+    MLU_WHITELIST,
+    NPU_WHITELIST,
+    XPU_WHITELIST,
+    GCU_WHITELIST,
+)
 
 SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu", "gcu"]
 
@@ -112,3 +119,27 @@ def set_env_for_device(device):
         if device_type.lower() == "gcu":
             envs = {"FLAGS_use_stride_kernel": "0"}
             _set(envs)
+
+
+def check_supported_device(device, model_name):
+    device_type, device_ids = parse_device(device)
+    if device_type == "dcu":
+        assert (
+            model_name in DCU_WHITELIST
+        ), f"The DCU device does not yet support `{model_name}` model!"
+    elif device_type == "mlu":
+        assert (
+            model_name in MLU_WHITELIST
+        ), f"The MLU device does not yet support `{model_name}` model!"
+    elif device_type == "npu":
+        assert (
+            model_name in NPU_WHITELIST
+        ), f"The NPU device does not yet support `{model_name}` model!"
+    elif device_type == "xpu":
+        assert (
+            model_name in XPU_WHITELIST
+        ), f"The XPU device does not yet support `{model_name}` model!"
+    elif device_type == "gcu":
+        assert (
+            model_name in GCU_WHITELIST
+        ), f"The GCU device does not yet support `{model_name}` model!"