瀏覽代碼

support multy device set device id in inference (#3925)

a31413510 6 月之前
父節點
當前提交
8e822d163c
共有 1 個文件被更改,包括 5 次插入3 次删除
  1. 5 3
      paddlex/inference/models/common/static_infer.py

+ 5 - 3
paddlex/inference/models/common/static_infer.py

@@ -373,7 +373,7 @@ class PaddleInfer(StaticInfer):
             logging.debug("`device_id` has been set to None")
 
         if (
-            self._option.device_type in ("gpu", "dcu")
+            self._option.device_type in ("gpu", "dcu", "npu", "mlu", "gcu", "xpu")
             and self._option.device_id is None
         ):
             self._option.device_id = 0
@@ -417,12 +417,14 @@ class PaddleInfer(StaticInfer):
                 if hasattr(config, "enable_new_executor"):
                     config.enable_new_executor()
             elif self._option.device_type == "xpu":
+                config.enable_xpu()
+                config.set_xpu_device_id(self._option.device_id)
                 if hasattr(config, "enable_new_ir"):
                     config.enable_new_ir(self._option.enable_new_ir)
                 if hasattr(config, "enable_new_executor"):
                     config.enable_new_executor()
             elif self._option.device_type == "mlu":
-                config.enable_custom_device("mlu")
+                config.enable_custom_device("mlu", self._option.device_id)
                 if hasattr(config, "enable_new_ir"):
                     config.enable_new_ir(self._option.enable_new_ir)
                 if hasattr(config, "enable_new_executor"):
@@ -431,7 +433,7 @@ class PaddleInfer(StaticInfer):
                 from paddle_custom_device.gcu import passes as gcu_passes
 
                 gcu_passes.setUp()
-                config.enable_custom_device("gcu")
+                config.enable_custom_device("gcu", self._option.device_id)
                 if hasattr(config, "enable_new_ir"):
                     config.enable_new_ir()
                 if hasattr(config, "enable_new_executor"):