cls_transforms.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .ops import *
  15. from .imgaug_support import execute_imgaug
  16. import random
  17. import os
  18. import os.path as osp
  19. import numpy as np
  20. from PIL import Image, ImageEnhance
  21. import paddlex.utils.logging as logging
  22. class ClsTransform:
  23. """分类Transform的基类
  24. """
  25. def __init__(self):
  26. pass
  27. class Compose(ClsTransform):
  28. """根据数据预处理/增强算子对输入数据进行操作。
  29. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  30. Args:
  31. transforms (list): 数据预处理/增强算子。
  32. Raises:
  33. TypeError: 形参数据类型不满足需求。
  34. ValueError: 数据长度不匹配。
  35. """
  36. def __init__(self, transforms):
  37. if not isinstance(transforms, list):
  38. raise TypeError('The transforms must be a list!')
  39. if len(transforms) < 1:
  40. raise ValueError('The length of transforms ' + \
  41. 'must be equal or larger than 1!')
  42. self.transforms = transforms
  43. # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
  44. for op in self.transforms:
  45. if not isinstance(op, ClsTransform):
  46. import imgaug.augmenters as iaa
  47. if not isinstance(op, iaa.Augmenter):
  48. raise Exception(
  49. "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
  50. )
  51. def __call__(self, im, label=None, images_writer=None, step=0):
  52. """
  53. Args:
  54. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  55. label (int): 每张图像所对应的类别序号。
  56. Returns:
  57. tuple: 根据网络所需字段所组成的tuple;
  58. 字段由transforms中的最后一个数据预处理操作决定。
  59. """
  60. if isinstance(im, np.ndarray):
  61. if len(im.shape) != 3:
  62. raise Exception(
  63. "im should be 3-dimension, but now is {}-dimensions".
  64. format(len(im.shape)))
  65. else:
  66. try:
  67. im = cv2.imread(im).astype('float32')
  68. except:
  69. raise TypeError('Can\'t read The image file {}!'.format(im))
  70. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  71. if images_writer is not None:
  72. images_writer.add_image(tag='0. origin image',
  73. img=im,
  74. step=step)
  75. op_id = 1
  76. for op in self.transforms:
  77. if isinstance(op, ClsTransform):
  78. outputs = op(im, label)
  79. im = outputs[0]
  80. if len(outputs) == 2:
  81. label = outputs[1]
  82. else:
  83. import imgaug.augmenters as iaa
  84. if isinstance(op, iaa.Augmenter):
  85. im = execute_imgaug(op, im)
  86. outputs = (im, )
  87. if label is not None:
  88. outputs = (im, label)
  89. if images_writer is not None:
  90. tag = str(op_id) + '. ' + op.__class__.__name__
  91. images_writer.add_image(tag=tag,
  92. img=im,
  93. step=step)
  94. op_id += 1
  95. return outputs
  96. def add_augmenters(self, augmenters):
  97. if not isinstance(augmenters, list):
  98. raise Exception(
  99. "augmenters should be list type in func add_augmenters()")
  100. transform_names = [type(x).__name__ for x in self.transforms]
  101. for aug in augmenters:
  102. if type(aug).__name__ in transform_names:
  103. logging.error("{} is already in ComposedTransforms, need to remove it from add_augmenters().".format(type(aug).__name__))
  104. self.transforms = augmenters + self.transforms
  105. class RandomCrop(ClsTransform):
  106. """对图像进行随机剪裁,模型训练时的数据增强操作。
  107. 1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
  108. 2. 根据随机剪裁的高、宽随机选取剪裁的起始点。
  109. 3. 剪裁图像。
  110. 4. 调整剪裁后的图像的大小到crop_size*crop_size。
  111. Args:
  112. crop_size (int): 随机裁剪后重新调整的目标边长。默认为224。
  113. lower_scale (float): 裁剪面积相对原面积比例的最小限制。默认为0.08。
  114. lower_ratio (float): 宽变换比例的最小限制。默认为3. / 4。
  115. upper_ratio (float): 宽变换比例的最大限制。默认为4. / 3。
  116. """
  117. def __init__(self,
  118. crop_size=224,
  119. lower_scale=0.08,
  120. lower_ratio=3. / 4,
  121. upper_ratio=4. / 3):
  122. self.crop_size = crop_size
  123. self.lower_scale = lower_scale
  124. self.lower_ratio = lower_ratio
  125. self.upper_ratio = upper_ratio
  126. def __call__(self, im, label=None):
  127. """
  128. Args:
  129. im (np.ndarray): 图像np.ndarray数据。
  130. label (int): 每张图像所对应的类别序号。
  131. Returns:
  132. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  133. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  134. """
  135. im = random_crop(im, self.crop_size, self.lower_scale,
  136. self.lower_ratio, self.upper_ratio)
  137. if label is None:
  138. return (im, )
  139. else:
  140. return (im, label)
  141. class RandomHorizontalFlip(ClsTransform):
  142. """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
  143. Args:
  144. prob (float): 随机水平翻转的概率。默认为0.5。
  145. """
  146. def __init__(self, prob=0.5):
  147. self.prob = prob
  148. def __call__(self, im, label=None):
  149. """
  150. Args:
  151. im (np.ndarray): 图像np.ndarray数据。
  152. label (int): 每张图像所对应的类别序号。
  153. Returns:
  154. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  155. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  156. """
  157. if random.random() < self.prob:
  158. im = horizontal_flip(im)
  159. if label is None:
  160. return (im, )
  161. else:
  162. return (im, label)
  163. class RandomVerticalFlip(ClsTransform):
  164. """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
  165. Args:
  166. prob (float): 随机垂直翻转的概率。默认为0.5。
  167. """
  168. def __init__(self, prob=0.5):
  169. self.prob = prob
  170. def __call__(self, im, label=None):
  171. """
  172. Args:
  173. im (np.ndarray): 图像np.ndarray数据。
  174. label (int): 每张图像所对应的类别序号。
  175. Returns:
  176. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  177. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  178. """
  179. if random.random() < self.prob:
  180. im = vertical_flip(im)
  181. if label is None:
  182. return (im, )
  183. else:
  184. return (im, label)
  185. class Normalize(ClsTransform):
  186. """对图像进行标准化。
  187. 1. 对图像进行归一化到区间[0.0, 1.0]。
  188. 2. 对图像进行减均值除以标准差操作。
  189. Args:
  190. mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
  191. std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
  192. """
  193. def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
  194. self.mean = mean
  195. self.std = std
  196. def __call__(self, im, label=None):
  197. """
  198. Args:
  199. im (np.ndarray): 图像np.ndarray数据。
  200. label (int): 每张图像所对应的类别序号。
  201. Returns:
  202. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  203. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  204. """
  205. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  206. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  207. im = normalize(im, mean, std)
  208. if label is None:
  209. return (im, )
  210. else:
  211. return (im, label)
  212. class ResizeByShort(ClsTransform):
  213. """根据图像短边对图像重新调整大小(resize)。
  214. 1. 获取图像的长边和短边长度。
  215. 2. 根据短边与short_size的比例,计算长边的目标长度,
  216. 此时高、宽的resize比例为short_size/原图短边长度。
  217. 3. 如果max_size>0,调整resize比例:
  218. 如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度;
  219. 4. 根据调整大小的比例对图像进行resize。
  220. Args:
  221. short_size (int): 调整大小后的图像目标短边长度。默认为256。
  222. max_size (int): 长边目标长度的最大限制。默认为-1。
  223. """
  224. def __init__(self, short_size=256, max_size=-1):
  225. self.short_size = short_size
  226. self.max_size = max_size
  227. def __call__(self, im, label=None):
  228. """
  229. Args:
  230. im (np.ndarray): 图像np.ndarray数据。
  231. label (int): 每张图像所对应的类别序号。
  232. Returns:
  233. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  234. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  235. """
  236. im_short_size = min(im.shape[0], im.shape[1])
  237. im_long_size = max(im.shape[0], im.shape[1])
  238. scale = float(self.short_size) / im_short_size
  239. if self.max_size > 0 and np.round(scale *
  240. im_long_size) > self.max_size:
  241. scale = float(self.max_size) / float(im_long_size)
  242. resized_width = int(round(im.shape[1] * scale))
  243. resized_height = int(round(im.shape[0] * scale))
  244. im = cv2.resize(
  245. im, (resized_width, resized_height),
  246. interpolation=cv2.INTER_LINEAR)
  247. if label is None:
  248. return (im, )
  249. else:
  250. return (im, label)
  251. class CenterCrop(ClsTransform):
  252. """以图像中心点扩散裁剪长宽为`crop_size`的正方形
  253. 1. 计算剪裁的起始点。
  254. 2. 剪裁图像。
  255. Args:
  256. crop_size (int): 裁剪的目标边长。默认为224。
  257. """
  258. def __init__(self, crop_size=224):
  259. self.crop_size = crop_size
  260. def __call__(self, im, label=None):
  261. """
  262. Args:
  263. im (np.ndarray): 图像np.ndarray数据。
  264. label (int): 每张图像所对应的类别序号。
  265. Returns:
  266. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  267. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  268. """
  269. im = center_crop(im, self.crop_size)
  270. if label is None:
  271. return (im, )
  272. else:
  273. return (im, label)
  274. class RandomRotate(ClsTransform):
  275. def __init__(self, rotate_range=30, prob=0.5):
  276. """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
  277. Args:
  278. rotate_range (int): 旋转度数的范围。默认为30。
  279. prob (float): 随机旋转的概率。默认为0.5。
  280. """
  281. self.rotate_range = rotate_range
  282. self.prob = prob
  283. def __call__(self, im, label=None):
  284. """
  285. Args:
  286. im (np.ndarray): 图像np.ndarray数据。
  287. label (int): 每张图像所对应的类别序号。
  288. Returns:
  289. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  290. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  291. """
  292. rotate_lower = -self.rotate_range
  293. rotate_upper = self.rotate_range
  294. im = im.astype('uint8')
  295. im = Image.fromarray(im)
  296. if np.random.uniform(0, 1) < self.prob:
  297. im = rotate(im, rotate_lower, rotate_upper)
  298. im = np.asarray(im).astype('float32')
  299. if label is None:
  300. return (im, )
  301. else:
  302. return (im, label)
  303. class RandomDistort(ClsTransform):
  304. """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
  305. 1. 对变换的操作顺序进行随机化操作。
  306. 2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
  307. Args:
  308. brightness_range (float): 明亮度因子的范围。默认为0.9。
  309. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  310. contrast_range (float): 对比度因子的范围。默认为0.9。
  311. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  312. saturation_range (float): 饱和度因子的范围。默认为0.9。
  313. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  314. hue_range (int): 色调因子的范围。默认为18。
  315. hue_prob (float): 随机调整色调的概率。默认为0.5。
  316. """
  317. def __init__(self,
  318. brightness_range=0.9,
  319. brightness_prob=0.5,
  320. contrast_range=0.9,
  321. contrast_prob=0.5,
  322. saturation_range=0.9,
  323. saturation_prob=0.5,
  324. hue_range=18,
  325. hue_prob=0.5):
  326. self.brightness_range = brightness_range
  327. self.brightness_prob = brightness_prob
  328. self.contrast_range = contrast_range
  329. self.contrast_prob = contrast_prob
  330. self.saturation_range = saturation_range
  331. self.saturation_prob = saturation_prob
  332. self.hue_range = hue_range
  333. self.hue_prob = hue_prob
  334. def __call__(self, im, label=None):
  335. """
  336. Args:
  337. im (np.ndarray): 图像np.ndarray数据。
  338. label (int): 每张图像所对应的类别序号。
  339. Returns:
  340. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  341. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  342. """
  343. brightness_lower = 1 - self.brightness_range
  344. brightness_upper = 1 + self.brightness_range
  345. contrast_lower = 1 - self.contrast_range
  346. contrast_upper = 1 + self.contrast_range
  347. saturation_lower = 1 - self.saturation_range
  348. saturation_upper = 1 + self.saturation_range
  349. hue_lower = -self.hue_range
  350. hue_upper = self.hue_range
  351. ops = [brightness, contrast, saturation, hue]
  352. random.shuffle(ops)
  353. params_dict = {
  354. 'brightness': {
  355. 'brightness_lower': brightness_lower,
  356. 'brightness_upper': brightness_upper
  357. },
  358. 'contrast': {
  359. 'contrast_lower': contrast_lower,
  360. 'contrast_upper': contrast_upper
  361. },
  362. 'saturation': {
  363. 'saturation_lower': saturation_lower,
  364. 'saturation_upper': saturation_upper
  365. },
  366. 'hue': {
  367. 'hue_lower': hue_lower,
  368. 'hue_upper': hue_upper
  369. }
  370. }
  371. prob_dict = {
  372. 'brightness': self.brightness_prob,
  373. 'contrast': self.contrast_prob,
  374. 'saturation': self.saturation_prob,
  375. 'hue': self.hue_prob,
  376. }
  377. for id in range(len(ops)):
  378. params = params_dict[ops[id].__name__]
  379. prob = prob_dict[ops[id].__name__]
  380. params['im'] = im
  381. if np.random.uniform(0, 1) < prob:
  382. im = ops[id](**params)
  383. im = im.astype('float32')
  384. if label is None:
  385. return (im, )
  386. else:
  387. return (im, label)
  388. class ArrangeClassifier(ClsTransform):
  389. """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
  390. Args:
  391. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  392. Raises:
  393. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
  394. """
  395. def __init__(self, mode=None):
  396. if mode not in ['train', 'eval', 'test', 'quant']:
  397. raise ValueError(
  398. "mode must be in ['train', 'eval', 'test', 'quant']!")
  399. self.mode = mode
  400. def __call__(self, im, label=None):
  401. """
  402. Args:
  403. im (np.ndarray): 图像np.ndarray数据。
  404. label (int): 每张图像所对应的类别序号。
  405. Returns:
  406. tuple: 当mode为'train'或'eval'时,返回(im, label),分别对应图像np.ndarray数据、
  407. 图像类别id;当mode为'test'或'quant'时,返回(im, ),对应图像np.ndarray数据。
  408. """
  409. im = permute(im, False).astype('float32')
  410. if self.mode == 'train' or self.mode == 'eval':
  411. outputs = (im, label)
  412. else:
  413. outputs = (im, )
  414. return outputs
  415. class ComposedClsTransforms(Compose):
  416. """ 分类模型的基础Transforms流程,具体如下
  417. 训练阶段:
  418. 1. 随机从图像中crop一块子图,并resize成crop_size大小
  419. 2. 将1的输出按0.5的概率随机进行水平翻转
  420. 3. 将图像进行归一化
  421. 验证/预测阶段:
  422. 1. 将图像按比例Resize,使得最小边长度为crop_size[0] * 1.14
  423. 2. 从图像中心crop出一个大小为crop_size的图像
  424. 3. 将图像进行归一化
  425. Args:
  426. mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
  427. crop_size(int|list): 输入模型里的图像大小
  428. mean(list): 图像均值
  429. std(list): 图像方差
  430. """
  431. def __init__(self,
  432. mode,
  433. crop_size=[224, 224],
  434. mean=[0.485, 0.456, 0.406],
  435. std=[0.229, 0.224, 0.225]):
  436. width = crop_size
  437. if isinstance(crop_size, list):
  438. if crop_size[0] != crop_size[1]:
  439. raise Exception(
  440. "In classifier model, width and height should be equal, please modify your parameter `crop_size`"
  441. )
  442. width = crop_size[0]
  443. if width % 32 != 0:
  444. raise Exception(
  445. "In classifier model, width and height should be multiple of 32, e.g 224、256、320...., please modify your parameter `crop_size`"
  446. )
  447. if mode == 'train':
  448. # 训练时的transforms,包含数据增强
  449. transforms = [
  450. RandomCrop(crop_size=width), RandomHorizontalFlip(prob=0.5),
  451. Normalize(
  452. mean=mean, std=std)
  453. ]
  454. else:
  455. # 验证/预测时的transforms
  456. transforms = [
  457. ResizeByShort(short_size=int(width * 1.14)),
  458. CenterCrop(crop_size=width), Normalize(
  459. mean=mean, std=std)
  460. ]
  461. super(ComposedClsTransforms, self).__init__(transforms)