Prechádzať zdrojové kódy

fix env setting & logging debug when set env

gaotingquan 1 rok pred
rodič
commit
c43a3119bc

+ 2 - 1
paddlex/modules/base/evaluator.py

@@ -17,7 +17,7 @@ from pathlib import Path
 from abc import ABC, abstractmethod
 
 from .build_model import build_model
-from ...utils.device import update_device_num
+from ...utils.device import update_device_num, set_env_for_device
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 from ...utils.logging import *
@@ -140,6 +140,7 @@ evaling!"
         """
         if using_device_number:
             return update_device_num(self.global_config.device, using_device_number)
+        set_env_for_device(self.global_config.device)
         return self.global_config.device
 
     @abstractmethod

+ 2 - 1
paddlex/modules/base/exportor.py

@@ -17,7 +17,7 @@ from pathlib import Path
 from abc import ABC, abstractmethod
 
 from .build_model import build_model
-from ...utils.device import update_device_num
+from ...utils.device import update_device_num, set_env_for_device
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 from ...utils.logging import *
@@ -105,6 +105,7 @@ exporting!"
         """
         if using_device_number:
             return update_device_num(self.global_config.device, using_device_number)
+        set_env_for_device(self.global_config.device)
         return self.global_config.device
 
     def update_config(self):

+ 2 - 1
paddlex/modules/base/trainer.py

@@ -16,7 +16,7 @@ import os
 from abc import ABC, abstractmethod
 from pathlib import Path
 from .build_model import build_model
-from ...utils.device import update_device_num
+from ...utils.device import update_device_num, set_env_for_device
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
 
@@ -90,6 +90,7 @@ training!"
         """
         if using_device_number:
             return update_device_num(self.global_config.device, using_device_number)
+        set_env_for_device(self.global_config.device)
         return self.global_config.device
 
     @abstractmethod

+ 26 - 13
paddlex/utils/device.py

@@ -15,6 +15,8 @@
 import os
 import GPUtil
 import lazy_paddle as paddle
+
+from . import logging
 from .errors import raise_unsupported_device_error
 
 SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
@@ -69,22 +71,33 @@ def update_device_num(device, num):
 
 
 def set_env_for_device(device):
+    def _set(envs):
+        for key, val in envs.items():
+            os.environ[key] = val
+            logging.debug(f"{key} has been set to {val}.")
+
     device_type, device_ids = parse_device(device)
     if device_type.lower() in ["gpu", "xpu", "npu", "mlu"]:
         if device_type.lower() == "gpu" and paddle.is_compiled_with_rocm():
-            os.environ["FLAGS_conv_workspace_size_limit"] = "2000"
+            envs = {"FLAGS_conv_workspace_size_limit": "2000"}
+            _set(envs)
         if device_type.lower() == "npu":
-            os.environ["FLAGS_npu_jit_compile"] = "0"
-            os.environ["FLAGS_use_stride_kernel"] = "0"
-            os.environ["FLAGS_allocator_strategy"] = "auto_growth"
-            os.environ["CUSTOM_DEVICE_BLACK_LIST"] = (
-                "pad3d,pad3d_grad,set_value,set_value_with_tensor"
-            )
-            os.environ["FLAGS_npu_scale_aclnn"] = "True"
-            os.environ["FLAGS_npu_split_aclnn"] = "True"
+            envs = {
+                "FLAGS_npu_jit_compile": "0",
+                "FLAGS_use_stride_kernel": "0",
+                "FLAGS_allocator_strategy": "auto_growth",
+                "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
+                "FLAGS_npu_scale_aclnn": "True",
+                "FLAGS_npu_split_aclnn": "True",
+            }
+            _set(envs)
         if device_type.lower() == "xpu":
-            os.environ["BKCL_FORCE_SYNC"] = "1"
-            os.environ["BKCL_TIMEOUT"] = "1800"
-            os.environ["FLAGS_use_stride_kernel"] = "0"
+            envs = {
+                "BKCL_FORCE_SYNC": "1",
+                "BKCL_TIMEOUT": "1800",
+                "FLAGS_use_stride_kernel": "0",
+            }
+            _set(envs)
         if device_type.lower() == "mlu":
-            os.environ["FLAGS_use_stride_kernel"] = "0"
+            envs = {"FLAGS_use_stride_kernel": "0"}
+            _set(envs)