浏览代码

fix get params bug

wangsiyuan06 4 年之前
父节点
当前提交
86a2fda6bd
共有 2 个文件被更改,包括 11 次插入11 次删除
  1. 1 1
      paddlex/restful/project/task.py
  2. 10 10
      paddlex/restful/project/train/params_v2.py

+ 1 - 1
paddlex/restful/project/task.py

@@ -207,7 +207,7 @@ def get_default_params(data, workspace, machine_info):
         per_gpu_memory = 0
         gpu_list = None
     else:
-        if gpu_list in data:
+        if 'gpu_list' in data:
             gpu_list = data['gpu_list']
             gpu_num = len(gpu_list)
             per_gpu_memory = None

+ 10 - 10
paddlex/restful/project/train/params_v2.py

@@ -25,36 +25,36 @@ def get_base_params(model_type, per_gpu_memory, num_train_samples, num_gpu,
         params['cuda_visible_devices'] = str(gpu_list).strip("[]")
         if model_type.startswith('MobileNet'):
             batch_size = (per_gpu_memory - 513) // 57 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 125)
+            batch_size = min(batch_size, num_gpu * 125)
         elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
             or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
             or model_type.startswith('ShuffleNet'):
             batch_size = (per_gpu_memory - 739) // 211 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 16)
+            batch_size = min(batch_size, num_gpu * 16)
         elif model_type.startswith('YOLOv3'):
             batch_size = (per_gpu_memory - 1555) // 943 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 8)
+            batch_size = min(batch_size, num_gpu * 8)
         elif model_type.startswith('PPYOLO'):
             batch_size = (per_gpu_memory - 1691) // 1025 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 8)
+            batch_size = min(batch_size, num_gpu * 8)
         elif model_type.startswith('FasterRCNN'):
             batch_size = (per_gpu_memory - 1755) // 915 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 2)
+            batch_size = min(batch_size, num_gpu * 2)
         elif model_type.startswith('MaskRCNN'):
             batch_size = (per_gpu_memory - 2702) // 1188 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 2)
+            batch_size = min(batch_size, num_gpu * 2)
         elif model_type.startswith('DeepLab'):
             batch_size = (per_gpu_memory - 1469) // 1605 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 4)
+            batch_size = min(batch_size, num_gpu * 4)
         elif model_type.startswith('UNet'):
             batch_size = (per_gpu_memory - 1275) // 1256 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 4)
+            batch_size = min(batch_size, num_gpu * 4)
         elif model_type.startswith('HRNet_W18'):
             batch_size = (per_gpu_memory - 800) // 682 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 4)
+            batch_size = min(batch_size, num_gpu * 4)
         elif model_type.startswith('FastSCNN'):
             batch_size = (per_gpu_memory - 636) // 144 * num_gpu
-            batch_size = min(batch_size, gpu_nums * 4)
+            batch_size = min(batch_size, num_gpu * 4)
     if batch_size > num_train_samples // 2:
         batch_size = num_train_samples // 2
     if batch_size < 1: