cls_transforms.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .ops import *
  15. import random
  16. import os.path as osp
  17. import numpy as np
  18. from PIL import Image, ImageEnhance
  19. class Compose:
  20. """根据数据预处理/增强算子对输入数据进行操作。
  21. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  22. Args:
  23. transforms (list): 数据预处理/增强算子。
  24. Raises:
  25. TypeError: 形参数据类型不满足需求。
  26. ValueError: 数据长度不匹配。
  27. """
  28. def __init__(self, transforms):
  29. if not isinstance(transforms, list):
  30. raise TypeError('The transforms must be a list!')
  31. if len(transforms) < 1:
  32. raise ValueError('The length of transforms ' + \
  33. 'must be equal or larger than 1!')
  34. self.transforms = transforms
  35. def __call__(self, im, label=None):
  36. """
  37. Args:
  38. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  39. label (int): 每张图像所对应的类别序号。
  40. Returns:
  41. tuple: 根据网络所需字段所组成的tuple;
  42. 字段由transforms中的最后一个数据预处理操作决定。
  43. """
  44. try:
  45. im = cv2.imread(im).astype('float32')
  46. except:
  47. raise TypeError('Can\'t read The image file {}!'.format(im))
  48. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  49. for op in self.transforms:
  50. outputs = op(im, label)
  51. im = outputs[0]
  52. if len(outputs) == 2:
  53. label = outputs[1]
  54. return outputs
  55. class RandomCrop:
  56. """对图像进行随机剪裁,模型训练时的数据增强操作。
  57. 1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
  58. 2. 根据随机剪裁的高、宽随机选取剪裁的起始点。
  59. 3. 剪裁图像。
  60. 4. 调整剪裁后的图像的大小到crop_size*crop_size。
  61. Args:
  62. crop_size (int): 随机裁剪后重新调整的目标边长。默认为224。
  63. lower_scale (float): 裁剪面积相对原面积比例的最小限制。默认为0.88。
  64. lower_ratio (float): 宽变换比例的最小限制。默认为3. / 4。
  65. upper_ratio (float): 宽变换比例的最大限制。默认为4. / 3。
  66. """
  67. def __init__(self,
  68. crop_size=224,
  69. lower_scale=0.88,
  70. lower_ratio=3. / 4,
  71. upper_ratio=4. / 3):
  72. self.crop_size = crop_size
  73. self.lower_scale = lower_scale
  74. self.lower_ratio = lower_ratio
  75. self.upper_ratio = upper_ratio
  76. def __call__(self, im, label=None):
  77. """
  78. Args:
  79. im (np.ndarray): 图像np.ndarray数据。
  80. label (int): 每张图像所对应的类别序号。
  81. Returns:
  82. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  83. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  84. """
  85. im = random_crop(im, self.crop_size, self.lower_scale,
  86. self.lower_ratio, self.upper_ratio)
  87. if label is None:
  88. return (im, )
  89. else:
  90. return (im, label)
  91. class RandomHorizontalFlip:
  92. """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
  93. Args:
  94. prob (float): 随机水平翻转的概率。默认为0.5。
  95. """
  96. def __init__(self, prob=0.5):
  97. self.prob = prob
  98. def __call__(self, im, label=None):
  99. """
  100. Args:
  101. im (np.ndarray): 图像np.ndarray数据。
  102. label (int): 每张图像所对应的类别序号。
  103. Returns:
  104. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  105. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  106. """
  107. if random.random() < self.prob:
  108. im = horizontal_flip(im)
  109. if label is None:
  110. return (im, )
  111. else:
  112. return (im, label)
  113. class RandomVerticalFlip:
  114. """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
  115. Args:
  116. prob (float): 随机垂直翻转的概率。默认为0.5。
  117. """
  118. def __init__(self, prob=0.5):
  119. self.prob = prob
  120. def __call__(self, im, label=None):
  121. """
  122. Args:
  123. im (np.ndarray): 图像np.ndarray数据。
  124. label (int): 每张图像所对应的类别序号。
  125. Returns:
  126. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  127. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  128. """
  129. if random.random() < self.prob:
  130. im = vertical_flip(im)
  131. if label is None:
  132. return (im, )
  133. else:
  134. return (im, label)
  135. class Normalize:
  136. """对图像进行标准化。
  137. 1. 对图像进行归一化到区间[0.0, 1.0]。
  138. 2. 对图像进行减均值除以标准差操作。
  139. Args:
  140. mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
  141. std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
  142. """
  143. def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
  144. self.mean = mean
  145. self.std = std
  146. def __call__(self, im, label=None):
  147. """
  148. Args:
  149. im (np.ndarray): 图像np.ndarray数据。
  150. label (int): 每张图像所对应的类别序号。
  151. Returns:
  152. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  153. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  154. """
  155. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  156. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  157. im = normalize(im, mean, std)
  158. if label is None:
  159. return (im, )
  160. else:
  161. return (im, label)
  162. class ResizeByShort:
  163. """根据图像短边对图像重新调整大小(resize)。
  164. 1. 获取图像的长边和短边长度。
  165. 2. 根据短边与short_size的比例,计算长边的目标长度,
  166. 此时高、宽的resize比例为short_size/原图短边长度。
  167. 3. 如果max_size>0,调整resize比例:
  168. 如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度;
  169. 4. 根据调整大小的比例对图像进行resize。
  170. Args:
  171. short_size (int): 调整大小后的图像目标短边长度。默认为256。
  172. max_size (int): 长边目标长度的最大限制。默认为-1。
  173. """
  174. def __init__(self, short_size=256, max_size=-1):
  175. self.short_size = short_size
  176. self.max_size = max_size
  177. def __call__(self, im, label=None):
  178. """
  179. Args:
  180. im (np.ndarray): 图像np.ndarray数据。
  181. label (int): 每张图像所对应的类别序号。
  182. Returns:
  183. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  184. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  185. """
  186. im_short_size = min(im.shape[0], im.shape[1])
  187. im_long_size = max(im.shape[0], im.shape[1])
  188. scale = float(self.short_size) / im_short_size
  189. if self.max_size > 0 and np.round(
  190. scale * im_long_size) > self.max_size:
  191. scale = float(self.max_size) / float(im_long_size)
  192. resized_width = int(round(im.shape[1] * scale))
  193. resized_height = int(round(im.shape[0] * scale))
  194. im = cv2.resize(
  195. im, (resized_width, resized_height),
  196. interpolation=cv2.INTER_LINEAR)
  197. if label is None:
  198. return (im, )
  199. else:
  200. return (im, label)
  201. class CenterCrop:
  202. """以图像中心点扩散裁剪长宽为`crop_size`的正方形
  203. 1. 计算剪裁的起始点。
  204. 2. 剪裁图像。
  205. Args:
  206. crop_size (int): 裁剪的目标边长。默认为224。
  207. """
  208. def __init__(self, crop_size=224):
  209. self.crop_size = crop_size
  210. def __call__(self, im, label=None):
  211. """
  212. Args:
  213. im (np.ndarray): 图像np.ndarray数据。
  214. label (int): 每张图像所对应的类别序号。
  215. Returns:
  216. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  217. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  218. """
  219. im = center_crop(im, self.crop_size)
  220. if label is None:
  221. return (im, )
  222. else:
  223. return (im, label)
  224. class RandomRotate:
  225. def __init__(self, rotate_range=30, prob=0.5):
  226. """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
  227. Args:
  228. rotate_range (int): 旋转度数的范围。默认为30。
  229. prob (float): 随机旋转的概率。默认为0.5。
  230. """
  231. self.rotate_range = rotate_range
  232. self.prob = prob
  233. def __call__(self, im, label=None):
  234. """
  235. Args:
  236. im (np.ndarray): 图像np.ndarray数据。
  237. label (int): 每张图像所对应的类别序号。
  238. Returns:
  239. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  240. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  241. """
  242. rotate_lower = -self.rotate_range
  243. rotate_upper = self.rotate_range
  244. im = im.astype('uint8')
  245. im = Image.fromarray(im)
  246. if np.random.uniform(0, 1) < self.prob:
  247. im = rotate(im, rotate_lower, rotate_upper)
  248. im = np.asarray(im).astype('float32')
  249. if label is None:
  250. return (im, )
  251. else:
  252. return (im, label)
  253. class RandomDistort:
  254. """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
  255. 1. 对变换的操作顺序进行随机化操作。
  256. 2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
  257. Args:
  258. brightness_range (float): 明亮度因子的范围。默认为0.9。
  259. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  260. contrast_range (float): 对比度因子的范围。默认为0.9。
  261. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  262. saturation_range (float): 饱和度因子的范围。默认为0.9。
  263. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  264. hue_range (int): 色调因子的范围。默认为18。
  265. hue_prob (float): 随机调整色调的概率。默认为0.5。
  266. """
  267. def __init__(self,
  268. brightness_range=0.9,
  269. brightness_prob=0.5,
  270. contrast_range=0.9,
  271. contrast_prob=0.5,
  272. saturation_range=0.9,
  273. saturation_prob=0.5,
  274. hue_range=18,
  275. hue_prob=0.5):
  276. self.brightness_range = brightness_range
  277. self.brightness_prob = brightness_prob
  278. self.contrast_range = contrast_range
  279. self.contrast_prob = contrast_prob
  280. self.saturation_range = saturation_range
  281. self.saturation_prob = saturation_prob
  282. self.hue_range = hue_range
  283. self.hue_prob = hue_prob
  284. def __call__(self, im, label=None):
  285. """
  286. Args:
  287. im (np.ndarray): 图像np.ndarray数据。
  288. label (int): 每张图像所对应的类别序号。
  289. Returns:
  290. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  291. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  292. """
  293. brightness_lower = 1 - self.brightness_range
  294. brightness_upper = 1 + self.brightness_range
  295. contrast_lower = 1 - self.contrast_range
  296. contrast_upper = 1 + self.contrast_range
  297. saturation_lower = 1 - self.saturation_range
  298. saturation_upper = 1 + self.saturation_range
  299. hue_lower = -self.hue_range
  300. hue_upper = self.hue_range
  301. ops = [brightness, contrast, saturation, hue]
  302. random.shuffle(ops)
  303. params_dict = {
  304. 'brightness': {
  305. 'brightness_lower': brightness_lower,
  306. 'brightness_upper': brightness_upper
  307. },
  308. 'contrast': {
  309. 'contrast_lower': contrast_lower,
  310. 'contrast_upper': contrast_upper
  311. },
  312. 'saturation': {
  313. 'saturation_lower': saturation_lower,
  314. 'saturation_upper': saturation_upper
  315. },
  316. 'hue': {
  317. 'hue_lower': hue_lower,
  318. 'hue_upper': hue_upper
  319. }
  320. }
  321. prob_dict = {
  322. 'brightness': self.brightness_prob,
  323. 'contrast': self.contrast_prob,
  324. 'saturation': self.saturation_prob,
  325. 'hue': self.hue_prob,
  326. }
  327. im = im.astype('uint8')
  328. im = Image.fromarray(im)
  329. for id in range(len(ops)):
  330. params = params_dict[ops[id].__name__]
  331. prob = prob_dict[ops[id].__name__]
  332. params['im'] = im
  333. if np.random.uniform(0, 1) < prob:
  334. im = ops[id](**params)
  335. im = np.asarray(im).astype('float32')
  336. if label is None:
  337. return (im, )
  338. else:
  339. return (im, label)
  340. class ArrangeClassifier:
  341. """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
  342. Args:
  343. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  344. Raises:
  345. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
  346. """
  347. def __init__(self, mode=None):
  348. if mode not in ['train', 'eval', 'test', 'quant']:
  349. raise ValueError(
  350. "mode must be in ['train', 'eval', 'test', 'quant']!")
  351. self.mode = mode
  352. def __call__(self, im, label=None):
  353. """
  354. Args:
  355. im (np.ndarray): 图像np.ndarray数据。
  356. label (int): 每张图像所对应的类别序号。
  357. Returns:
  358. tuple: 当mode为'train'或'eval'时,返回(im, label),分别对应图像np.ndarray数据、
  359. 图像类别id;当mode为'test'或'quant'时,返回(im, ),对应图像np.ndarray数据。
  360. """
  361. im = permute(im, False)
  362. if self.mode == 'train' or self.mode == 'eval':
  363. outputs = (im, label)
  364. else:
  365. outputs = (im, )
  366. return outputs