dataset.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. try:
  17. from collections.abc import Sequence
  18. except Exception:
  19. from collections import Sequence
  20. from paddle.io import Dataset
  21. from paddlex.ppdet.core.workspace import register, serializable
  22. from paddlex.ppdet.utils.download import get_dataset_path
  23. import copy
  24. @serializable
  25. class DetDataset(Dataset):
  26. """
  27. Load detection dataset.
  28. Args:
  29. dataset_dir (str): root directory for dataset.
  30. image_dir (str): directory for images.
  31. anno_path (str): annotation file path.
  32. data_fields (list): key name of data dictionary, at least have 'image'.
  33. sample_num (int): number of samples to load, -1 means all.
  34. use_default_label (bool): whether to load default label list.
  35. """
  36. def __init__(self,
  37. dataset_dir=None,
  38. image_dir=None,
  39. anno_path=None,
  40. data_fields=['image'],
  41. sample_num=-1,
  42. use_default_label=None,
  43. **kwargs):
  44. super(DetDataset, self).__init__()
  45. self.dataset_dir = dataset_dir if dataset_dir is not None else ''
  46. self.anno_path = anno_path
  47. self.image_dir = image_dir if image_dir is not None else ''
  48. self.data_fields = data_fields
  49. self.sample_num = sample_num
  50. self.use_default_label = use_default_label
  51. self._epoch = 0
  52. self._curr_iter = 0
  53. def __len__(self, ):
  54. return len(self.roidbs)
  55. def __getitem__(self, idx):
  56. # data batch
  57. roidb = copy.deepcopy(self.roidbs[idx])
  58. if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
  59. n = len(self.roidbs)
  60. idx = np.random.randint(n)
  61. roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
  62. elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
  63. n = len(self.roidbs)
  64. idx = np.random.randint(n)
  65. roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
  66. elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
  67. n = len(self.roidbs)
  68. roidb = [roidb, ] + [
  69. copy.deepcopy(self.roidbs[np.random.randint(n)])
  70. for _ in range(3)
  71. ]
  72. if isinstance(roidb, Sequence):
  73. for r in roidb:
  74. r['curr_iter'] = self._curr_iter
  75. else:
  76. roidb['curr_iter'] = self._curr_iter
  77. self._curr_iter += 1
  78. return self.transform(roidb)
  79. def check_or_download_dataset(self):
  80. self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
  81. self.image_dir)
  82. def set_kwargs(self, **kwargs):
  83. self.mixup_epoch = kwargs.get('mixup_epoch', -1)
  84. self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
  85. self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
  86. def set_transform(self, transform):
  87. self.transform = transform
  88. def set_epoch(self, epoch_id):
  89. self._epoch = epoch_id
  90. def parse_dataset(self, ):
  91. raise NotImplementedError(
  92. "Need to implement parse_dataset method of Dataset")
  93. def get_anno(self):
  94. if self.anno_path is None:
  95. return
  96. return os.path.join(self.dataset_dir, self.anno_path)
  97. def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
  98. return f.lower().endswith(extensions)
  99. def _make_dataset(dir):
  100. dir = os.path.expanduser(dir)
  101. if not os.path.isdir(dir):
  102. raise ('{} should be a dir'.format(dir))
  103. images = []
  104. for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
  105. for fname in sorted(fnames):
  106. path = os.path.join(root, fname)
  107. if _is_valid_file(path):
  108. images.append(path)
  109. return images
  110. @register
  111. @serializable
  112. class ImageFolder(DetDataset):
  113. def __init__(self,
  114. dataset_dir=None,
  115. image_dir=None,
  116. anno_path=None,
  117. sample_num=-1,
  118. use_default_label=None,
  119. **kwargs):
  120. super(ImageFolder, self).__init__(
  121. dataset_dir,
  122. image_dir,
  123. anno_path,
  124. sample_num=sample_num,
  125. use_default_label=use_default_label)
  126. self._imid2path = {}
  127. self.roidbs = None
  128. self.sample_num = sample_num
  129. def check_or_download_dataset(self):
  130. if self.dataset_dir:
  131. # NOTE: ImageFolder is only used for prediction, in
  132. # infer mode, image_dir is set by set_images
  133. # so we only check anno_path here
  134. self.dataset_dir = get_dataset_path(self.dataset_dir,
  135. self.anno_path, None)
  136. def parse_dataset(self, ):
  137. if not self.roidbs:
  138. self.roidbs = self._load_images()
  139. def _parse(self):
  140. image_dir = self.image_dir
  141. if not isinstance(image_dir, Sequence):
  142. image_dir = [image_dir]
  143. images = []
  144. for im_dir in image_dir:
  145. if os.path.isdir(im_dir):
  146. im_dir = os.path.join(self.dataset_dir, im_dir)
  147. images.extend(_make_dataset(im_dir))
  148. elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
  149. images.append(im_dir)
  150. return images
  151. def _load_images(self):
  152. images = self._parse()
  153. ct = 0
  154. records = []
  155. for image in images:
  156. assert image != '' and os.path.isfile(image), \
  157. "Image {} not found".format(image)
  158. if self.sample_num > 0 and ct >= self.sample_num:
  159. break
  160. rec = {'im_id': np.array([ct]), 'im_file': image}
  161. self._imid2path[ct] = image
  162. ct += 1
  163. records.append(rec)
  164. assert len(records) > 0, "No image file found"
  165. return records
  166. def get_imid2path(self):
  167. return self._imid2path
  168. def set_images(self, images):
  169. self.image_dir = images
  170. self.roidbs = self._load_images()