|
|
@@ -251,6 +251,8 @@ class StaticInfer:
|
|
|
elif self.option.device == "mlu":
|
|
|
config.enable_custom_device("mlu")
|
|
|
elif self.option.device == "dcu":
|
|
|
+ config.enable_use_gpu(100, self.option.device_id)
|
|
|
+ # XXX: is_compiled_with_rocm() must be True on dcu platform ?
|
|
|
if paddle.is_compiled_with_rocm():
|
|
|
# Delete unsupported passes in dcu
|
|
|
config.delete_pass("conv2d_add_act_fuse_pass")
|