|
@@ -25,36 +25,36 @@ def get_base_params(model_type, per_gpu_memory, num_train_samples, num_gpu,
|
|
|
params['cuda_visible_devices'] = str(gpu_list).strip("[]")
|
|
params['cuda_visible_devices'] = str(gpu_list).strip("[]")
|
|
|
if model_type.startswith('MobileNet'):
|
|
if model_type.startswith('MobileNet'):
|
|
|
batch_size = (per_gpu_memory - 513) // 57 * num_gpu
|
|
batch_size = (per_gpu_memory - 513) // 57 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 125)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 125)
|
|
|
elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
|
|
elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
|
|
|
or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
|
|
or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
|
|
|
or model_type.startswith('ShuffleNet'):
|
|
or model_type.startswith('ShuffleNet'):
|
|
|
batch_size = (per_gpu_memory - 739) // 211 * num_gpu
|
|
batch_size = (per_gpu_memory - 739) // 211 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 16)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 16)
|
|
|
elif model_type.startswith('YOLOv3'):
|
|
elif model_type.startswith('YOLOv3'):
|
|
|
batch_size = (per_gpu_memory - 1555) // 943 * num_gpu
|
|
batch_size = (per_gpu_memory - 1555) // 943 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 8)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 8)
|
|
|
elif model_type.startswith('PPYOLO'):
|
|
elif model_type.startswith('PPYOLO'):
|
|
|
batch_size = (per_gpu_memory - 1691) // 1025 * num_gpu
|
|
batch_size = (per_gpu_memory - 1691) // 1025 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 8)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 8)
|
|
|
elif model_type.startswith('FasterRCNN'):
|
|
elif model_type.startswith('FasterRCNN'):
|
|
|
batch_size = (per_gpu_memory - 1755) // 915 * num_gpu
|
|
batch_size = (per_gpu_memory - 1755) // 915 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 2)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 2)
|
|
|
elif model_type.startswith('MaskRCNN'):
|
|
elif model_type.startswith('MaskRCNN'):
|
|
|
batch_size = (per_gpu_memory - 2702) // 1188 * num_gpu
|
|
batch_size = (per_gpu_memory - 2702) // 1188 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 2)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 2)
|
|
|
elif model_type.startswith('DeepLab'):
|
|
elif model_type.startswith('DeepLab'):
|
|
|
batch_size = (per_gpu_memory - 1469) // 1605 * num_gpu
|
|
batch_size = (per_gpu_memory - 1469) // 1605 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 4)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 4)
|
|
|
elif model_type.startswith('UNet'):
|
|
elif model_type.startswith('UNet'):
|
|
|
batch_size = (per_gpu_memory - 1275) // 1256 * num_gpu
|
|
batch_size = (per_gpu_memory - 1275) // 1256 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 4)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 4)
|
|
|
elif model_type.startswith('HRNet_W18'):
|
|
elif model_type.startswith('HRNet_W18'):
|
|
|
batch_size = (per_gpu_memory - 800) // 682 * num_gpu
|
|
batch_size = (per_gpu_memory - 800) // 682 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 4)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 4)
|
|
|
elif model_type.startswith('FastSCNN'):
|
|
elif model_type.startswith('FastSCNN'):
|
|
|
batch_size = (per_gpu_memory - 636) // 144 * num_gpu
|
|
batch_size = (per_gpu_memory - 636) // 144 * num_gpu
|
|
|
- batch_size = min(batch_size, gpu_nums * 4)
|
|
|
|
|
|
|
+ batch_size = min(batch_size, num_gpu * 4)
|
|
|
if batch_size > num_train_samples // 2:
|
|
if batch_size > num_train_samples // 2:
|
|
|
batch_size = num_train_samples // 2
|
|
batch_size = num_train_samples // 2
|
|
|
if batch_size < 1:
|
|
if batch_size < 1:
|