cls_transforms.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .ops import *
  15. from .imgaug_support import execute_imgaug
  16. import random
  17. import os.path as osp
  18. import numpy as np
  19. from PIL import Image, ImageEnhance
  20. import paddlex.utils.logging as logging
  21. class ClsTransform:
  22. """分类Transform的基类
  23. """
  24. def __init__(self):
  25. pass
  26. class Compose(ClsTransform):
  27. """根据数据预处理/增强算子对输入数据进行操作。
  28. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  29. Args:
  30. transforms (list): 数据预处理/增强算子。
  31. Raises:
  32. TypeError: 形参数据类型不满足需求。
  33. ValueError: 数据长度不匹配。
  34. """
  35. def __init__(self, transforms):
  36. if not isinstance(transforms, list):
  37. raise TypeError('The transforms must be a list!')
  38. if len(transforms) < 1:
  39. raise ValueError('The length of transforms ' + \
  40. 'must be equal or larger than 1!')
  41. self.transforms = transforms
  42. self.batch_transforms = None
  43. self.data_type = np.uint8
  44. self.to_rgb = True
  45. # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
  46. for op in self.transforms:
  47. if not isinstance(op, ClsTransform):
  48. import imgaug.augmenters as iaa
  49. if not isinstance(op, iaa.Augmenter):
  50. raise Exception(
  51. "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
  52. )
  53. def __call__(self, im_file, label=None):
  54. """
  55. Args:
  56. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  57. label (int): 每张图像所对应的类别序号。
  58. Returns:
  59. tuple: 根据网络所需字段所组成的tuple;
  60. 字段由transforms中的最后一个数据预处理操作决定。
  61. """
  62. input_channel = getattr(self, 'input_channel', 3)
  63. if isinstance(im_file, np.ndarray):
  64. if len(im_file.shape) != 3:
  65. raise Exception(
  66. "im should be 3-dimension, but now is {}-dimensions".
  67. format(len(im_file.shape)))
  68. else:
  69. try:
  70. if input_channel == 3:
  71. im = cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
  72. cv2.IMREAD_ANYCOLOR)
  73. else:
  74. im = cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
  75. if im.ndim < 3:
  76. im = np.expand_dims(im, axis=-1)
  77. except:
  78. raise TypeError('Can\'t read The image file {}!'.format(
  79. im_file))
  80. self.data_type = im.dtype
  81. im = im.astype('float32')
  82. if input_channel == 3 and self.to_rgb:
  83. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  84. for op in self.transforms:
  85. if isinstance(op, ClsTransform):
  86. if op.__class__.__name__ == 'RandomDistort':
  87. op.to_rgb = self.to_rgb
  88. op.data_type = self.data_type
  89. outputs = op(im, label)
  90. im = outputs[0]
  91. if len(outputs) == 2:
  92. label = outputs[1]
  93. else:
  94. import imgaug.augmenters as iaa
  95. if im.shape[-1] != 3:
  96. raise Exception(
  97. "Only the 3-channel RGB image is supported in the imgaug operator, but recieved image channel is {}".
  98. format(im.shape[-1]))
  99. if isinstance(op, iaa.Augmenter):
  100. im = execute_imgaug(op, im)
  101. outputs = (im, )
  102. if label is not None:
  103. outputs = (im, label)
  104. return outputs
  105. def add_augmenters(self, augmenters):
  106. if not isinstance(augmenters, list):
  107. raise Exception(
  108. "augmenters should be list type in func add_augmenters()")
  109. transform_names = [type(x).__name__ for x in self.transforms]
  110. for aug in augmenters:
  111. if type(aug).__name__ in transform_names:
  112. logging.error(
  113. "{} is already in ComposedTransforms, need to remove it from add_augmenters().".
  114. format(type(aug).__name__))
  115. self.transforms = augmenters + self.transforms
  116. class RandomCrop(ClsTransform):
  117. """对图像进行随机剪裁,模型训练时的数据增强操作。
  118. 1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
  119. 2. 根据随机剪裁的高、宽随机选取剪裁的起始点。
  120. 3. 剪裁图像。
  121. 4. 调整剪裁后的图像的大小到crop_size*crop_size。
  122. Args:
  123. crop_size (int): 随机裁剪后重新调整的目标边长。默认为224。
  124. lower_scale (float): 裁剪面积相对原面积比例的最小限制。默认为0.08。
  125. lower_ratio (float): 宽变换比例的最小限制。默认为3. / 4。
  126. upper_ratio (float): 宽变换比例的最大限制。默认为4. / 3。
  127. """
  128. def __init__(self,
  129. crop_size=224,
  130. lower_scale=0.08,
  131. lower_ratio=3. / 4,
  132. upper_ratio=4. / 3):
  133. self.crop_size = crop_size
  134. self.lower_scale = lower_scale
  135. self.lower_ratio = lower_ratio
  136. self.upper_ratio = upper_ratio
  137. def __call__(self, im, label=None):
  138. """
  139. Args:
  140. im (np.ndarray): 图像np.ndarray数据。
  141. label (int): 每张图像所对应的类别序号。
  142. Returns:
  143. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  144. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  145. """
  146. im = random_crop(im, self.crop_size, self.lower_scale,
  147. self.lower_ratio, self.upper_ratio)
  148. if label is None:
  149. return (im, )
  150. else:
  151. return (im, label)
  152. class RandomHorizontalFlip(ClsTransform):
  153. """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
  154. Args:
  155. prob (float): 随机水平翻转的概率。默认为0.5。
  156. """
  157. def __init__(self, prob=0.5):
  158. self.prob = prob
  159. def __call__(self, im, label=None):
  160. """
  161. Args:
  162. im (np.ndarray): 图像np.ndarray数据。
  163. label (int): 每张图像所对应的类别序号。
  164. Returns:
  165. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  166. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  167. """
  168. if random.random() < self.prob:
  169. im = horizontal_flip(im)
  170. if label is None:
  171. return (im, )
  172. else:
  173. return (im, label)
  174. class RandomVerticalFlip(ClsTransform):
  175. """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
  176. Args:
  177. prob (float): 随机垂直翻转的概率。默认为0.5。
  178. """
  179. def __init__(self, prob=0.5):
  180. self.prob = prob
  181. def __call__(self, im, label=None):
  182. """
  183. Args:
  184. im (np.ndarray): 图像np.ndarray数据。
  185. label (int): 每张图像所对应的类别序号。
  186. Returns:
  187. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  188. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  189. """
  190. if random.random() < self.prob:
  191. im = vertical_flip(im)
  192. if label is None:
  193. return (im, )
  194. else:
  195. return (im, label)
  196. class Normalize(ClsTransform):
  197. """对图像进行标准化。
  198. 1.像素值减去min_val
  199. 2.像素值除以(max_val-min_val)
  200. 3.对图像进行减均值除以标准差操作。
  201. Args:
  202. mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
  203. std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
  204. min_val (list): 图像数据集的最小值。默认值[0, 0, 0]。
  205. max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
  206. Raises:
  207. ValueError: mean或std不是list对象。std包含0。
  208. """
  209. def __init__(self,
  210. mean=[0.485, 0.456, 0.406],
  211. std=[0.229, 0.224, 0.225],
  212. min_val=[0, 0, 0],
  213. max_val=[255.0, 255.0, 255.0]):
  214. self.mean = mean
  215. self.std = std
  216. self.min_val = min_val
  217. self.max_val = max_val
  218. if not (isinstance(self.mean, list) and isinstance(self.std, list)):
  219. raise ValueError("{}: input type is invalid.".format(self))
  220. if not (isinstance(self.min_val, list) and isinstance(self.max_val,
  221. list)):
  222. raise ValueError("{}: input type is invalid.".format(self))
  223. from functools import reduce
  224. if reduce(lambda x, y: x * y, self.std) == 0:
  225. raise ValueError('{}: std is invalid!'.format(self))
  226. def __call__(self, im, label=None):
  227. """
  228. Args:
  229. im (np.ndarray): 图像np.ndarray数据。
  230. label (int): 每张图像所对应的类别序号。
  231. Returns:
  232. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  233. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  234. """
  235. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  236. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  237. im = normalize(im, mean, std, self.min_val, self.max_val)
  238. if label is None:
  239. return (im, )
  240. else:
  241. return (im, label)
  242. class ResizeByShort(ClsTransform):
  243. """根据图像短边对图像重新调整大小(resize)。
  244. 1. 获取图像的长边和短边长度。
  245. 2. 根据短边与short_size的比例,计算长边的目标长度,
  246. 此时高、宽的resize比例为short_size/原图短边长度。
  247. 3. 如果max_size>0,调整resize比例:
  248. 如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度;
  249. 4. 根据调整大小的比例对图像进行resize。
  250. Args:
  251. short_size (int): 调整大小后的图像目标短边长度。默认为256。
  252. max_size (int): 长边目标长度的最大限制。默认为-1。
  253. """
  254. def __init__(self, short_size=256, max_size=-1):
  255. self.short_size = short_size
  256. self.max_size = max_size
  257. def __call__(self, im, label=None):
  258. """
  259. Args:
  260. im (np.ndarray): 图像np.ndarray数据。
  261. label (int): 每张图像所对应的类别序号。
  262. Returns:
  263. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  264. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  265. """
  266. im_short_size = min(im.shape[0], im.shape[1])
  267. im_long_size = max(im.shape[0], im.shape[1])
  268. scale = float(self.short_size) / im_short_size
  269. if self.max_size > 0 and np.round(scale *
  270. im_long_size) > self.max_size:
  271. scale = float(self.max_size) / float(im_long_size)
  272. resized_width = int(round(im.shape[1] * scale))
  273. resized_height = int(round(im.shape[0] * scale))
  274. im = cv2.resize(
  275. im, (resized_width, resized_height),
  276. interpolation=cv2.INTER_LINEAR)
  277. if label is None:
  278. return (im, )
  279. else:
  280. return (im, label)
  281. class CenterCrop(ClsTransform):
  282. """以图像中心点扩散裁剪长宽为`crop_size`的正方形
  283. 1. 计算剪裁的起始点。
  284. 2. 剪裁图像。
  285. Args:
  286. crop_size (int): 裁剪的目标边长。默认为224。
  287. """
  288. def __init__(self, crop_size=224):
  289. self.crop_size = crop_size
  290. def __call__(self, im, label=None):
  291. """
  292. Args:
  293. im (np.ndarray): 图像np.ndarray数据。
  294. label (int): 每张图像所对应的类别序号。
  295. Returns:
  296. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  297. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  298. """
  299. im = center_crop(im, self.crop_size)
  300. if label is None:
  301. return (im, )
  302. else:
  303. return (im, label)
  304. class RandomRotate(ClsTransform):
  305. def __init__(self, rotate_range=30, prob=0.5):
  306. """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
  307. Args:
  308. rotate_range (int): 旋转度数的范围。默认为30。
  309. prob (float): 随机旋转的概率。默认为0.5。
  310. """
  311. self.rotate_range = rotate_range
  312. self.prob = prob
  313. def __call__(self, im, label=None):
  314. """
  315. Args:
  316. im (np.ndarray): 图像np.ndarray数据。
  317. label (int): 每张图像所对应的类别序号。
  318. Returns:
  319. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  320. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  321. """
  322. rotate_lower = -self.rotate_range
  323. rotate_upper = self.rotate_range
  324. im = im.astype('uint8')
  325. im = Image.fromarray(im)
  326. if np.random.uniform(0, 1) < self.prob:
  327. im = rotate(im, rotate_lower, rotate_upper)
  328. im = np.asarray(im).astype('float32')
  329. if label is None:
  330. return (im, )
  331. else:
  332. return (im, label)
  333. class RandomDistort(ClsTransform):
  334. """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
  335. 1. 对变换的操作顺序进行随机化操作。
  336. 2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。
  337. 【注意】如果输入是uint8/uint16的RGB图像,该数据增强必须在数据增强Normalize之前使用。
  338. Args:
  339. brightness_range (float): 明亮度的缩放系数范围。
  340. 从[1-`brightness_range`, 1+`brightness_range`]中随机取值作为明亮度缩放因子`scale`,
  341. 按照公式`image = image * scale`调整图像明亮度。默认值为0.9。
  342. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  343. contrast_range (float): 对比度的缩放系数范围。
  344. 从[1-`contrast_range`, 1+`contrast_range`]中随机取值作为对比度缩放因子`scale`,
  345. 按照公式`image = image * scale + (image_mean + 0.5) * (1 - scale)`调整图像对比度。默认为0.9。
  346. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  347. saturation_range (float): 饱和度的缩放系数范围。
  348. 从[1-`saturation_range`, 1+`saturation_range`]中随机取值作为饱和度缩放因子`scale`,
  349. 按照公式`image = gray * (1 - scale) + image * scale`,
  350. 其中`gray = R * 299/1000 + G * 587/1000+ B * 114/1000`。默认为0.9。
  351. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  352. hue_range (int): 调整色相角度的差值取值范围。
  353. 从[-`hue_range`, `hue_range`]中随机取值作为色相角度调整差值`delta`,
  354. 按照公式`hue = hue + delta`调整色相角度 。默认为18,取值范围[0, 360]。
  355. hue_prob (float): 随机调整色调的概率。默认为0.5。
  356. """
  357. def __init__(self,
  358. brightness_range=0.9,
  359. brightness_prob=0.5,
  360. contrast_range=0.9,
  361. contrast_prob=0.5,
  362. saturation_range=0.9,
  363. saturation_prob=0.5,
  364. hue_range=18,
  365. hue_prob=0.5):
  366. self.brightness_range = brightness_range
  367. self.brightness_prob = brightness_prob
  368. self.contrast_range = contrast_range
  369. self.contrast_prob = contrast_prob
  370. self.saturation_range = saturation_range
  371. self.saturation_prob = saturation_prob
  372. self.hue_range = hue_range
  373. self.hue_prob = hue_prob
  374. def __call__(self, im, label=None):
  375. """
  376. Args:
  377. im (np.ndarray): 图像np.ndarray数据。
  378. label (int): 每张图像所对应的类别序号。
  379. Returns:
  380. tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
  381. 当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
  382. """
  383. if im.shape[-1] != 3:
  384. raise Exception(
  385. "Only the 3-channel RGB image is supported in the RandomDistort operator, but recieved image channel is {}".
  386. format(im.shape[-1]))
  387. if self.data_type not in [np.uint8, np.uint16, np.float32]:
  388. raise Exception(
  389. "Only the uint8/uint16/float32 RGB image is supported in the RandomDistort operator, but recieved image data type is {}".
  390. format(self.data_type))
  391. brightness_lower = 1 - self.brightness_range
  392. brightness_upper = 1 + self.brightness_range
  393. contrast_lower = 1 - self.contrast_range
  394. contrast_upper = 1 + self.contrast_range
  395. saturation_lower = 1 - self.saturation_range
  396. saturation_upper = 1 + self.saturation_range
  397. hue_lower = -self.hue_range
  398. hue_upper = self.hue_range
  399. ops = [brightness, contrast, saturation, hue]
  400. random.shuffle(ops)
  401. params_dict = {
  402. 'brightness': {
  403. 'brightness_lower': brightness_lower,
  404. 'brightness_upper': brightness_upper,
  405. 'dtype': self.data_type
  406. },
  407. 'contrast': {
  408. 'contrast_lower': contrast_lower,
  409. 'contrast_upper': contrast_upper,
  410. 'dtype': self.data_type
  411. },
  412. 'saturation': {
  413. 'saturation_lower': saturation_lower,
  414. 'saturation_upper': saturation_upper,
  415. 'is_rgb': self.to_rgb,
  416. 'dtype': self.data_type
  417. },
  418. 'hue': {
  419. 'hue_lower': hue_lower,
  420. 'hue_upper': hue_upper,
  421. 'is_rgb': self.to_rgb,
  422. 'dtype': self.data_type
  423. }
  424. }
  425. prob_dict = {
  426. 'brightness': self.brightness_prob,
  427. 'contrast': self.contrast_prob,
  428. 'saturation': self.saturation_prob,
  429. 'hue': self.hue_prob,
  430. }
  431. for id in range(len(ops)):
  432. params = params_dict[ops[id].__name__]
  433. prob = prob_dict[ops[id].__name__]
  434. params['im'] = im
  435. if np.random.uniform(0, 1) < prob:
  436. im = ops[id](**params)
  437. im = im.astype('float32')
  438. if label is None:
  439. return (im, )
  440. else:
  441. return (im, label)
  442. class ArrangeClassifier(ClsTransform):
  443. """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
  444. Args:
  445. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  446. Raises:
  447. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
  448. """
  449. def __init__(self, mode=None):
  450. if mode not in ['train', 'eval', 'test', 'quant']:
  451. raise ValueError(
  452. "mode must be in ['train', 'eval', 'test', 'quant']!")
  453. self.mode = mode
  454. def __call__(self, im, label=None):
  455. """
  456. Args:
  457. im (np.ndarray): 图像np.ndarray数据。
  458. label (int): 每张图像所对应的类别序号。
  459. Returns:
  460. tuple: 当mode为'train'或'eval'时,返回(im, label),分别对应图像np.ndarray数据、
  461. 图像类别id;当mode为'test'或'quant'时,返回(im, ),对应图像np.ndarray数据。
  462. """
  463. im = permute(im, False).astype('float32')
  464. if self.mode == 'train' or self.mode == 'eval':
  465. outputs = (im, label)
  466. else:
  467. outputs = (im, )
  468. return outputs
  469. class ComposedClsTransforms(Compose):
  470. """ 分类模型的基础Transforms流程,具体如下
  471. 训练阶段:
  472. 1. 随机从图像中crop一块子图,并resize成crop_size大小
  473. 2. 将1的输出按0.5的概率随机进行水平翻转
  474. 3. 将图像进行归一化
  475. 验证/预测阶段:
  476. 1. 将图像按比例Resize,使得最小边长度为crop_size[0] * 1.14
  477. 2. 从图像中心crop出一个大小为crop_size的图像
  478. 3. 将图像进行归一化
  479. Args:
  480. mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
  481. crop_size(int|list): 输入模型里的图像大小
  482. mean(list): 图像均值
  483. std(list): 图像方差
  484. random_horizontal_flip(bool): 是否以0.5的概率使用随机水平翻转增强,该仅在mode为`train`时生效,默认为True
  485. """
  486. def __init__(self,
  487. mode,
  488. crop_size=[224, 224],
  489. mean=[0.485, 0.456, 0.406],
  490. std=[0.229, 0.224, 0.225],
  491. random_horizontal_flip=True):
  492. width = crop_size
  493. if isinstance(crop_size, list):
  494. if crop_size[0] != crop_size[1]:
  495. raise Exception(
  496. "In classifier model, width and height should be equal, please modify your parameter `crop_size`"
  497. )
  498. width = crop_size[0]
  499. if width % 32 != 0:
  500. raise Exception(
  501. "In classifier model, width and height should be multiple of 32, e.g 224、256、320...., please modify your parameter `crop_size`"
  502. )
  503. if mode == 'train':
  504. # 训练时的transforms,包含数据增强
  505. transforms = [
  506. RandomCrop(crop_size=width), Normalize(
  507. mean=mean, std=std)
  508. ]
  509. if random_horizontal_flip:
  510. transforms.insert(0, RandomHorizontalFlip())
  511. else:
  512. # 验证/预测时的transforms
  513. transforms = [
  514. ResizeByShort(short_size=int(width * 1.14)),
  515. CenterCrop(crop_size=width), Normalize(
  516. mean=mean, std=std)
  517. ]
  518. super(ComposedClsTransforms, self).__init__(transforms)