cls_transforms.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .ops import *
  15. from .imgaug_support import execute_imgaug
  16. import random
  17. import os.path as osp
  18. import numpy as np
  19. from PIL import Image, ImageEnhance
  20. class ClsTransform:
  21. """分类Transform的基类
  22. """
  23. def __init__(self):
  24. pass
  25. class Compose(ClsTransform):
  26. """根据数据预处理/增强算子对输入数据进行操作。
  27. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  28. Args:
  29. transforms (list): 数据预处理/增强算子。
  30. Raises:
  31. TypeError: 形参数据类型不满足需求。
  32. ValueError: 数据长度不匹配。
  33. """
  34. def __init__(self, transforms):
  35. if not isinstance(transforms, list):
  36. raise TypeError('The transforms must be a list!')
  37. if len(transforms) < 1:
  38. raise ValueError('The length of transforms ' + \
  39. 'must be equal or larger than 1!')
  40. self.transforms = transforms
  41. # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
  42. for op in self.transforms:
  43. if not isinstance(op, ClsTransform):
  44. import imgaug.augmenters as iaa
  45. if not isinstance(op, iaa.Augmenter):
  46. raise Exception(
  47. "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
  48. )
  49. def __call__(self, im, label=None):
  50. """
  51. Args:
  52. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  53. label (int): 每张图像所对应的类别序号。
  54. Returns:
  55. tuple: 根据网络所需字段所组成的tuple;
  56. 字段由transforms中的最后一个数据预处理操作决定。
  57. """
  58. if isinstance(im, np.ndarray):
  59. if len(im.shape) != 3:
  60. raise Exception(
  61. "im should be 3-dimension, but now is {}-dimensions".
  62. format(len(im.shape)))
  63. else:
  64. try:
  65. im = cv2.imread(im).astype('float32')
  66. except:
  67. raise TypeError('Can\'t read The image file {}!'.format(im))
  68. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  69. for op in self.transforms:
  70. if isinstance(op, ClsTransform):
  71. outputs = op(im, label)
  72. im = outputs[0]
  73. if len(outputs) == 2:
  74. label = outputs[1]
  75. else:
  76. import imgaug.augmenters as iaa
  77. if isinstance(op, iaa.Augmenter):
  78. im = execute_imgaug(op, im)
  79. outputs = (im, )
  80. if label is not None:
  81. outputs = (im, label)
  82. return outputs
  83. def add_augmenters(self, augmenters):
  84. if not isinstance(augmenters, list):
  85. raise Exception(
  86. "augmenters should be list type in func add_augmenters()")
  87. self.transforms = augmenters + self.transforms.transforms
  88. class RandomCrop(ClsTransform):
  89. """对图像进行随机剪裁,模型训练时的数据增强操作。
  90. 1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
  91. 2. 根据随机剪裁的高、宽随机选取剪裁的起始点。
  92. 3. 剪裁图像。
  93. 4. 调整剪裁后的图像的大小到crop_size*crop_size。
  94. Args:
  95. crop_size (int): 随机裁剪后重新调整的目标边长。默认为224。
  96. lower_scale (float): 裁剪面积相对原面积比例的最小限制。默认为0.08。
  97. lower_ratio (float): 宽变换比例的最小限制。默认为3. / 4。
  98. upper_ratio (float): 宽变换比例的最大限制。默认为4. / 3。
  99. """
  100. def __init__(self,
  101. crop_size=224,
  102. lower_scale=0.08,
  103. lower_ratio=3. / 4,
  104. upper_ratio=4. / 3):
  105. self.crop_size = crop_size
  106. self.lower_scale = lower_scale
  107. self.lower_ratio = lower_ratio
  108. self.upper_ratio = upper_ratio
  109. def __call__(self, im, label=None):
  110. """
  111. Args:
  112. im (np.ndarray): 图像np.ndarray数据。
  113. label (int): 每张图像所对应的类别序号。
  114. Returns:
  115. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  116. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  117. """
  118. im = random_crop(im, self.crop_size, self.lower_scale,
  119. self.lower_ratio, self.upper_ratio)
  120. if label is None:
  121. return (im, )
  122. else:
  123. return (im, label)
  124. class RandomHorizontalFlip(ClsTransform):
  125. """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
  126. Args:
  127. prob (float): 随机水平翻转的概率。默认为0.5。
  128. """
  129. def __init__(self, prob=0.5):
  130. self.prob = prob
  131. def __call__(self, im, label=None):
  132. """
  133. Args:
  134. im (np.ndarray): 图像np.ndarray数据。
  135. label (int): 每张图像所对应的类别序号。
  136. Returns:
  137. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  138. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  139. """
  140. if random.random() < self.prob:
  141. im = horizontal_flip(im)
  142. if label is None:
  143. return (im, )
  144. else:
  145. return (im, label)
  146. class RandomVerticalFlip(ClsTransform):
  147. """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
  148. Args:
  149. prob (float): 随机垂直翻转的概率。默认为0.5。
  150. """
  151. def __init__(self, prob=0.5):
  152. self.prob = prob
  153. def __call__(self, im, label=None):
  154. """
  155. Args:
  156. im (np.ndarray): 图像np.ndarray数据。
  157. label (int): 每张图像所对应的类别序号。
  158. Returns:
  159. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  160. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  161. """
  162. if random.random() < self.prob:
  163. im = vertical_flip(im)
  164. if label is None:
  165. return (im, )
  166. else:
  167. return (im, label)
  168. class Normalize(ClsTransform):
  169. """对图像进行标准化。
  170. 1. 对图像进行归一化到区间[0.0, 1.0]。
  171. 2. 对图像进行减均值除以标准差操作。
  172. Args:
  173. mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
  174. std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
  175. """
  176. def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
  177. self.mean = mean
  178. self.std = std
  179. def __call__(self, im, label=None):
  180. """
  181. Args:
  182. im (np.ndarray): 图像np.ndarray数据。
  183. label (int): 每张图像所对应的类别序号。
  184. Returns:
  185. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  186. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  187. """
  188. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  189. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  190. im = normalize(im, mean, std)
  191. if label is None:
  192. return (im, )
  193. else:
  194. return (im, label)
  195. class ResizeByShort(ClsTransform):
  196. """根据图像短边对图像重新调整大小(resize)。
  197. 1. 获取图像的长边和短边长度。
  198. 2. 根据短边与short_size的比例,计算长边的目标长度,
  199. 此时高、宽的resize比例为short_size/原图短边长度。
  200. 3. 如果max_size>0,调整resize比例:
  201. 如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度;
  202. 4. 根据调整大小的比例对图像进行resize。
  203. Args:
  204. short_size (int): 调整大小后的图像目标短边长度。默认为256。
  205. max_size (int): 长边目标长度的最大限制。默认为-1。
  206. """
  207. def __init__(self, short_size=256, max_size=-1):
  208. self.short_size = short_size
  209. self.max_size = max_size
  210. def __call__(self, im, label=None):
  211. """
  212. Args:
  213. im (np.ndarray): 图像np.ndarray数据。
  214. label (int): 每张图像所对应的类别序号。
  215. Returns:
  216. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  217. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  218. """
  219. im_short_size = min(im.shape[0], im.shape[1])
  220. im_long_size = max(im.shape[0], im.shape[1])
  221. scale = float(self.short_size) / im_short_size
  222. if self.max_size > 0 and np.round(scale *
  223. im_long_size) > self.max_size:
  224. scale = float(self.max_size) / float(im_long_size)
  225. resized_width = int(round(im.shape[1] * scale))
  226. resized_height = int(round(im.shape[0] * scale))
  227. im = cv2.resize(
  228. im, (resized_width, resized_height),
  229. interpolation=cv2.INTER_LINEAR)
  230. if label is None:
  231. return (im, )
  232. else:
  233. return (im, label)
  234. class CenterCrop(ClsTransform):
  235. """以图像中心点扩散裁剪长宽为`crop_size`的正方形
  236. 1. 计算剪裁的起始点。
  237. 2. 剪裁图像。
  238. Args:
  239. crop_size (int): 裁剪的目标边长。默认为224。
  240. """
  241. def __init__(self, crop_size=224):
  242. self.crop_size = crop_size
  243. def __call__(self, im, label=None):
  244. """
  245. Args:
  246. im (np.ndarray): 图像np.ndarray数据。
  247. label (int): 每张图像所对应的类别序号。
  248. Returns:
  249. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  250. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  251. """
  252. im = center_crop(im, self.crop_size)
  253. if label is None:
  254. return (im, )
  255. else:
  256. return (im, label)
  257. class RandomRotate(ClsTransform):
  258. def __init__(self, rotate_range=30, prob=0.5):
  259. """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
  260. Args:
  261. rotate_range (int): 旋转度数的范围。默认为30。
  262. prob (float): 随机旋转的概率。默认为0.5。
  263. """
  264. self.rotate_range = rotate_range
  265. self.prob = prob
  266. def __call__(self, im, label=None):
  267. """
  268. Args:
  269. im (np.ndarray): 图像np.ndarray数据。
  270. label (int): 每张图像所对应的类别序号。
  271. Returns:
  272. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  273. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  274. """
  275. rotate_lower = -self.rotate_range
  276. rotate_upper = self.rotate_range
  277. im = im.astype('uint8')
  278. im = Image.fromarray(im)
  279. if np.random.uniform(0, 1) < self.prob:
  280. im = rotate(im, rotate_lower, rotate_upper)
  281. im = np.asarray(im).astype('float32')
  282. if label is None:
  283. return (im, )
  284. else:
  285. return (im, label)
  286. class RandomDistort(ClsTransform):
  287. """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
  288. 1. 对变换的操作顺序进行随机化操作。
  289. 2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
  290. Args:
  291. brightness_range (float): 明亮度因子的范围。默认为0.9。
  292. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  293. contrast_range (float): 对比度因子的范围。默认为0.9。
  294. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  295. saturation_range (float): 饱和度因子的范围。默认为0.9。
  296. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  297. hue_range (int): 色调因子的范围。默认为18。
  298. hue_prob (float): 随机调整色调的概率。默认为0.5。
  299. """
  300. def __init__(self,
  301. brightness_range=0.9,
  302. brightness_prob=0.5,
  303. contrast_range=0.9,
  304. contrast_prob=0.5,
  305. saturation_range=0.9,
  306. saturation_prob=0.5,
  307. hue_range=18,
  308. hue_prob=0.5):
  309. self.brightness_range = brightness_range
  310. self.brightness_prob = brightness_prob
  311. self.contrast_range = contrast_range
  312. self.contrast_prob = contrast_prob
  313. self.saturation_range = saturation_range
  314. self.saturation_prob = saturation_prob
  315. self.hue_range = hue_range
  316. self.hue_prob = hue_prob
  317. def __call__(self, im, label=None):
  318. """
  319. Args:
  320. im (np.ndarray): 图像np.ndarray数据。
  321. label (int): 每张图像所对应的类别序号。
  322. Returns:
  323. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  324. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  325. """
  326. brightness_lower = 1 - self.brightness_range
  327. brightness_upper = 1 + self.brightness_range
  328. contrast_lower = 1 - self.contrast_range
  329. contrast_upper = 1 + self.contrast_range
  330. saturation_lower = 1 - self.saturation_range
  331. saturation_upper = 1 + self.saturation_range
  332. hue_lower = -self.hue_range
  333. hue_upper = self.hue_range
  334. ops = [brightness, contrast, saturation, hue]
  335. random.shuffle(ops)
  336. params_dict = {
  337. 'brightness': {
  338. 'brightness_lower': brightness_lower,
  339. 'brightness_upper': brightness_upper
  340. },
  341. 'contrast': {
  342. 'contrast_lower': contrast_lower,
  343. 'contrast_upper': contrast_upper
  344. },
  345. 'saturation': {
  346. 'saturation_lower': saturation_lower,
  347. 'saturation_upper': saturation_upper
  348. },
  349. 'hue': {
  350. 'hue_lower': hue_lower,
  351. 'hue_upper': hue_upper
  352. }
  353. }
  354. prob_dict = {
  355. 'brightness': self.brightness_prob,
  356. 'contrast': self.contrast_prob,
  357. 'saturation': self.saturation_prob,
  358. 'hue': self.hue_prob,
  359. }
  360. for id in range(len(ops)):
  361. params = params_dict[ops[id].__name__]
  362. prob = prob_dict[ops[id].__name__]
  363. params['im'] = im
  364. if np.random.uniform(0, 1) < prob:
  365. im = ops[id](**params)
  366. if label is None:
  367. return (im, )
  368. else:
  369. return (im, label)
  370. class ArrangeClassifier(ClsTransform):
  371. """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
  372. Args:
  373. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  374. Raises:
  375. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
  376. """
  377. def __init__(self, mode=None):
  378. if mode not in ['train', 'eval', 'test', 'quant']:
  379. raise ValueError(
  380. "mode must be in ['train', 'eval', 'test', 'quant']!")
  381. self.mode = mode
  382. def __call__(self, im, label=None):
  383. """
  384. Args:
  385. im (np.ndarray): 图像np.ndarray数据。
  386. label (int): 每张图像所对应的类别序号。
  387. Returns:
  388. tuple: 当mode为'train'或'eval'时,返回(im, label),分别对应图像np.ndarray数据、
  389. 图像类别id;当mode为'test'或'quant'时,返回(im, ),对应图像np.ndarray数据。
  390. """
  391. im = permute(im, False).astype('float32')
  392. if self.mode == 'train' or self.mode == 'eval':
  393. outputs = (im, label)
  394. else:
  395. outputs = (im, )
  396. return outputs
  397. class ComposedClsTransforms(Compose):
  398. """ 分类模型的基础Transforms流程,具体如下
  399. 训练阶段:
  400. 1. 随机从图像中crop一块子图,并resize成crop_size大小
  401. 2. 将1的输出按0.5的概率随机进行水平翻转
  402. 3. 将图像进行归一化
  403. 验证/预测阶段:
  404. 1. 将图像按比例Resize,使得最小边长度为crop_size[0] * 1.14
  405. 2. 从图像中心crop出一个大小为crop_size的图像
  406. 3. 将图像进行归一化
  407. Args:
  408. mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
  409. crop_size(int|list): 输入模型里的图像大小
  410. mean(list): 图像均值
  411. std(list): 图像方差
  412. """
  413. def __init__(self,
  414. mode,
  415. crop_size=[224, 224],
  416. mean=[0.485, 0.456, 0.406],
  417. std=[0.229, 0.224, 0.225]):
  418. width = crop_size
  419. if isinstance(crop_size, list):
  420. if crop_size[0] != crop_size[1]:
  421. raise Exception(
  422. "In classifier model, width and height should be equal, please modify your parameter `crop_size`"
  423. )
  424. width = crop_size[0]
  425. if width % 32 != 0:
  426. raise Exception(
  427. "In classifier model, width and height should be multiple of 32, e.g 224、256、320...., please modify your parameter `crop_size`"
  428. )
  429. if mode == 'train':
  430. # 训练时的transforms,包含数据增强
  431. transforms = [
  432. RandomCrop(crop_size=width), RandomHorizontalFlip(prob=0.5),
  433. Normalize(
  434. mean=mean, std=std)
  435. ]
  436. else:
  437. # 验证/预测时的transforms
  438. transforms = [
  439. ResizeByShort(short_size=int(width * 1.14)),
  440. CenterCrop(crop_size=width), Normalize(
  441. mean=mean, std=std)
  442. ]
  443. super(ComposedClsTransforms, self).__init__(transforms)