소스 검색

replace dcu with gpu when using toolkits

gaotingquan 9 달 전
부모
커밋
0011f594ba
3개의 변경된 파일24개의 추가작업 그리고 9개의 파일을 삭제
  1. 8 3
      paddlex/modules/base/evaluator.py
  2. 8 3
      paddlex/modules/base/exportor.py
  3. 8 3
      paddlex/modules/base/trainer.py

+ 8 - 3
paddlex/modules/base/evaluator.py

@@ -150,9 +150,14 @@ evaling!"
         """
         check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
-        if using_device_number:
-            return update_device_num(self.global_config.device, using_device_number)
-        return self.global_config.device
+        device_setting = (
+            update_device_num(self.global_config.device, using_device_number)
+            if using_device_number
+            else self.global_config.device
+        )
+        # replace "dcu" with "gpu"
+        device_setting = device_setting.replace("dcu", "gpu")
+        return device_setting
 
     @abstractmethod
     def update_config(self):

+ 8 - 3
paddlex/modules/base/exportor.py

@@ -120,9 +120,14 @@ exporting!"
         """
         check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
-        if using_device_number:
-            return update_device_num(self.global_config.device, using_device_number)
-        return self.global_config.device
+        device_setting = (
+            update_device_num(self.global_config.device, using_device_number)
+            if using_device_number
+            else self.global_config.device
+        )
+        # replace "dcu" with "gpu"
+        device_setting = device_setting.replace("dcu", "gpu")
+        return device_setting
 
     def update_config(self):
         """update export config"""

+ 8 - 3
paddlex/modules/base/trainer.py

@@ -114,9 +114,14 @@ training!"
         """
         check_supported_device(self.global_config.device, self.global_config.model)
         set_env_for_device(self.global_config.device)
-        if using_device_number:
-            return update_device_num(self.global_config.device, using_device_number)
-        return self.global_config.device
+        device_setting = (
+            update_device_num(self.global_config.device, using_device_number)
+            if using_device_number
+            else self.global_config.device
+        )
+        # replace "dcu" with "gpu"
+        device_setting = device_setting.replace("dcu", "gpu")
+        return device_setting
 
     @abstractmethod
     def update_config(self):