Explorar o código

Merge pull request #906 from will-jl944/develop_jf

set dataloader worker_init_fn explictly
FlyingQianMM %!s(int64=4) %!d(string=hai) anos
pai
achega
085d593cb3
Modificáronse 1 ficheiros con 4 adicións e 1 borrados
  1. 4 1
      dygraph/paddlex/cv/models/base.py

+ 4 - 1
dygraph/paddlex/cv/models/base.py

@@ -20,6 +20,7 @@ import copy
 import math
 import yaml
 import json
+import numpy as np
 import paddle
 from paddle.io import DataLoader, DistributedBatchSampler
 from paddleslim import QAT
@@ -249,7 +250,9 @@ class BaseModel:
             collate_fn=dataset.batch_transforms,
             num_workers=dataset.num_workers,
             return_list=True,
-            use_shared_memory=use_shared_memory)
+            use_shared_memory=use_shared_memory,
+            worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id)
+        )
 
         return loader