Procházet zdrojové kódy

Merge pull request #829 from will-jl944/windows_bug_fix

Forcibily convert image labels to int64 in ImagenetDataset
FlyingQianMM před 4 roky
rodič
revize
e42dc31c0f
1 změnil soubory, kde provedl 3 přidání a 1 odebrání
  1. 3 1
      dygraph/paddlex/cv/datasets/imagenet.py

+ 3 - 1
dygraph/paddlex/cv/datasets/imagenet.py

@@ -14,6 +14,7 @@
 
 import os.path as osp
 import copy
+import numpy as np
 from paddle.io import Dataset
 from paddlex.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
@@ -70,7 +71,8 @@ class ImageNet(Dataset):
                         full_path))
                 self.file_list.append({
                     'image': full_path,
-                    'label': int(items[1])
+                    'label': np.asarray(
+                        items[1], dtype=np.int64)
                 })
         self.num_samples = len(self.file_list)
         logging.info("{} samples in file {}".format(