Просмотр исходного кода

Merge pull request #835 from will-jl944/develop_jf

Refinement for non-linux systems
FlyingQianMM 4 лет назад
Родитель
Сommit
ca90902815
2 измененных файлов с 12 добавлено и 4 удалено
  1. 7 4
      dygraph/paddlex/cv/models/base.py
  2. 5 0
      dygraph/paddlex/utils/env.py

+ 7 - 4
dygraph/paddlex/cv/models/base.py

@@ -191,11 +191,14 @@ class BaseModel:
             shuffle=dataset.shuffle,
             drop_last=mode == 'train')
 
-        shm_size = _get_shared_memory_size_in_M()
-        if shm_size is None or shm_size < 1024.:
-            use_shared_memory = False
+        if dataset.num_workers > 0:
+            shm_size = _get_shared_memory_size_in_M()
+            if shm_size is None or shm_size < 1024.:
+                use_shared_memory = False
+            else:
+                use_shared_memory = True
         else:
-            use_shared_memory = True
+            use_shared_memory = False
 
         loader = DataLoader(
             dataset,

+ 5 - 0
dygraph/paddlex/utils/env.py

@@ -16,6 +16,7 @@ import sys
 import glob
 import os
 import os.path as osp
+import platform
 import random
 import numpy as np
 import multiprocessing as mp
@@ -47,6 +48,10 @@ def get_environ_info():
 
 
 def get_num_workers(num_workers):
+    if not platform.system() == 'Linux':
+        # Dataloader with multi-process model is not supported
+        # on MacOS and Windows currently.
+        return 0
     if num_workers == 'auto':
         num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 2 else 2
     return num_workers