ソースを参照

Merge pull request #829 from will-jl944/windows_bug_fix

Forcibily convert image labels to int64 in ImagenetDataset
FlyingQianMM 4 年 前
コミット
e42dc31c0f
1 ファイル変更3 行追加1 行削除
  1. 3 1
      dygraph/paddlex/cv/datasets/imagenet.py

+ 3 - 1
dygraph/paddlex/cv/datasets/imagenet.py

@@ -14,6 +14,7 @@
 
 import os.path as osp
 import copy
+import numpy as np
 from paddle.io import Dataset
 from paddlex.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 
@@ -70,7 +71,8 @@ class ImageNet(Dataset):
                         full_path))
                 self.file_list.append({
                     'image': full_path,
-                    'label': int(items[1])
+                    'label': np.asarray(
+                        items[1], dtype=np.int64)
                 })
         self.num_samples = len(self.file_list)
         logging.info("{} samples in file {}".format(