operators.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from __future__ import unicode_literals
  20. import six
  21. import math
  22. import random
  23. import cv2
  24. import numpy as np
  25. from .autoaugment import ImageNetPolicy
  26. class OperatorParamError(ValueError):
  27. """ OperatorParamError
  28. """
  29. pass
  30. class DecodeImage(object):
  31. """ decode image """
  32. def __init__(self, to_rgb=True, to_np=False, channel_first=False):
  33. self.to_rgb = to_rgb
  34. self.to_np = to_np # to numpy
  35. self.channel_first = channel_first # only enabled when to_np is True
  36. def __call__(self, img):
  37. if six.PY2:
  38. assert type(img) is str and len(
  39. img) > 0, "invalid input 'img' in DecodeImage"
  40. else:
  41. assert type(img) is bytes and len(
  42. img) > 0, "invalid input 'img' in DecodeImage"
  43. data = np.frombuffer(img, dtype='uint8')
  44. img = cv2.imdecode(data, 1)
  45. if self.to_rgb:
  46. assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
  47. img.shape)
  48. img = img[:, :, ::-1]
  49. if self.channel_first:
  50. img = img.transpose((2, 0, 1))
  51. return img
  52. class ResizeImage(object):
  53. """ resize image """
  54. def __init__(self, size=None, resize_short=None, interpolation=-1):
  55. self.interpolation = interpolation if interpolation >= 0 else None
  56. if resize_short is not None and resize_short > 0:
  57. self.resize_short = resize_short
  58. self.w = None
  59. self.h = None
  60. elif size is not None:
  61. self.resize_short = None
  62. self.w = size if type(size) is int else size[0]
  63. self.h = size if type(size) is int else size[1]
  64. else:
  65. raise OperatorParamError("invalid params for ReisizeImage for '\
  66. 'both 'size' and 'resize_short' are None")
  67. def __call__(self, img):
  68. img_h, img_w = img.shape[:2]
  69. if self.resize_short is not None:
  70. percent = float(self.resize_short) / min(img_w, img_h)
  71. w = int(round(img_w * percent))
  72. h = int(round(img_h * percent))
  73. else:
  74. w = self.w
  75. h = self.h
  76. if self.interpolation is None:
  77. return cv2.resize(img, (w, h))
  78. else:
  79. return cv2.resize(img, (w, h), interpolation=self.interpolation)
  80. class CropImage(object):
  81. """ crop image """
  82. def __init__(self, size):
  83. if type(size) is int:
  84. self.size = (size, size)
  85. else:
  86. self.size = size # (h, w)
  87. def __call__(self, img):
  88. w, h = self.size
  89. img_h, img_w = img.shape[:2]
  90. w_start = (img_w - w) // 2
  91. h_start = (img_h - h) // 2
  92. w_end = w_start + w
  93. h_end = h_start + h
  94. return img[h_start:h_end, w_start:w_end, :]
  95. class RandCropImage(object):
  96. """ random crop image """
  97. def __init__(self, size, scale=None, ratio=None, interpolation=-1):
  98. self.interpolation = interpolation if interpolation >= 0 else None
  99. if type(size) is int:
  100. self.size = (size, size) # (h, w)
  101. else:
  102. self.size = size
  103. self.scale = [0.08, 1.0] if scale is None else scale
  104. self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
  105. def __call__(self, img):
  106. size = self.size
  107. scale = self.scale
  108. ratio = self.ratio
  109. aspect_ratio = math.sqrt(random.uniform(*ratio))
  110. w = 1. * aspect_ratio
  111. h = 1. / aspect_ratio
  112. img_h, img_w = img.shape[:2]
  113. bound = min((float(img_w) / img_h) / (w**2),
  114. (float(img_h) / img_w) / (h**2))
  115. scale_max = min(scale[1], bound)
  116. scale_min = min(scale[0], bound)
  117. target_area = img_w * img_h * random.uniform(scale_min, scale_max)
  118. target_size = math.sqrt(target_area)
  119. w = int(target_size * w)
  120. h = int(target_size * h)
  121. i = random.randint(0, img_w - w)
  122. j = random.randint(0, img_h - h)
  123. img = img[j:j + h, i:i + w, :]
  124. if self.interpolation is None:
  125. return cv2.resize(img, size)
  126. else:
  127. return cv2.resize(img, size, interpolation=self.interpolation)
  128. class RandFlipImage(object):
  129. """ random flip image
  130. flip_code:
  131. 1: Flipped Horizontally
  132. 0: Flipped Vertically
  133. -1: Flipped Horizontally & Vertically
  134. """
  135. def __init__(self, flip_code=1):
  136. assert flip_code in [-1, 0, 1
  137. ], "flip_code should be a value in [-1, 0, 1]"
  138. self.flip_code = flip_code
  139. def __call__(self, img):
  140. if random.randint(0, 1) == 1:
  141. return cv2.flip(img, self.flip_code)
  142. else:
  143. return img
  144. class AutoAugment(object):
  145. def __init__(self):
  146. self.policy = ImageNetPolicy()
  147. def __call__(self, img):
  148. from PIL import Image
  149. img = np.ascontiguousarray(img)
  150. img = Image.fromarray(img)
  151. img = self.policy(img)
  152. img = np.asarray(img)
  153. class NormalizeImage(object):
  154. """ normalize image such as substract mean, divide std
  155. """
  156. def __init__(self,
  157. scale=None,
  158. mean=None,
  159. std=None,
  160. order='chw',
  161. output_fp16=False,
  162. channel_num=3):
  163. if isinstance(scale, str):
  164. scale = eval(scale)
  165. assert channel_num in [
  166. 3, 4
  167. ], "channel number of input image should be set to 3 or 4."
  168. self.channel_num = channel_num
  169. self.output_dtype = 'float16' if output_fp16 else 'float32'
  170. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  171. self.order = order
  172. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  173. std = std if std is not None else [0.229, 0.224, 0.225]
  174. shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
  175. self.mean = np.array(mean).reshape(shape).astype('float32')
  176. self.std = np.array(std).reshape(shape).astype('float32')
  177. def __call__(self, img):
  178. from PIL import Image
  179. if isinstance(img, Image.Image):
  180. img = np.array(img)
  181. assert isinstance(img,
  182. np.ndarray), "invalid input 'img' in NormalizeImage"
  183. img = (img.astype('float32') * self.scale - self.mean) / self.std
  184. if self.channel_num == 4:
  185. img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
  186. img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
  187. pad_zeros = np.zeros(
  188. (1, img_h, img_w)) if self.order == 'chw' else np.zeros(
  189. (img_h, img_w, 1))
  190. img = (np.concatenate(
  191. (img, pad_zeros), axis=0)
  192. if self.order == 'chw' else np.concatenate(
  193. (img, pad_zeros), axis=2))
  194. return img.astype(self.output_dtype)
  195. class ToCHWImage(object):
  196. """ convert hwc image to chw image
  197. """
  198. def __init__(self):
  199. pass
  200. def __call__(self, img):
  201. from PIL import Image
  202. if isinstance(img, Image.Image):
  203. img = np.array(img)
  204. return img.transpose((2, 0, 1))