|
|
@@ -236,12 +236,19 @@ class StaticInfer:
|
|
|
)
|
|
|
elif self.option.device == "npu":
|
|
|
config.enable_custom_device("npu")
|
|
|
+ if hasattr(config, "enable_new_executor"):
|
|
|
+ config.enable_new_executor()
|
|
|
elif self.option.device == "xpu":
|
|
|
- pass
|
|
|
+ if hasattr(config, "enable_new_executor"):
|
|
|
+ config.enable_new_executor()
|
|
|
elif self.option.device == "mlu":
|
|
|
config.enable_custom_device("mlu")
|
|
|
+ if hasattr(config, "enable_new_executor"):
|
|
|
+ config.enable_new_executor()
|
|
|
elif self.option.device == "dcu":
|
|
|
config.enable_use_gpu(100, self.option.device_id)
|
|
|
+ if hasattr(config, "enable_new_executor"):
|
|
|
+ config.enable_new_executor()
|
|
|
# XXX: is_compiled_with_rocm() must be True on dcu platform ?
|
|
|
if paddle.is_compiled_with_rocm():
|
|
|
# Delete unsupported passes in dcu
|