Quellcode durchsuchen

replace dcu with gpu when using toolkits

gaotingquan vor 9 Monaten
Ursprung
Commit
0011f594ba

+ 8 - 3
paddlex/modules/base/evaluator.py

@@ -150,9 +150,14 @@ evaling!"
         """
         check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
-        if using_device_number:
-            return update_device_num(self.global_config.device, using_device_number)
-        return self.global_config.device
+        device_setting = (
+            update_device_num(self.global_config.device, using_device_number)
+            if using_device_number
+            else self.global_config.device
+        )
+        # replace "dcu" with "gpu"
+        device_setting = device_setting.replace("dcu", "gpu")
+        return device_setting
 
     @abstractmethod
     def update_config(self):

+ 8 - 3
paddlex/modules/base/exportor.py

@@ -120,9 +120,14 @@ exporting!"
         """
         check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
-        if using_device_number:
-            return update_device_num(self.global_config.device, using_device_number)
-        return self.global_config.device
+        device_setting = (
+            update_device_num(self.global_config.device, using_device_number)
+            if using_device_number
+            else self.global_config.device
+        )
+        # replace "dcu" with "gpu"
+        device_setting = device_setting.replace("dcu", "gpu")
+        return device_setting
 
     def update_config(self):
         """update export config"""

+ 8 - 3
paddlex/modules/base/trainer.py

@@ -114,9 +114,14 @@ training!"
         """
         check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
-        if using_device_number:
-            return update_device_num(self.global_config.device, using_device_number)
-        return self.global_config.device
+        device_setting = (
+            update_device_num(self.global_config.device, using_device_number)
+            if using_device_number
+            else self.global_config.device
+        )
+        # replace "dcu" with "gpu"
+        device_setting = device_setting.replace("dcu", "gpu")
+        return device_setting
 
     @abstractmethod
     def update_config(self):