a31413510 8 meses atrás
pai
commit
7b5288bd6f
1 arquivos alterados com 8 adições e 1 exclusões
  1. 8 1
      paddlex/inference/models/common/static_infer.py

+ 8 - 1
paddlex/inference/models/common/static_infer.py

@@ -236,12 +236,19 @@ class StaticInfer:
                     )
         elif self.option.device == "npu":
             config.enable_custom_device("npu")
+            if hasattr(config, "enable_new_executor"):
+                config.enable_new_executor()
         elif self.option.device == "xpu":
-            pass
+            if hasattr(config, "enable_new_executor"):
+                config.enable_new_executor()
         elif self.option.device == "mlu":
             config.enable_custom_device("mlu")
+            if hasattr(config, "enable_new_executor"):
+                config.enable_new_executor()
         elif self.option.device == "dcu":
             config.enable_use_gpu(100, self.option.device_id)
+            if hasattr(config, "enable_new_executor"):
+                config.enable_new_executor()
             # XXX: is_compiled_with_rocm() must be True on dcu platform ?
             if paddle.is_compiled_with_rocm():
                 # Delete unsupported passes in dcu