Bläddra i källkod

mv setting_environ_flags to paddlex/__init__

FlyingQianMM 5 år sedan
förälder
incheckning
e57c4bf739
2 ändrade filer med 14 tillägg och 13 borttagningar
  1. 14 2
      paddlex/__init__.py
  2. 0 11
      paddlex/utils/utils.py

+ 14 - 2
paddlex/__init__.py

@@ -13,6 +13,14 @@
 # limitations under the License.
 
 from __future__ import absolute_import
+import os
+if 'FLAGS_eager_delete_tensor_gb' not in os.environ:
+    os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
+if 'FLAGS_allocator_strategy' not in os.environ:
+    os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
+if "CUDA_VISIBLE_DEVICES" in os.environ:
+    if os.environ["CUDA_VISIBLE_DEVICES"].count("-1") > 0:
+        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 from .utils.utils import get_environ_info
 from . import cv
 from . import det
@@ -23,8 +31,12 @@ from . import slim
 try:
     import pycocotools
 except:
-    print("[WARNING] pycocotools is not installed, detection model is not available now.")
-    print("[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md")
+    print(
+        "[WARNING] pycocotools is not installed, detection model is not available now."
+    )
+    print(
+        "[WARNING] pycocotools install: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/install.md"
+    )
 
 env_info = get_environ_info()
 load_model = cv.models.load_model

+ 0 - 11
paddlex/utils/utils.py

@@ -31,18 +31,7 @@ def seconds_to_hms(seconds):
     return hms_str
 
 
-def setting_environ_flags():
-    if 'FLAGS_eager_delete_tensor_gb' not in os.environ:
-        os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
-    if 'FLAGS_allocator_strategy' not in os.environ:
-        os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
-    if "CUDA_VISIBLE_DEVICES" in os.environ:
-        if os.environ["CUDA_VISIBLE_DEVICES"].count("-1") > 0:
-            os.environ["CUDA_VISIBLE_DEVICES"] = ""
-
-
 def get_environ_info():
-    setting_environ_flags()
     import paddle.fluid as fluid
     info = dict()
     info['place'] = 'cpu'