dataset.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from threading import Thread
  15. import multiprocessing
  16. import collections
  17. import numpy as np
  18. import six
  19. import sys
  20. import copy
  21. import random
  22. import platform
  23. import chardet
  24. import paddlex.utils.logging as logging
  25. class EndSignal():
  26. pass
  27. def is_pic(img_name):
  28. valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
  29. suffix = img_name.split('.')[-1]
  30. if suffix not in valid_suffix:
  31. return False
  32. return True
  33. def is_valid(sample):
  34. if sample is None:
  35. return False
  36. if isinstance(sample, tuple):
  37. for s in sample:
  38. if s is None:
  39. return False
  40. elif isinstance(s, np.ndarray) and s.size == 0:
  41. return False
  42. elif isinstance(s, collections.abc.Sequence) and len(s) == 0:
  43. return False
  44. return True
  45. def get_encoding(path):
  46. f = open(path, 'rb')
  47. data = f.read()
  48. file_encoding = chardet.detect(data).get('encoding')
  49. f.close()
  50. return file_encoding
  51. def multithread_reader(mapper,
  52. reader,
  53. num_workers=4,
  54. buffer_size=1024,
  55. batch_size=8,
  56. drop_last=True):
  57. from queue import Queue
  58. end = EndSignal()
  59. # define a worker to read samples from reader to in_queue
  60. def read_worker(reader, in_queue):
  61. for i in reader():
  62. in_queue.put(i)
  63. in_queue.put(end)
  64. # define a worker to handle samples from in_queue by mapper
  65. # and put mapped samples into out_queue
  66. def handle_worker(in_queue, out_queue, mapper):
  67. sample = in_queue.get()
  68. while not isinstance(sample, EndSignal):
  69. if len(sample) == 2:
  70. r = mapper(sample[0], sample[1])
  71. elif len(sample) == 3:
  72. r = mapper(sample[0], sample[1], sample[2])
  73. else:
  74. raise Exception('The sample\'s length must be 2 or 3.')
  75. if is_valid(r):
  76. out_queue.put(r)
  77. sample = in_queue.get()
  78. in_queue.put(end)
  79. out_queue.put(end)
  80. def xreader():
  81. in_queue = Queue(buffer_size)
  82. out_queue = Queue(buffer_size)
  83. # start a read worker in a thread
  84. target = read_worker
  85. t = Thread(target=target, args=(reader, in_queue))
  86. t.daemon = True
  87. t.start()
  88. # start several handle_workers
  89. target = handle_worker
  90. args = (in_queue, out_queue, mapper)
  91. workers = []
  92. for i in range(num_workers):
  93. worker = Thread(target=target, args=args)
  94. worker.daemon = True
  95. workers.append(worker)
  96. for w in workers:
  97. w.start()
  98. batch_data = []
  99. sample = out_queue.get()
  100. while not isinstance(sample, EndSignal):
  101. batch_data.append(sample)
  102. if len(batch_data) == batch_size:
  103. batch_data = generate_minibatch(batch_data)
  104. yield batch_data
  105. batch_data = []
  106. sample = out_queue.get()
  107. finish = 1
  108. while finish < num_workers:
  109. sample = out_queue.get()
  110. if isinstance(sample, EndSignal):
  111. finish += 1
  112. else:
  113. batch_data.append(sample)
  114. if len(batch_data) == batch_size:
  115. batch_data = generate_minibatch(batch_data)
  116. yield batch_data
  117. batch_data = []
  118. if not drop_last and len(batch_data) != 0:
  119. batch_data = generate_minibatch(batch_data)
  120. yield batch_data
  121. batch_data = []
  122. return xreader
  123. def multiprocess_reader(mapper,
  124. reader,
  125. num_workers=4,
  126. buffer_size=1024,
  127. batch_size=8,
  128. drop_last=True):
  129. from .shared_queue import SharedQueue as Queue
  130. def _read_into_queue(samples, mapper, queue):
  131. end = EndSignal()
  132. try:
  133. for sample in samples:
  134. if sample is None:
  135. raise ValueError("sample has None")
  136. if len(sample) == 2:
  137. result = mapper(sample[0], sample[1])
  138. elif len(sample) == 3:
  139. result = mapper(sample[0], sample[1], sample[2])
  140. else:
  141. raise Exception('The sample\'s length must be 2 or 3.')
  142. if is_valid(result):
  143. queue.put(result)
  144. queue.put(end)
  145. except:
  146. queue.put("")
  147. six.reraise(*sys.exc_info())
  148. def queue_reader():
  149. queue = Queue(buffer_size, memsize=3 * 1024**3)
  150. total_samples = [[] for i in range(num_workers)]
  151. for i, sample in enumerate(reader()):
  152. index = i % num_workers
  153. total_samples[index].append(sample)
  154. for i in range(num_workers):
  155. p = multiprocessing.Process(
  156. target=_read_into_queue,
  157. args=(total_samples[i], mapper, queue))
  158. p.start()
  159. finish_num = 0
  160. batch_data = list()
  161. while finish_num < num_workers:
  162. sample = queue.get()
  163. if isinstance(sample, EndSignal):
  164. finish_num += 1
  165. elif sample == "":
  166. raise ValueError("multiprocess reader raises an exception")
  167. else:
  168. batch_data.append(sample)
  169. if len(batch_data) == batch_size:
  170. batch_data = generate_minibatch(batch_data)
  171. yield batch_data
  172. batch_data = []
  173. if len(batch_data) != 0 and not drop_last:
  174. batch_data = generate_minibatch(batch_data)
  175. yield batch_data
  176. batch_data = []
  177. return queue_reader
  178. def generate_minibatch(batch_data, label_padding_value=255):
  179. # if batch_size is 1, do not pad the image
  180. if len(batch_data) == 1:
  181. return batch_data
  182. width = [data[0].shape[2] for data in batch_data]
  183. height = [data[0].shape[1] for data in batch_data]
  184. # if the sizes of images in a mini-batch are equal,
  185. # do not pad the image
  186. if len(set(width)) == 1 and len(set(height)) == 1:
  187. return batch_data
  188. max_shape = np.array([data[0].shape for data in batch_data]).max(axis=0)
  189. padding_batch = []
  190. for data in batch_data:
  191. # pad the image to a same size
  192. im_c, im_h, im_w = data[0].shape[:]
  193. padding_im = np.zeros(
  194. (im_c, max_shape[1], max_shape[2]), dtype=np.float32)
  195. padding_im[:, :im_h, :im_w] = data[0]
  196. if len(data) > 1:
  197. if isinstance(data[1], np.ndarray):
  198. # padding the image and label of segmentation
  199. # during the training and evaluating phase
  200. padding_label = np.zeros(
  201. (1, max_shape[1], max_shape[2]
  202. )).astype('int64') + label_padding_value
  203. _, label_h, label_w = data[1].shape
  204. padding_label[:, :label_h, :label_w] = data[1]
  205. padding_batch.append((padding_im, padding_label))
  206. elif len(data[1]) == 0 or isinstance(
  207. data[1][0],
  208. tuple) and data[1][0][0] in ['resize', 'padding']:
  209. # padding the image and insert 'padding' into `im_info`
  210. # of segmentation during the infering phase
  211. if len(data[1]) == 0 or 'padding' not in [
  212. data[1][i][0] for i in range(len(data[1]))
  213. ]:
  214. data[1].append(('padding', [im_h, im_w]))
  215. padding_batch.append((padding_im, ) + tuple(data[1:]))
  216. else:
  217. # padding the image of detection, or
  218. # padding the image of classification during the trainging
  219. # and evaluating phase
  220. padding_batch.append((padding_im, ) + tuple(data[1:]))
  221. else:
  222. # padding the image of classification during the infering phase
  223. padding_batch.append((padding_im))
  224. return padding_batch
  225. class Dataset:
  226. def __init__(self,
  227. transforms=None,
  228. num_workers='auto',
  229. buffer_size=100,
  230. parallel_method='process',
  231. shuffle=False):
  232. if num_workers == 'auto':
  233. import multiprocessing as mp
  234. num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
  235. if platform.platform().startswith("Darwin") or platform.platform(
  236. ).startswith("Windows"):
  237. parallel_method = 'thread'
  238. if transforms is None:
  239. raise Exception("transform should be defined.")
  240. self.transforms = transforms
  241. self.num_workers = num_workers
  242. self.buffer_size = buffer_size
  243. self.parallel_method = parallel_method
  244. self.shuffle = shuffle
  245. def generator(self, batch_size=1, drop_last=True):
  246. self.batch_size = batch_size
  247. parallel_reader = multithread_reader
  248. if self.parallel_method == "process":
  249. if platform.platform().startswith("Windows"):
  250. logging.debug(
  251. "multiprocess_reader is not supported in Windows platform, force to use multithread_reader."
  252. )
  253. else:
  254. parallel_reader = multiprocess_reader
  255. return parallel_reader(
  256. self.transforms,
  257. self.iterator,
  258. num_workers=self.num_workers,
  259. buffer_size=self.buffer_size,
  260. batch_size=batch_size,
  261. drop_last=drop_last)
  262. def set_num_samples(self, num_samples):
  263. if num_samples > len(self.file_list):
  264. logging.warning(
  265. "You want set num_samples to {}, but your dataset only has {} samples, so we will keep your dataset num_samples as {}"
  266. .format(num_samples, len(self.file_list), len(self.file_list)))
  267. num_samples = len(self.file_list)
  268. self.num_samples = num_samples