|
|
@@ -28,7 +28,7 @@ class PaddlePredictorOption(object):
|
|
|
"mkldnn",
|
|
|
"mkldnn_bf16",
|
|
|
)
|
|
|
- SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu")
|
|
|
+ SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu")
|
|
|
|
|
|
def __init__(self, model_name=None, **kwargs):
|
|
|
super().__init__()
|
|
|
@@ -95,12 +95,12 @@ class PaddlePredictorOption(object):
|
|
|
if not device:
|
|
|
return
|
|
|
device_type, device_ids = parse_device(device)
|
|
|
- self._cfg["device"] = device_type
|
|
|
if device_type not in self.SUPPORT_DEVICE:
|
|
|
support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
|
|
|
raise ValueError(
|
|
|
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
|
|
|
)
|
|
|
+ self._cfg["device"] = device_type
|
|
|
device_id = device_ids[0] if device_ids is not None else 0
|
|
|
self._cfg["device_id"] = device_id
|
|
|
set_env_for_device(device)
|