reader.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import random
  16. import imghdr
  17. import os
  18. import signal
  19. from paddle.io import Dataset, DataLoader, DistributedBatchSampler
  20. from . import imaug
  21. from .imaug import transform
  22. from paddlex.ppcls.utils import logger
  23. trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
  24. trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0))
  25. class ModeException(Exception):
  26. """
  27. ModeException
  28. """
  29. def __init__(self, message='', mode=''):
  30. message += "\nOnly the following 3 modes are supported: " \
  31. "train, valid, test. Given mode is {}".format(mode)
  32. super(ModeException, self).__init__(message)
  33. class SampleNumException(Exception):
  34. """
  35. SampleNumException
  36. """
  37. def __init__(self, message='', sample_num=0, batch_size=1):
  38. message += "\nError: The number of the whole data ({}) " \
  39. "is smaller than the batch_size ({}), and drop_last " \
  40. "is turnning on, so nothing will feed in program, " \
  41. "Terminated now. Please reset batch_size to a smaller " \
  42. "number or feed more data!".format(sample_num, batch_size)
  43. super(SampleNumException, self).__init__(message)
  44. class ShuffleSeedException(Exception):
  45. """
  46. ShuffleSeedException
  47. """
  48. def __init__(self, message=''):
  49. message += "\nIf trainers_num > 1, the shuffle_seed must be set, " \
  50. "because the order of batch data generated by reader " \
  51. "must be the same in the respective processes."
  52. super(ShuffleSeedException, self).__init__(message)
  53. def check_params(params):
  54. """
  55. check params to avoid unexpect errors
  56. Args:
  57. params(dict):
  58. """
  59. if 'shuffle_seed' not in params:
  60. params['shuffle_seed'] = None
  61. if trainers_num > 1 and params['shuffle_seed'] is None:
  62. raise ShuffleSeedException()
  63. data_dir = params.get('data_dir', '')
  64. assert os.path.isdir(data_dir), \
  65. "{} doesn't exist, please check datadir path".format(data_dir)
  66. if params['mode'] != 'test':
  67. file_list = params.get('file_list', '')
  68. assert os.path.isfile(file_list), \
  69. "{} doesn't exist, please check file list path".format(file_list)
  70. def create_file_list(params):
  71. """
  72. if mode is test, create the file list
  73. Args:
  74. params(dict):
  75. """
  76. data_dir = params.get('data_dir', '')
  77. params['file_list'] = ".tmp.txt"
  78. imgtype_list = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'}
  79. with open(params['file_list'], "w") as fout:
  80. tmp_file_list = os.listdir(data_dir)
  81. for file_name in tmp_file_list:
  82. file_path = os.path.join(data_dir, file_name)
  83. if imghdr.what(file_path) not in imgtype_list:
  84. continue
  85. fout.write(file_name + " 0" + "\n")
  86. def shuffle_lines(full_lines, seed=None):
  87. """
  88. random shuffle lines
  89. Args:
  90. full_lines(list):
  91. seed(int): random seed
  92. """
  93. if seed is not None:
  94. np.random.RandomState(seed).shuffle(full_lines)
  95. else:
  96. np.random.shuffle(full_lines)
  97. return full_lines
  98. def get_file_list(params):
  99. """
  100. read label list from file and shuffle the list
  101. Args:
  102. params(dict):
  103. """
  104. if params['mode'] == 'test':
  105. create_file_list(params)
  106. with open(params['file_list']) as flist:
  107. full_lines = [line.strip() for line in flist]
  108. if params["mode"] == "train":
  109. full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed'])
  110. return full_lines
  111. def create_operators(params):
  112. """
  113. create operators based on the config
  114. Args:
  115. params(list): a dict list, used to create some operators
  116. """
  117. assert isinstance(params, list), ('operator config should be a list')
  118. ops = []
  119. for operator in params:
  120. assert isinstance(operator,
  121. dict) and len(operator) == 1, "yaml format error"
  122. op_name = list(operator)[0]
  123. param = {} if operator[op_name] is None else operator[op_name]
  124. op = getattr(imaug, op_name)(**param)
  125. ops.append(op)
  126. return ops
  127. def term_mp(sig_num, frame):
  128. """ kill all child processes
  129. """
  130. pid = os.getpid()
  131. pgid = os.getpgid(os.getpid())
  132. logger.info("main proc {} exit, kill process group "
  133. "{}".format(pid, pgid))
  134. os.killpg(pgid, signal.SIGKILL)
  135. return
  136. class CommonDataset(Dataset):
  137. def __init__(self, params):
  138. self.params = params
  139. self.mode = params.get("mode", "train")
  140. self.full_lines = get_file_list(params)
  141. self.delimiter = params.get('delimiter', ' ')
  142. self.ops = create_operators(params['transforms'])
  143. self.num_samples = len(self.full_lines)
  144. return
  145. def __getitem__(self, idx):
  146. try:
  147. line = self.full_lines[idx]
  148. img_path, label = line.split(self.delimiter)
  149. img_path = os.path.join(self.params['data_dir'], img_path)
  150. with open(img_path, 'rb') as f:
  151. img = f.read()
  152. return (transform(img, self.ops), int(label))
  153. except Exception as e:
  154. logger.error("data read faild: {}, exception info: {}".format(line,
  155. e))
  156. return self.__getitem__(random.randint(0, len(self)))
  157. def __len__(self):
  158. return self.num_samples
  159. class MultiLabelDataset(Dataset):
  160. """
  161. Define dataset class for multilabel image classification
  162. """
  163. def __init__(self, params):
  164. self.params = params
  165. self.mode = params.get("mode", "train")
  166. self.full_lines = get_file_list(params)
  167. self.delimiter = params.get("delimiter", "\t")
  168. self.ops = create_operators(params["transforms"])
  169. self.num_samples = len(self.full_lines)
  170. return
  171. def __getitem__(self, idx):
  172. try:
  173. line = self.full_lines[idx]
  174. img_path, label_str = line.split(self.delimiter)
  175. img_path = os.path.join(self.params["data_dir"], img_path)
  176. with open(img_path, "rb") as f:
  177. img = f.read()
  178. labels = label_str.split(',')
  179. labels = [int(i) for i in labels]
  180. return (transform(img, self.ops),
  181. np.array(labels).astype("float32"))
  182. except Exception as e:
  183. logger.error("data read failed: {}, exception info: {}".format(
  184. line, e))
  185. return self.__getitem__(random.randint(0, len(self)))
  186. def __len__(self):
  187. return self.num_samples
  188. class Reader:
  189. """
  190. Create a reader for trainning/validate/test
  191. Args:
  192. config(dict): arguments
  193. mode(str): train or val or test
  194. seed(int): random seed used to generate same sequence in each trainer
  195. Returns:
  196. the specific reader
  197. """
  198. def __init__(self, config, mode='train', places=None):
  199. try:
  200. self.params = config[mode.upper()]
  201. except KeyError:
  202. raise ModeException(mode=mode)
  203. use_mix = config.get('use_mix')
  204. self.params['mode'] = mode
  205. self.shuffle = mode == "train"
  206. self.collate_fn = None
  207. self.batch_ops = []
  208. if use_mix and mode == "train":
  209. self.batch_ops = create_operators(self.params['mix'])
  210. self.collate_fn = self.mix_collate_fn
  211. self.places = places
  212. self.use_xpu = config.get("use_xpu", False)
  213. self.multilabel = config.get("multilabel", False)
  214. def mix_collate_fn(self, batch):
  215. batch = transform(batch, self.batch_ops)
  216. # batch each field
  217. slots = []
  218. for items in batch:
  219. for i, item in enumerate(items):
  220. if len(slots) < len(items):
  221. slots.append([item])
  222. else:
  223. slots[i].append(item)
  224. return [np.stack(slot, axis=0) for slot in slots]
  225. def __call__(self):
  226. batch_size = int(self.params['batch_size']) // trainers_num
  227. if self.multilabel:
  228. dataset = MultiLabelDataset(self.params)
  229. else:
  230. dataset = CommonDataset(self.params)
  231. if (self.params['mode'] != "train") and self.use_xpu:
  232. loader = DataLoader(
  233. dataset,
  234. places=self.places,
  235. batch_size=batch_size,
  236. drop_last=False,
  237. return_list=True,
  238. shuffle=False,
  239. num_workers=self.params["num_workers"])
  240. else:
  241. is_train = self.params['mode'] == "train"
  242. batch_sampler = DistributedBatchSampler(
  243. dataset,
  244. batch_size=batch_size,
  245. shuffle=self.shuffle and is_train,
  246. drop_last=is_train)
  247. loader = DataLoader(
  248. dataset,
  249. batch_sampler=batch_sampler,
  250. collate_fn=self.collate_fn if is_train else None,
  251. places=self.places,
  252. return_list=True,
  253. num_workers=self.params["num_workers"])
  254. return loader
  255. signal.signal(signal.SIGINT, term_mp)
  256. signal.signal(signal.SIGTERM, term_mp)