datasetbase.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import pickle
  15. import os.path as osp
  16. import random
  17. from .utils import copy_directory
  18. class DatasetBase(object):
  19. def __init__(self, dataset_id, path):
  20. self.id = dataset_id
  21. self.path = path
  22. self.all_files = list()
  23. self.file_info = dict()
  24. self.label_info = dict()
  25. self.labels = list()
  26. self.train_files = list()
  27. self.val_files = list()
  28. self.test_files = list()
  29. self.class_train_file_list = dict()
  30. self.class_val_file_list = dict()
  31. self.class_test_file_list = dict()
  32. def copy_dataset(self, source_path, files):
  33. # 将原数据集拷贝至目标路径
  34. copy_directory(source_path, self.path, files)
  35. def dump_statis_info(self):
  36. # info['fields']指定了需要dump的信息
  37. info = dict()
  38. info['fields'] = [
  39. 'file_info', 'label_info', 'labels', 'train_files', 'val_files',
  40. 'test_files', 'class_train_file_list', 'class_val_file_list',
  41. 'class_test_file_list'
  42. ]
  43. for field in info['fields']:
  44. if hasattr(self, field):
  45. info[field] = getattr(self, field)
  46. with open(osp.join(self.path, 'statis.pkl'), 'wb') as f:
  47. pickle.dump(info, f)
  48. def load_statis_info(self):
  49. with open(osp.join(self.path, 'statis.pkl'), 'rb') as f:
  50. info = pickle.load(f)
  51. for field in info['fields']:
  52. if field in info:
  53. setattr(self, field, info[field])
  54. def split(self, val_split, test_split):
  55. all_files = list(self.file_info.keys())
  56. random.shuffle(all_files)
  57. val_num = int(len(all_files) * val_split)
  58. test_num = int(len(all_files) * test_split)
  59. train_num = len(all_files) - val_num - test_num
  60. assert train_num > 0, "训练集样本数量需大于0"
  61. assert val_num > 0, "验证集样本数量需大于0"
  62. self.train_files = all_files[:train_num]
  63. self.val_files = all_files[train_num:train_num + val_num]
  64. self.test_files = all_files[train_num + val_num:]
  65. self.train_set = set(self.train_files)
  66. self.val_set = set(self.val_files)
  67. self.test_set = set(self.test_files)
  68. for label, file_list in self.label_info.items():
  69. self.class_train_file_list[label] = list()
  70. self.class_val_file_list[label] = list()
  71. self.class_test_file_list[label] = list()
  72. for f in file_list:
  73. if f in self.test_set:
  74. self.class_test_file_list[label].append(f)
  75. if f in self.val_set:
  76. self.class_val_file_list[label].append(f)
  77. if f in self.train_set:
  78. self.class_train_file_list[label].append(f)