readers.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import numpy as np
  18. import six
  19. import glob
  20. from .data_path_utils import _find_classes
  21. from PIL import Image
  22. import paddlex.utils.logging as logging
  23. def resize_short(img, target_size, interpolation=None):
  24. """resize image
  25. Args:
  26. img: image data
  27. target_size: resize short target size
  28. interpolation: interpolation mode
  29. Returns:
  30. resized image data
  31. """
  32. percent = float(target_size) / min(img.shape[0], img.shape[1])
  33. resized_width = int(round(img.shape[1] * percent))
  34. resized_height = int(round(img.shape[0] * percent))
  35. if interpolation:
  36. resized = cv2.resize(
  37. img, (resized_width, resized_height), interpolation=interpolation)
  38. else:
  39. resized = cv2.resize(img, (resized_width, resized_height))
  40. return resized
  41. def crop_image(img, target_size, center=True):
  42. """crop image
  43. Args:
  44. img: images data
  45. target_size: crop target size
  46. center: crop mode
  47. Returns:
  48. img: cropped image data
  49. """
  50. height, width = img.shape[:2]
  51. size = target_size
  52. if center:
  53. w_start = (width - size) // 2
  54. h_start = (height - size) // 2
  55. else:
  56. w_start = np.random.randint(0, width - size + 1)
  57. h_start = np.random.randint(0, height - size + 1)
  58. w_end = w_start + size
  59. h_end = h_start + size
  60. img = img[h_start:h_end, w_start:w_end, :]
  61. return img
  62. def preprocess_image(img, random_mirror=False):
  63. """
  64. centered, scaled by 1/255.
  65. :param img: np.array: shape: [ns, h, w, 3], color order: rgb.
  66. :return: np.array: shape: [ns, h, w, 3]
  67. """
  68. mean = [0.485, 0.456, 0.406]
  69. std = [0.229, 0.224, 0.225]
  70. # transpose to [ns, 3, h, w]
  71. img = img.astype('float32').transpose((0, 3, 1, 2)) / 255
  72. img_mean = np.array(mean).reshape((3, 1, 1))
  73. img_std = np.array(std).reshape((3, 1, 1))
  74. img -= img_mean
  75. img /= img_std
  76. if random_mirror:
  77. mirror = int(np.random.uniform(0, 2))
  78. if mirror == 1:
  79. img = img[:, :, ::-1, :]
  80. return img
  81. def read_image(img_path, target_size=256, crop_size=224):
  82. """
  83. resize_short to 256, then center crop to 224.
  84. :param img_path: one image path
  85. :return: np.array: shape: [1, h, w, 3], color order: rgb.
  86. """
  87. if isinstance(img_path, str):
  88. with open(img_path, 'rb') as f:
  89. img = Image.open(f)
  90. img = img.convert('RGB')
  91. img = np.array(img)
  92. # img = cv2.imread(img_path)
  93. img = resize_short(img, target_size, interpolation=None)
  94. img = crop_image(img, target_size=crop_size, center=True)
  95. # img = img[:, :, ::-1]
  96. img = np.expand_dims(img, axis=0)
  97. return img
  98. elif isinstance(img_path, np.ndarray):
  99. assert len(img_path.shape) == 4
  100. return img_path
  101. else:
  102. ValueError("Not recognized data type {}.".format(type(img_path)))
  103. class ReaderConfig(object):
  104. """
  105. A generic data loader where the images are arranged in this way:
  106. root/train/dog/xxy.jpg
  107. root/train/dog/xxz.jpg
  108. ...
  109. root/train/cat/nsdf3.jpg
  110. root/train/cat/asd932_.jpg
  111. ...
  112. root/test/dog/xxx.jpg
  113. ...
  114. root/test/cat/123.jpg
  115. ...
  116. """
  117. def __init__(self, dataset_dir, is_test):
  118. image_paths, labels, self.num_classes = self.get_dataset_info(
  119. dataset_dir, is_test)
  120. random_per = np.random.permutation(range(len(image_paths)))
  121. self.image_paths = image_paths[random_per]
  122. self.labels = labels[random_per]
  123. self.is_test = is_test
  124. def get_reader(self):
  125. def reader():
  126. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm',
  127. '.tif', '.tiff', '.webp')
  128. target_size = 256
  129. crop_size = 224
  130. for i, img_path in enumerate(self.image_paths):
  131. if not img_path.lower().endswith(IMG_EXTENSIONS):
  132. continue
  133. img = cv2.imread(img_path)
  134. if img is None:
  135. logging.info(img_path)
  136. continue
  137. img = resize_short(img, target_size, interpolation=None)
  138. img = crop_image(img, crop_size, center=self.is_test)
  139. img = img[:, :, ::-1]
  140. img = np.expand_dims(img, axis=0)
  141. img = preprocess_image(img, not self.is_test)
  142. yield img, self.labels[i]
  143. return reader
  144. def get_dataset_info(self, dataset_dir, is_test=False):
  145. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm',
  146. '.tif', '.tiff', '.webp')
  147. # read
  148. if is_test:
  149. datasubset_dir = os.path.join(dataset_dir, 'test')
  150. else:
  151. datasubset_dir = os.path.join(dataset_dir, 'train')
  152. class_names, class_to_idx = _find_classes(datasubset_dir)
  153. # num_classes = len(class_names)
  154. image_paths = []
  155. labels = []
  156. for class_name in class_names:
  157. classes_dir = os.path.join(datasubset_dir, class_name)
  158. for img_path in glob.glob(os.path.join(classes_dir, '*')):
  159. if not img_path.lower().endswith(IMG_EXTENSIONS):
  160. continue
  161. image_paths.append(img_path)
  162. labels.append(class_to_idx[class_name])
  163. image_paths = np.array(image_paths)
  164. labels = np.array(labels)
  165. return image_paths, labels, len(class_names)
  166. def create_reader(list_image_path, list_label=None, is_test=False):
  167. def reader():
  168. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm',
  169. '.tif', '.tiff', '.webp')
  170. target_size = 256
  171. crop_size = 224
  172. for i, img_path in enumerate(list_image_path):
  173. if not img_path.lower().endswith(IMG_EXTENSIONS):
  174. continue
  175. img = cv2.imread(img_path)
  176. if img is None:
  177. logging.info(img_path)
  178. continue
  179. img = resize_short(img, target_size, interpolation=None)
  180. img = crop_image(img, crop_size, center=is_test)
  181. img = img[:, :, ::-1]
  182. img_show = np.expand_dims(img, axis=0)
  183. img = preprocess_image(img_show, not is_test)
  184. label = 0 if list_label is None else list_label[i]
  185. yield img_show, img, label
  186. return reader