|
@@ -105,7 +105,7 @@ class UnimernetModel(object):
|
|
|
# Create dataset with sorted images
|
|
# Create dataset with sorted images
|
|
|
dataset = MathDataset(sorted_images, transform=self.model.transform)
|
|
dataset = MathDataset(sorted_images, transform=self.model.transform)
|
|
|
|
|
|
|
|
- # 如果batch_size> len(sorted_images),则设置为不超过len(sorted_images)的2的幂
|
|
|
|
|
|
|
+ # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂
|
|
|
batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
|
|
batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
|