Browse Source

dcu also use gpu as the flag

gaotingquan 9 months ago
parent
commit
fcbf6b11d2
1 changed files with 6 additions and 6 deletions
  1. 6 6
      paddlex/inference/models/common/static_infer.py

+ 6 - 6
paddlex/inference/models/common/static_infer.py

@@ -216,7 +216,12 @@ class StaticInfer:
                 params_file = trt_save_path + ".pdiparams"
 
         config = Config(model_file, params_file)
-        if self.option.device == "gpu":
+        # The DCU also uses "GPU" as the flag
+        if self.option.device == "gpu" and paddle.is_compiled_with_rocm():
+            # Delete unsupported passes in dcu
+            config.delete_pass("conv2d_add_act_fuse_pass")
+            config.delete_pass("conv2d_add_fuse_pass")
+        elif self.option.device == "gpu":
             config.exp_disable_mixed_precision_ops({"feed", "fetch"})
             config.enable_use_gpu(100, self.option.device_id)
             if not self.option.run_mode.startswith("trt"):
@@ -250,11 +255,6 @@ class StaticInfer:
             pass
         elif self.option.device == "mlu":
             config.enable_custom_device("mlu")
-        elif self.option.device == "dcu":
-            if paddle.is_compiled_with_rocm():
-                # Delete unsupported passes in dcu
-                config.delete_pass("conv2d_add_act_fuse_pass")
-                config.delete_pass("conv2d_add_fuse_pass")
         else:
             assert self.option.device == "cpu"
             config.disable_gpu()