Jelajahi Sumber

use dcu device type

gaotingquan 9 bulan lalu
induk
melakukan
a5b997876d
1 mengubah file dengan 6 tambahan dan 6 penghapusan
  1. 6 6
      paddlex/inference/models/common/static_infer.py

+ 6 - 6
paddlex/inference/models/common/static_infer.py

@@ -216,12 +216,7 @@ class StaticInfer:
                 params_file = trt_save_path + ".pdiparams"
 
         config = Config(model_file, params_file)
-        # The DCU also uses "GPU" as the flag
-        if self.option.device == "gpu" and paddle.is_compiled_with_rocm():
-            # Delete unsupported passes in dcu
-            config.delete_pass("conv2d_add_act_fuse_pass")
-            config.delete_pass("conv2d_add_fuse_pass")
-        elif self.option.device == "gpu":
+        if self.option.device == "gpu":
             config.exp_disable_mixed_precision_ops({"feed", "fetch"})
             config.enable_use_gpu(100, self.option.device_id)
             if not self.option.run_mode.startswith("trt"):
@@ -255,6 +250,11 @@ class StaticInfer:
             pass
         elif self.option.device == "mlu":
             config.enable_custom_device("mlu")
+        elif self.option.device == "dcu":
+            if paddle.is_compiled_with_rocm():
+                # Delete unsupported passes in dcu
+                config.delete_pass("conv2d_add_act_fuse_pass")
+                config.delete_pass("conv2d_add_fuse_pass")
         else:
             assert self.option.device == "cpu"
             config.disable_gpu()