|
|
@@ -20,6 +20,7 @@ import copy
|
|
|
import math
|
|
|
import yaml
|
|
|
import json
|
|
|
+import numpy as np
|
|
|
import paddle
|
|
|
from paddle.io import DataLoader, DistributedBatchSampler
|
|
|
from paddleslim import QAT
|
|
|
@@ -249,7 +250,9 @@ class BaseModel:
|
|
|
collate_fn=dataset.batch_transforms,
|
|
|
num_workers=dataset.num_workers,
|
|
|
return_list=True,
|
|
|
- use_shared_memory=use_shared_memory)
|
|
|
+ use_shared_memory=use_shared_memory,
|
|
|
+ worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
|
+ )
|
|
|
|
|
|
return loader
|
|
|
|