|
|
@@ -216,7 +216,12 @@ class StaticInfer:
|
|
|
params_file = trt_save_path + ".pdiparams"
|
|
|
|
|
|
config = Config(model_file, params_file)
|
|
|
- if self.option.device == "gpu":
|
|
|
+ # The DCU also uses "GPU" as the flag
|
|
|
+ if self.option.device == "gpu" and paddle.is_compiled_with_rocm():
|
|
|
+ # Delete unsupported passes in dcu
|
|
|
+ config.delete_pass("conv2d_add_act_fuse_pass")
|
|
|
+ config.delete_pass("conv2d_add_fuse_pass")
|
|
|
+ elif self.option.device == "gpu":
|
|
|
config.exp_disable_mixed_precision_ops({"feed", "fetch"})
|
|
|
config.enable_use_gpu(100, self.option.device_id)
|
|
|
if not self.option.run_mode.startswith("trt"):
|
|
|
@@ -250,11 +255,6 @@ class StaticInfer:
|
|
|
pass
|
|
|
elif self.option.device == "mlu":
|
|
|
config.enable_custom_device("mlu")
|
|
|
- elif self.option.device == "dcu":
|
|
|
- if paddle.is_compiled_with_rocm():
|
|
|
- # Delete unsupported passes in dcu
|
|
|
- config.delete_pass("conv2d_add_act_fuse_pass")
|
|
|
- config.delete_pass("conv2d_add_fuse_pass")
|
|
|
else:
|
|
|
assert self.option.device == "cpu"
|
|
|
config.disable_gpu()
|