readers.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
  3. import cv2
  4. import numpy as np
  5. import six
  6. import glob
  7. from as_data_reader.data_path_utils import _find_classes
  8. from PIL import Image
  9. def resize_short(img, target_size, interpolation=None):
  10. """resize image
  11. Args:
  12. img: image data
  13. target_size: resize short target size
  14. interpolation: interpolation mode
  15. Returns:
  16. resized image data
  17. """
  18. percent = float(target_size) / min(img.shape[0], img.shape[1])
  19. resized_width = int(round(img.shape[1] * percent))
  20. resized_height = int(round(img.shape[0] * percent))
  21. if interpolation:
  22. resized = cv2.resize(
  23. img, (resized_width, resized_height), interpolation=interpolation)
  24. else:
  25. resized = cv2.resize(img, (resized_width, resized_height))
  26. return resized
  27. def crop_image(img, target_size, center=True):
  28. """crop image
  29. Args:
  30. img: images data
  31. target_size: crop target size
  32. center: crop mode
  33. Returns:
  34. img: cropped image data
  35. """
  36. height, width = img.shape[:2]
  37. size = target_size
  38. if center:
  39. w_start = (width - size) // 2
  40. h_start = (height - size) // 2
  41. else:
  42. w_start = np.random.randint(0, width - size + 1)
  43. h_start = np.random.randint(0, height - size + 1)
  44. w_end = w_start + size
  45. h_end = h_start + size
  46. img = img[h_start:h_end, w_start:w_end, :]
  47. return img
  48. def preprocess_image(img, random_mirror=False):
  49. """
  50. centered, scaled by 1/255.
  51. :param img: np.array: shape: [ns, h, w, 3], color order: rgb.
  52. :return: np.array: shape: [ns, h, w, 3]
  53. """
  54. mean = [0.485, 0.456, 0.406]
  55. std = [0.229, 0.224, 0.225]
  56. # transpose to [ns, 3, h, w]
  57. img = img.astype('float32').transpose((0, 3, 1, 2)) / 255
  58. img_mean = np.array(mean).reshape((3, 1, 1))
  59. img_std = np.array(std).reshape((3, 1, 1))
  60. img -= img_mean
  61. img /= img_std
  62. if random_mirror:
  63. mirror = int(np.random.uniform(0, 2))
  64. if mirror == 1:
  65. img = img[:, :, ::-1, :]
  66. return img
  67. def read_image(img_path, target_size=256, crop_size=224):
  68. """
  69. resize_short to 256, then center crop to 224.
  70. :param img_path: one image path
  71. :return: np.array: shape: [1, h, w, 3], color order: rgb.
  72. """
  73. if isinstance(img_path, str):
  74. with open(img_path, 'rb') as f:
  75. img = Image.open(f)
  76. img = img.convert('RGB')
  77. img = np.array(img)
  78. # img = cv2.imread(img_path)
  79. img = resize_short(img, target_size, interpolation=None)
  80. img = crop_image(img, target_size=crop_size, center=True)
  81. # img = img[:, :, ::-1]
  82. img = np.expand_dims(img, axis=0)
  83. return img
  84. elif isinstance(img_path, np.ndarray):
  85. assert len(img_path.shape) == 4
  86. return img_path
  87. else:
  88. ValueError(f"Not recognized data type {type(img_path)}.")
  89. class ReaderConfig(object):
  90. """
  91. A generic data loader where the images are arranged in this way:
  92. root/train/dog/xxy.jpg
  93. root/train/dog/xxz.jpg
  94. ...
  95. root/train/cat/nsdf3.jpg
  96. root/train/cat/asd932_.jpg
  97. ...
  98. root/test/dog/xxx.jpg
  99. ...
  100. root/test/cat/123.jpg
  101. ...
  102. """
  103. def __init__(self, dataset_dir, is_test):
  104. image_paths, labels, self.num_classes = self.get_dataset_info(dataset_dir, is_test)
  105. random_per = np.random.permutation(range(len(image_paths)))
  106. self.image_paths = image_paths[random_per]
  107. self.labels = labels[random_per]
  108. self.is_test = is_test
  109. def get_reader(self):
  110. def reader():
  111. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
  112. target_size = 256
  113. crop_size = 224
  114. for i, img_path in enumerate(self.image_paths):
  115. if not img_path.lower().endswith(IMG_EXTENSIONS):
  116. continue
  117. img = cv2.imread(img_path)
  118. if img is None:
  119. print(img_path)
  120. continue
  121. img = resize_short(img, target_size, interpolation=None)
  122. img = crop_image(img, crop_size, center=self.is_test)
  123. img = img[:, :, ::-1]
  124. img = np.expand_dims(img, axis=0)
  125. img = preprocess_image(img, not self.is_test)
  126. yield img, self.labels[i]
  127. return reader
  128. def get_dataset_info(self, dataset_dir, is_test=False):
  129. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
  130. # read
  131. if is_test:
  132. datasubset_dir = os.path.join(dataset_dir, 'test')
  133. else:
  134. datasubset_dir = os.path.join(dataset_dir, 'train')
  135. class_names, class_to_idx = _find_classes(datasubset_dir)
  136. # num_classes = len(class_names)
  137. image_paths = []
  138. labels = []
  139. for class_name in class_names:
  140. classes_dir = os.path.join(datasubset_dir, class_name)
  141. for img_path in glob.glob(os.path.join(classes_dir, '*')):
  142. if not img_path.lower().endswith(IMG_EXTENSIONS):
  143. continue
  144. image_paths.append(img_path)
  145. labels.append(class_to_idx[class_name])
  146. image_paths = np.array(image_paths)
  147. labels = np.array(labels)
  148. return image_paths, labels, len(class_names)
  149. def create_reader(list_image_path, list_label=None, is_test=False):
  150. def reader():
  151. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
  152. target_size = 256
  153. crop_size = 224
  154. for i, img_path in enumerate(list_image_path):
  155. if not img_path.lower().endswith(IMG_EXTENSIONS):
  156. continue
  157. img = cv2.imread(img_path)
  158. if img is None:
  159. print(img_path)
  160. continue
  161. img = resize_short(img, target_size, interpolation=None)
  162. img = crop_image(img, crop_size, center=is_test)
  163. img = img[:, :, ::-1]
  164. img_show = np.expand_dims(img, axis=0)
  165. img = preprocess_image(img_show, not is_test)
  166. label = 0 if list_label is None else list_label[i]
  167. yield img_show, img, label
  168. return reader