gaotingquan 1 gadu atpakaļ
vecāks
revīzija
49448a9e42

+ 1 - 1
paddlex/inference/components/paddle_predictor/predictor.py

@@ -62,7 +62,7 @@ class BasePaddlePredictor(BaseComponent, PPEngineMixin):
         params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
         config = Config(model_file, params_file)
 
-        if self.option.device == "gpu":
+        if self.option.device in ("gpu", "dcu"):
             config.enable_use_gpu(200, self.option.device_id)
             if paddle.is_compiled_with_rocm():
                 os.environ["FLAGS_conv_workspace_size_limit"] = "2000"

+ 2 - 2
paddlex/inference/utils/pp_option.py

@@ -28,7 +28,7 @@ class PaddlePredictorOption(object):
         "mkldnn",
         "mkldnn_bf16",
     )
-    SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu")
+    SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu")
 
     def __init__(self, model_name=None, **kwargs):
         super().__init__()
@@ -95,12 +95,12 @@ class PaddlePredictorOption(object):
         if not device:
             return
         device_type, device_ids = parse_device(device)
-        self._cfg["device"] = device_type
         if device_type not in self.SUPPORT_DEVICE:
             support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
             raise ValueError(
                 f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
             )
+        self._cfg["device"] = device_type
         device_id = device_ids[0] if device_ids is not None else 0
         self._cfg["device_id"] = device_id
         set_env_for_device(device)