@@ -60,6 +60,7 @@ model.train(
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
+ pretrain_weights=None,
learning_rate=0.001 / 8,
warmup_steps=1000,
warmup_start_lr=0.0,
@@ -56,5 +56,6 @@ model.train(
train_batch_size=4,
learning_rate=0.01,
save_dir='output/unet/prune')