seg_transforms.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from .ops import *
  16. from .imgaug_support import execute_imgaug
  17. import random
  18. import os.path as osp
  19. import numpy as np
  20. from PIL import Image
  21. import cv2
  22. from collections import OrderedDict
  23. import paddlex.utils.logging as logging
  24. class SegTransform:
  25. """ 分割transform基类
  26. """
  27. def __init__(self):
  28. pass
  29. class Compose(SegTransform):
  30. """根据数据预处理/增强算子对输入数据进行操作。
  31. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  32. Args:
  33. transforms (list): 数据预处理/增强算子。
  34. Raises:
  35. TypeError: transforms不是list对象
  36. ValueError: transforms元素个数小于1。
  37. """
  38. def __init__(self, transforms):
  39. if not isinstance(transforms, list):
  40. raise TypeError('The transforms must be a list!')
  41. if len(transforms) < 1:
  42. raise ValueError('The length of transforms ' + \
  43. 'must be equal or larger than 1!')
  44. self.transforms = transforms
  45. self.batch_transforms = None
  46. self.to_rgb = False
  47. # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
  48. for op in self.transforms:
  49. if not isinstance(op, SegTransform):
  50. import imgaug.augmenters as iaa
  51. if not isinstance(op, iaa.Augmenter):
  52. raise Exception(
  53. "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
  54. )
  55. @staticmethod
  56. def decode_image(im, label):
  57. if isinstance(im, np.ndarray):
  58. if len(im.shape) != 3:
  59. raise Exception(
  60. "im should be 3-dimensions, but now is {}-dimensions".
  61. format(len(im.shape)))
  62. else:
  63. try:
  64. im = cv2.imread(im)
  65. except:
  66. raise ValueError('Can\'t read The image file {}!'.format(im))
  67. im = im.astype('float32')
  68. if label is not None:
  69. if not isinstance(label, np.ndarray):
  70. label = np.asarray(Image.open(label))
  71. return (im, label)
  72. def __call__(self, im, im_info=None, label=None):
  73. """
  74. Args:
  75. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  76. im_info (list): 存储图像reisze或padding前的shape信息,如
  77. [('resize', [200, 300]), ('padding', [400, 600])]表示
  78. 图像在过resize前shape为(200, 300), 过padding前shape为
  79. (400, 600)
  80. label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。
  81. Returns:
  82. tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
  83. """
  84. im, label = self.decode_image(im, label)
  85. if self.to_rgb:
  86. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  87. if im_info is None:
  88. im_info = [('origin_shape', im.shape[0:2])]
  89. if label is not None:
  90. origin_label = label.copy()
  91. for op in self.transforms:
  92. if isinstance(op, SegTransform):
  93. outputs = op(im, im_info, label)
  94. im = outputs[0]
  95. if len(outputs) >= 2:
  96. im_info = outputs[1]
  97. if len(outputs) == 3:
  98. label = outputs[2]
  99. else:
  100. im = execute_imgaug(op, im)
  101. if label is not None:
  102. outputs = (im, im_info, label)
  103. else:
  104. outputs = (im, im_info)
  105. if self.transforms[-1].__class__.__name__ == 'ArrangeSegmenter':
  106. if self.transforms[-1].mode == 'eval':
  107. if label is not None:
  108. outputs = (im, im_info, origin_label)
  109. return outputs
  110. def add_augmenters(self, augmenters):
  111. if not isinstance(augmenters, list):
  112. raise Exception(
  113. "augmenters should be list type in func add_augmenters()")
  114. transform_names = [type(x).__name__ for x in self.transforms]
  115. for aug in augmenters:
  116. if type(aug).__name__ in transform_names:
  117. logging.error(
  118. "{} is already in ComposedTransforms, need to remove it from add_augmenters().".
  119. format(type(aug).__name__))
  120. self.transforms = augmenters + self.transforms
  121. class RandomHorizontalFlip(SegTransform):
  122. """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
  123. Args:
  124. prob (float): 随机水平翻转的概率。默认值为0.5。
  125. """
  126. def __init__(self, prob=0.5):
  127. self.prob = prob
  128. def __call__(self, im, im_info=None, label=None):
  129. """
  130. Args:
  131. im (np.ndarray): 图像np.ndarray数据。
  132. im_info (list): 存储图像reisze或padding前的shape信息,如
  133. [('resize', [200, 300]), ('padding', [400, 600])]表示
  134. 图像在过resize前shape为(200, 300), 过padding前shape为
  135. (400, 600)
  136. label (np.ndarray): 标注图像np.ndarray数据。
  137. Returns:
  138. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  139. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  140. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  141. """
  142. if random.random() < self.prob:
  143. im = horizontal_flip(im)
  144. if label is not None:
  145. label = horizontal_flip(label)
  146. if label is None:
  147. return (im, im_info)
  148. else:
  149. return (im, im_info, label)
  150. class RandomVerticalFlip(SegTransform):
  151. """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
  152. Args:
  153. prob (float): 随机垂直翻转的概率。默认值为0.1。
  154. """
  155. def __init__(self, prob=0.1):
  156. self.prob = prob
  157. def __call__(self, im, im_info=None, label=None):
  158. """
  159. Args:
  160. im (np.ndarray): 图像np.ndarray数据。
  161. im_info (list): 存储图像reisze或padding前的shape信息,如
  162. [('resize', [200, 300]), ('padding', [400, 600])]表示
  163. 图像在过resize前shape为(200, 300), 过padding前shape为
  164. (400, 600)
  165. label (np.ndarray): 标注图像np.ndarray数据。
  166. Returns:
  167. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  168. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  169. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  170. """
  171. if random.random() < self.prob:
  172. im = vertical_flip(im)
  173. if label is not None:
  174. label = vertical_flip(label)
  175. if label is None:
  176. return (im, im_info)
  177. else:
  178. return (im, im_info, label)
  179. class Resize(SegTransform):
  180. """调整图像大小(resize),当存在标注图像时,则同步进行处理。
  181. - 当目标大小(target_size)类型为int时,根据插值方式,
  182. 将图像resize为[target_size, target_size]。
  183. - 当目标大小(target_size)类型为list或tuple时,根据插值方式,
  184. 将图像resize为target_size, target_size的输入应为[w, h]或(w, h)。
  185. Args:
  186. target_size (int|list|tuple): 目标大小。
  187. interp (str): resize的插值方式,与opencv的插值方式对应,
  188. 可选的值为['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4'],默认为"LINEAR"。
  189. Raises:
  190. TypeError: target_size不是int/list/tuple。
  191. ValueError: target_size为list/tuple时元素个数不等于2。
  192. AssertionError: interp的取值不在['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4']之内。
  193. """
  194. # The interpolation mode
  195. interp_dict = {
  196. 'NEAREST': cv2.INTER_NEAREST,
  197. 'LINEAR': cv2.INTER_LINEAR,
  198. 'CUBIC': cv2.INTER_CUBIC,
  199. 'AREA': cv2.INTER_AREA,
  200. 'LANCZOS4': cv2.INTER_LANCZOS4
  201. }
  202. def __init__(self, target_size, interp='LINEAR'):
  203. self.interp = interp
  204. assert interp in self.interp_dict, "interp should be one of {}".format(
  205. interp_dict.keys())
  206. if isinstance(target_size, list) or isinstance(target_size, tuple):
  207. if len(target_size) != 2:
  208. raise ValueError(
  209. 'when target is list or tuple, it should include 2 elements, but it is {}'
  210. .format(target_size))
  211. elif not isinstance(target_size, int):
  212. raise TypeError(
  213. "Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
  214. .format(type(target_size)))
  215. self.target_size = target_size
  216. def __call__(self, im, im_info=None, label=None):
  217. """
  218. Args:
  219. im (np.ndarray): 图像np.ndarray数据。
  220. im_info (list): 存储图像reisze或padding前的shape信息,如
  221. [('resize', [200, 300]), ('padding', [400, 600])]表示
  222. 图像在过resize前shape为(200, 300), 过padding前shape为
  223. (400, 600)
  224. label (np.ndarray): 标注图像np.ndarray数据。
  225. Returns:
  226. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  227. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  228. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  229. 其中,im_info跟新字段为:
  230. -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
  231. Raises:
  232. ZeroDivisionError: im的短边为0。
  233. TypeError: im不是np.ndarray数据。
  234. ValueError: im不是3维nd.ndarray。
  235. """
  236. if im_info is None:
  237. im_info = OrderedDict()
  238. im_info.append(('resize', im.shape[:2]))
  239. if not isinstance(im, np.ndarray):
  240. raise TypeError("ResizeImage: image type is not np.ndarray.")
  241. if len(im.shape) != 3:
  242. raise ValueError('ResizeImage: image is not 3-dimensional.')
  243. im_shape = im.shape
  244. im_size_min = np.min(im_shape[0:2])
  245. im_size_max = np.max(im_shape[0:2])
  246. if float(im_size_min) == 0:
  247. raise ZeroDivisionError('ResizeImage: min size of image is 0')
  248. if isinstance(self.target_size, int):
  249. resize_w = self.target_size
  250. resize_h = self.target_size
  251. else:
  252. resize_w = self.target_size[0]
  253. resize_h = self.target_size[1]
  254. im_scale_x = float(resize_w) / float(im_shape[1])
  255. im_scale_y = float(resize_h) / float(im_shape[0])
  256. im = cv2.resize(
  257. im,
  258. None,
  259. None,
  260. fx=im_scale_x,
  261. fy=im_scale_y,
  262. interpolation=self.interp_dict[self.interp])
  263. if label is not None:
  264. label = cv2.resize(
  265. label,
  266. None,
  267. None,
  268. fx=im_scale_x,
  269. fy=im_scale_y,
  270. interpolation=self.interp_dict['NEAREST'])
  271. if label is None:
  272. return (im, im_info)
  273. else:
  274. return (im, im_info, label)
  275. class ResizeByLong(SegTransform):
  276. """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  277. Args:
  278. long_size (int): resize后图像的长边大小。
  279. """
  280. def __init__(self, long_size):
  281. self.long_size = long_size
  282. def __call__(self, im, im_info=None, label=None):
  283. """
  284. Args:
  285. im (np.ndarray): 图像np.ndarray数据。
  286. im_info (list): 存储图像reisze或padding前的shape信息,如
  287. [('resize', [200, 300]), ('padding', [400, 600])]表示
  288. 图像在过resize前shape为(200, 300), 过padding前shape为
  289. (400, 600)
  290. label (np.ndarray): 标注图像np.ndarray数据。
  291. Returns:
  292. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  293. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  294. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  295. 其中,im_info新增字段为:
  296. -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
  297. """
  298. if im_info is None:
  299. im_info = OrderedDict()
  300. im_info.append(('resize', im.shape[:2]))
  301. im = resize_long(im, self.long_size)
  302. if label is not None:
  303. label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
  304. if label is None:
  305. return (im, im_info)
  306. else:
  307. return (im, im_info, label)
  308. class ResizeByShort(SegTransform):
  309. """根据图像的短边调整图像大小(resize)。
  310. 1. 获取图像的长边和短边长度。
  311. 2. 根据短边与short_size的比例,计算长边的目标长度,
  312. 此时高、宽的resize比例为short_size/原图短边长度。
  313. 3. 如果max_size>0,调整resize比例:
  314. 如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度。
  315. 4. 根据调整大小的比例对图像进行resize。
  316. Args:
  317. target_size (int): 短边目标长度。默认为800。
  318. max_size (int): 长边目标长度的最大限制。默认为1333。
  319. Raises:
  320. TypeError: 形参数据类型不满足需求。
  321. """
  322. def __init__(self, short_size=800, max_size=1333):
  323. self.max_size = int(max_size)
  324. if not isinstance(short_size, int):
  325. raise TypeError(
  326. "Type of short_size is invalid. Must be Integer, now is {}".
  327. format(type(short_size)))
  328. self.short_size = short_size
  329. if not (isinstance(self.max_size, int)):
  330. raise TypeError("max_size: input type is invalid.")
  331. def __call__(self, im, im_info=None, label=None):
  332. """
  333. Args:
  334. im (numnp.ndarraypy): 图像np.ndarray数据。
  335. im_info (list): 存储图像reisze或padding前的shape信息,如
  336. [('resize', [200, 300]), ('padding', [400, 600])]表示
  337. 图像在过resize前shape为(200, 300), 过padding前shape为
  338. (400, 600)
  339. label (np.ndarray): 标注图像np.ndarray数据。
  340. Returns:
  341. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  342. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  343. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  344. 其中,im_info更新字段为:
  345. -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
  346. Raises:
  347. TypeError: 形参数据类型不满足需求。
  348. ValueError: 数据长度不匹配。
  349. """
  350. if im_info is None:
  351. im_info = OrderedDict()
  352. if not isinstance(im, np.ndarray):
  353. raise TypeError("ResizeByShort: image type is not numpy.")
  354. if len(im.shape) != 3:
  355. raise ValueError('ResizeByShort: image is not 3-dimensional.')
  356. im_info.append(('resize', im.shape[:2]))
  357. im_short_size = min(im.shape[0], im.shape[1])
  358. im_long_size = max(im.shape[0], im.shape[1])
  359. scale = float(self.short_size) / im_short_size
  360. if self.max_size > 0 and np.round(scale *
  361. im_long_size) > self.max_size:
  362. scale = float(self.max_size) / float(im_long_size)
  363. resized_width = int(round(im.shape[1] * scale))
  364. resized_height = int(round(im.shape[0] * scale))
  365. im = cv2.resize(
  366. im, (resized_width, resized_height),
  367. interpolation=cv2.INTER_NEAREST)
  368. if label is not None:
  369. im = cv2.resize(
  370. label, (resized_width, resized_height),
  371. interpolation=cv2.INTER_NEAREST)
  372. if label is None:
  373. return (im, im_info)
  374. else:
  375. return (im, im_info, label)
  376. class ResizeRangeScaling(SegTransform):
  377. """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  378. Args:
  379. min_value (int): 图像长边resize后的最小值。默认值400。
  380. max_value (int): 图像长边resize后的最大值。默认值600。
  381. Raises:
  382. ValueError: min_value大于max_value
  383. """
  384. def __init__(self, min_value=400, max_value=600):
  385. if min_value > max_value:
  386. raise ValueError('min_value must be less than max_value, '
  387. 'but they are {} and {}.'.format(min_value,
  388. max_value))
  389. self.min_value = min_value
  390. self.max_value = max_value
  391. def __call__(self, im, im_info=None, label=None):
  392. """
  393. Args:
  394. im (np.ndarray): 图像np.ndarray数据。
  395. im_info (list): 存储图像reisze或padding前的shape信息,如
  396. [('resize', [200, 300]), ('padding', [400, 600])]表示
  397. 图像在过resize前shape为(200, 300), 过padding前shape为
  398. (400, 600)
  399. label (np.ndarray): 标注图像np.ndarray数据。
  400. Returns:
  401. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  402. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  403. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  404. """
  405. if self.min_value == self.max_value:
  406. random_size = self.max_value
  407. else:
  408. random_size = int(
  409. np.random.uniform(self.min_value, self.max_value) + 0.5)
  410. im = resize_long(im, random_size, cv2.INTER_LINEAR)
  411. if label is not None:
  412. label = resize_long(label, random_size, cv2.INTER_NEAREST)
  413. if label is None:
  414. return (im, im_info)
  415. else:
  416. return (im, im_info, label)
  417. class ResizeStepScaling(SegTransform):
  418. """对图像按照某一个比例resize,这个比例以scale_step_size为步长
  419. 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
  420. Args:
  421. min_scale_factor(float), resize最小尺度。默认值0.75。
  422. max_scale_factor (float), resize最大尺度。默认值1.25。
  423. scale_step_size (float), resize尺度范围间隔。默认值0.25。
  424. Raises:
  425. ValueError: min_scale_factor大于max_scale_factor
  426. """
  427. def __init__(self,
  428. min_scale_factor=0.75,
  429. max_scale_factor=1.25,
  430. scale_step_size=0.25):
  431. if min_scale_factor > max_scale_factor:
  432. raise ValueError(
  433. 'min_scale_factor must be less than max_scale_factor, '
  434. 'but they are {} and {}.'.format(min_scale_factor,
  435. max_scale_factor))
  436. self.min_scale_factor = min_scale_factor
  437. self.max_scale_factor = max_scale_factor
  438. self.scale_step_size = scale_step_size
  439. def __call__(self, im, im_info=None, label=None):
  440. """
  441. Args:
  442. im (np.ndarray): 图像np.ndarray数据。
  443. im_info (list): 存储图像reisze或padding前的shape信息,如
  444. [('resize', [200, 300]), ('padding', [400, 600])]表示
  445. 图像在过resize前shape为(200, 300), 过padding前shape为
  446. (400, 600)
  447. label (np.ndarray): 标注图像np.ndarray数据。
  448. Returns:
  449. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  450. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  451. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  452. """
  453. if self.min_scale_factor == self.max_scale_factor:
  454. scale_factor = self.min_scale_factor
  455. elif self.scale_step_size == 0:
  456. scale_factor = np.random.uniform(self.min_scale_factor,
  457. self.max_scale_factor)
  458. else:
  459. num_steps = int((self.max_scale_factor - self.min_scale_factor) /
  460. self.scale_step_size + 1)
  461. scale_factors = np.linspace(self.min_scale_factor,
  462. self.max_scale_factor,
  463. num_steps).tolist()
  464. np.random.shuffle(scale_factors)
  465. scale_factor = scale_factors[0]
  466. im = cv2.resize(
  467. im, (0, 0),
  468. fx=scale_factor,
  469. fy=scale_factor,
  470. interpolation=cv2.INTER_LINEAR)
  471. if label is not None:
  472. label = cv2.resize(
  473. label, (0, 0),
  474. fx=scale_factor,
  475. fy=scale_factor,
  476. interpolation=cv2.INTER_NEAREST)
  477. if label is None:
  478. return (im, im_info)
  479. else:
  480. return (im, im_info, label)
  481. class Normalize(SegTransform):
  482. """对图像进行标准化。
  483. 1.尺度缩放到 [0,1]。
  484. 2.对图像进行减均值除以标准差操作。
  485. Args:
  486. mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
  487. std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
  488. Raises:
  489. ValueError: mean或std不是list对象。std包含0。
  490. """
  491. def __init__(self,
  492. mean=[0.5, 0.5, 0.5],
  493. std=[0.5, 0.5, 0.5],
  494. min_val=[0, 0, 0],
  495. max_val=[255.0, 255.0, 255.0]):
  496. self.min_val = min_val
  497. self.max_val = max_val
  498. self.mean = mean
  499. self.std = std
  500. if not (isinstance(self.mean, list) and isinstance(self.std, list)):
  501. raise ValueError("{}: input type is invalid.".format(self))
  502. if not (isinstance(self.min_val, list) and isinstance(self.max_val,
  503. list)):
  504. raise ValueError("{}: input type is invalid.".format(self))
  505. from functools import reduce
  506. if reduce(lambda x, y: x * y, self.std) == 0:
  507. raise ValueError('{}: std is invalid!'.format(self))
  508. def __call__(self, im, im_info=None, label=None):
  509. """
  510. Args:
  511. im (np.ndarray): 图像np.ndarray数据。
  512. im_info (list): 存储图像reisze或padding前的shape信息,如
  513. [('resize', [200, 300]), ('padding', [400, 600])]表示
  514. 图像在过resize前shape为(200, 300), 过padding前shape为
  515. (400, 600)
  516. label (np.ndarray): 标注图像np.ndarray数据。
  517. Returns:
  518. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  519. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  520. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  521. """
  522. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  523. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  524. im = normalize(im, mean, std, self.min_val, self.max_val)
  525. if label is None:
  526. return (im, im_info)
  527. else:
  528. return (im, im_info, label)
  529. class Padding(SegTransform):
  530. """对图像或标注图像进行padding,padding方向为右和下。
  531. 根据提供的值对图像或标注图像进行padding操作。
  532. Args:
  533. target_size (int|list|tuple): padding后图像的大小。
  534. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  535. label_padding_value (int): 标注图像padding的值。默认值为255。
  536. Raises:
  537. TypeError: target_size不是int|list|tuple。
  538. ValueError: target_size为list|tuple时元素个数不等于2。
  539. """
  540. def __init__(self,
  541. target_size,
  542. im_padding_value=[127.5, 127.5, 127.5],
  543. label_padding_value=255):
  544. if isinstance(target_size, list) or isinstance(target_size, tuple):
  545. if len(target_size) != 2:
  546. raise ValueError(
  547. 'when target is list or tuple, it should include 2 elements, but it is {}'
  548. .format(target_size))
  549. elif not isinstance(target_size, int):
  550. raise TypeError(
  551. "Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
  552. .format(type(target_size)))
  553. self.target_size = target_size
  554. self.im_padding_value = im_padding_value
  555. self.label_padding_value = label_padding_value
  556. def __call__(self, im, im_info=None, label=None):
  557. """
  558. Args:
  559. im (np.ndarray): 图像np.ndarray数据。
  560. im_info (list): 存储图像reisze或padding前的shape信息,如
  561. [('resize', [200, 300]), ('padding', [400, 600])]表示
  562. 图像在过resize前shape为(200, 300), 过padding前shape为
  563. (400, 600)
  564. label (np.ndarray): 标注图像np.ndarray数据。
  565. Returns:
  566. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  567. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  568. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  569. 其中,im_info新增字段为:
  570. -shape_before_padding (tuple): 保存padding之前图像的形状(h, w)。
  571. Raises:
  572. ValueError: 输入图像im或label的形状大于目标值
  573. """
  574. if im_info is None:
  575. im_info = OrderedDict()
  576. im_info.append(('padding', im.shape[:2]))
  577. im_height, im_width = im.shape[0], im.shape[1]
  578. if isinstance(self.target_size, int):
  579. target_height = self.target_size
  580. target_width = self.target_size
  581. else:
  582. target_height = self.target_size[1]
  583. target_width = self.target_size[0]
  584. pad_height = target_height - im_height
  585. pad_width = target_width - im_width
  586. if pad_height < 0 or pad_width < 0:
  587. raise ValueError(
  588. 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
  589. .format(im_width, im_height, target_width, target_height))
  590. else:
  591. im = cv2.copyMakeBorder(
  592. im,
  593. 0,
  594. pad_height,
  595. 0,
  596. pad_width,
  597. cv2.BORDER_CONSTANT,
  598. value=self.im_padding_value)
  599. if label is not None:
  600. label = cv2.copyMakeBorder(
  601. label,
  602. 0,
  603. pad_height,
  604. 0,
  605. pad_width,
  606. cv2.BORDER_CONSTANT,
  607. value=self.label_padding_value)
  608. if label is None:
  609. return (im, im_info)
  610. else:
  611. return (im, im_info, label)
  612. class RandomPaddingCrop(SegTransform):
  613. """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
  614. Args:
  615. crop_size (int|list|tuple): 裁剪图像大小。默认为512。
  616. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  617. label_padding_value (int): 标注图像padding的值。默认值为255。
  618. Raises:
  619. TypeError: crop_size不是int/list/tuple。
  620. ValueError: target_size为list/tuple时元素个数不等于2。
  621. """
  622. def __init__(self,
  623. crop_size=512,
  624. im_padding_value=[127.5, 127.5, 127.5],
  625. label_padding_value=255):
  626. if isinstance(crop_size, list) or isinstance(crop_size, tuple):
  627. if len(crop_size) != 2:
  628. raise ValueError(
  629. 'when crop_size is list or tuple, it should include 2 elements, but it is {}'
  630. .format(crop_size))
  631. elif not isinstance(crop_size, int):
  632. raise TypeError(
  633. "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}"
  634. .format(type(crop_size)))
  635. self.crop_size = crop_size
  636. self.im_padding_value = im_padding_value
  637. self.label_padding_value = label_padding_value
  638. def __call__(self, im, im_info=None, label=None):
  639. """
  640. Args:
  641. im (np.ndarray): 图像np.ndarray数据。
  642. im_info (list): 存储图像reisze或padding前的shape信息,如
  643. [('resize', [200, 300]), ('padding', [400, 600])]表示
  644. 图像在过resize前shape为(200, 300), 过padding前shape为
  645. (400, 600)
  646. label (np.ndarray): 标注图像np.ndarray数据。
  647. Returns:
  648. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  649. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  650. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  651. """
  652. if isinstance(self.crop_size, int):
  653. crop_width = self.crop_size
  654. crop_height = self.crop_size
  655. else:
  656. crop_width = self.crop_size[0]
  657. crop_height = self.crop_size[1]
  658. img_height = im.shape[0]
  659. img_width = im.shape[1]
  660. if img_height == crop_height and img_width == crop_width:
  661. if label is None:
  662. return (im, im_info)
  663. else:
  664. return (im, im_info, label)
  665. else:
  666. pad_height = max(crop_height - img_height, 0)
  667. pad_width = max(crop_width - img_width, 0)
  668. if (pad_height > 0 or pad_width > 0):
  669. img_channel = im.shape[2]
  670. import copy
  671. orig_im = copy.deepcopy(im)
  672. im = np.zeros((img_height + pad_height, img_width + pad_width,
  673. img_channel)).astype(orig_im.dtype)
  674. for i in range(img_channel):
  675. im[:, :, i] = np.pad(
  676. orig_im[:, :, i],
  677. pad_width=((0, pad_height), (0, pad_width)),
  678. mode='constant',
  679. constant_values=(self.im_padding_value[i],
  680. self.im_padding_value[i]))
  681. if label is not None:
  682. label = np.pad(label,
  683. pad_width=((0, pad_height), (0, pad_width)),
  684. mode='constant',
  685. constant_values=(self.label_padding_value,
  686. self.label_padding_value))
  687. img_height = im.shape[0]
  688. img_width = im.shape[1]
  689. if crop_height > 0 and crop_width > 0:
  690. h_off = np.random.randint(img_height - crop_height + 1)
  691. w_off = np.random.randint(img_width - crop_width + 1)
  692. im = im[h_off:(crop_height + h_off), w_off:(w_off + crop_width
  693. ), :]
  694. if label is not None:
  695. label = label[h_off:(crop_height + h_off), w_off:(
  696. w_off + crop_width)]
  697. if label is None:
  698. return (im, im_info)
  699. else:
  700. return (im, im_info, label)
  701. class RandomBlur(SegTransform):
  702. """以一定的概率对图像进行高斯模糊。
  703. Args:
  704. prob (float): 图像模糊概率。默认为0.1。
  705. """
  706. def __init__(self, prob=0.1):
  707. self.prob = prob
  708. def __call__(self, im, im_info=None, label=None):
  709. """
  710. Args:
  711. im (np.ndarray): 图像np.ndarray数据。
  712. im_info (list): 存储图像reisze或padding前的shape信息,如
  713. [('resize', [200, 300]), ('padding', [400, 600])]表示
  714. 图像在过resize前shape为(200, 300), 过padding前shape为
  715. (400, 600)
  716. label (np.ndarray): 标注图像np.ndarray数据。
  717. Returns:
  718. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  719. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  720. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  721. """
  722. if self.prob <= 0:
  723. n = 0
  724. elif self.prob >= 1:
  725. n = 1
  726. else:
  727. n = int(1.0 / self.prob)
  728. if n > 0:
  729. if np.random.randint(0, n) == 0:
  730. radius = np.random.randint(3, 10)
  731. if radius % 2 != 1:
  732. radius = radius + 1
  733. if radius > 9:
  734. radius = 9
  735. im = cv2.GaussianBlur(im, (radius, radius), 0, 0)
  736. if label is None:
  737. return (im, im_info)
  738. else:
  739. return (im, im_info, label)
  740. class RandomRotate(SegTransform):
  741. """对图像进行随机旋转, 模型训练时的数据增强操作。
  742. 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
  743. 并对旋转后的图像和标注图像进行相应的padding。
  744. Args:
  745. rotate_range (float): 最大旋转角度。默认为15度。
  746. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  747. label_padding_value (int): 标注图像padding的值。默认为255。
  748. """
  749. def __init__(self,
  750. rotate_range=15,
  751. im_padding_value=[127.5, 127.5, 127.5],
  752. label_padding_value=255):
  753. self.rotate_range = rotate_range
  754. self.im_padding_value = im_padding_value
  755. self.label_padding_value = label_padding_value
  756. def __call__(self, im, im_info=None, label=None):
  757. """
  758. Args:
  759. im (np.ndarray): 图像np.ndarray数据。
  760. im_info (list): 存储图像reisze或padding前的shape信息,如
  761. [('resize', [200, 300]), ('padding', [400, 600])]表示
  762. 图像在过resize前shape为(200, 300), 过padding前shape为
  763. (400, 600)
  764. label (np.ndarray): 标注图像np.ndarray数据。
  765. Returns:
  766. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  767. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  768. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  769. """
  770. if self.rotate_range > 0:
  771. (h, w) = im.shape[:2]
  772. do_rotation = np.random.uniform(-self.rotate_range,
  773. self.rotate_range)
  774. pc = (w // 2, h // 2)
  775. r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
  776. cos = np.abs(r[0, 0])
  777. sin = np.abs(r[0, 1])
  778. nw = int((h * sin) + (w * cos))
  779. nh = int((h * cos) + (w * sin))
  780. (cx, cy) = pc
  781. r[0, 2] += (nw / 2) - cx
  782. r[1, 2] += (nh / 2) - cy
  783. dsize = (nw, nh)
  784. im = cv2.warpAffine(
  785. im,
  786. r,
  787. dsize=dsize,
  788. flags=cv2.INTER_LINEAR,
  789. borderMode=cv2.BORDER_CONSTANT,
  790. borderValue=self.im_padding_value)
  791. label = cv2.warpAffine(
  792. label,
  793. r,
  794. dsize=dsize,
  795. flags=cv2.INTER_NEAREST,
  796. borderMode=cv2.BORDER_CONSTANT,
  797. borderValue=self.label_padding_value)
  798. if label is None:
  799. return (im, im_info)
  800. else:
  801. return (im, im_info, label)
  802. class RandomScaleAspect(SegTransform):
  803. """裁剪并resize回原始尺寸的图像和标注图像。
  804. 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
  805. Args:
  806. min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
  807. aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
  808. """
  809. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  810. self.min_scale = min_scale
  811. self.aspect_ratio = aspect_ratio
  812. def __call__(self, im, im_info=None, label=None):
  813. """
  814. Args:
  815. im (np.ndarray): 图像np.ndarray数据。
  816. im_info (list): 存储图像reisze或padding前的shape信息,如
  817. [('resize', [200, 300]), ('padding', [400, 600])]表示
  818. 图像在过resize前shape为(200, 300), 过padding前shape为
  819. (400, 600)
  820. label (np.ndarray): 标注图像np.ndarray数据。
  821. Returns:
  822. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  823. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  824. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  825. """
  826. if self.min_scale != 0 and self.aspect_ratio != 0:
  827. img_height = im.shape[0]
  828. img_width = im.shape[1]
  829. for i in range(0, 10):
  830. area = img_height * img_width
  831. target_area = area * np.random.uniform(self.min_scale, 1.0)
  832. aspectRatio = np.random.uniform(self.aspect_ratio,
  833. 1.0 / self.aspect_ratio)
  834. dw = int(np.sqrt(target_area * 1.0 * aspectRatio))
  835. dh = int(np.sqrt(target_area * 1.0 / aspectRatio))
  836. if (np.random.randint(10) < 5):
  837. tmp = dw
  838. dw = dh
  839. dh = tmp
  840. if (dh < img_height and dw < img_width):
  841. h1 = np.random.randint(0, img_height - dh)
  842. w1 = np.random.randint(0, img_width - dw)
  843. im = im[h1:(h1 + dh), w1:(w1 + dw), :]
  844. label = label[h1:(h1 + dh), w1:(w1 + dw)]
  845. im = cv2.resize(
  846. im, (img_width, img_height),
  847. interpolation=cv2.INTER_LINEAR)
  848. label = cv2.resize(
  849. label, (img_width, img_height),
  850. interpolation=cv2.INTER_NEAREST)
  851. break
  852. if label is None:
  853. return (im, im_info)
  854. else:
  855. return (im, im_info, label)
  856. class RandomDistort(SegTransform):
  857. """对图像进行随机失真。
  858. 1. 对变换的操作顺序进行随机化操作。
  859. 2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。
  860. Args:
  861. brightness_range (float): 明亮度因子的范围。默认为0.5。
  862. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  863. contrast_range (float): 对比度因子的范围。默认为0.5。
  864. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  865. saturation_range (float): 饱和度因子的范围。默认为0.5。
  866. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  867. hue_range (int): 色调因子的范围。默认为18。
  868. hue_prob (float): 随机调整色调的概率。默认为0.5。
  869. """
  870. def __init__(self,
  871. brightness_range=0.5,
  872. brightness_prob=0.5,
  873. contrast_range=0.5,
  874. contrast_prob=0.5,
  875. saturation_range=0.5,
  876. saturation_prob=0.5,
  877. hue_range=18,
  878. hue_prob=0.5):
  879. self.brightness_range = brightness_range
  880. self.brightness_prob = brightness_prob
  881. self.contrast_range = contrast_range
  882. self.contrast_prob = contrast_prob
  883. self.saturation_range = saturation_range
  884. self.saturation_prob = saturation_prob
  885. self.hue_range = hue_range
  886. self.hue_prob = hue_prob
  887. def __call__(self, im, im_info=None, label=None):
  888. """
  889. Args:
  890. im (np.ndarray): 图像np.ndarray数据。
  891. im_info (list): 存储图像reisze或padding前的shape信息,如
  892. [('resize', [200, 300]), ('padding', [400, 600])]表示
  893. 图像在过resize前shape为(200, 300), 过padding前shape为
  894. (400, 600)
  895. label (np.ndarray): 标注图像np.ndarray数据。
  896. Returns:
  897. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  898. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  899. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  900. """
  901. brightness_lower = 1 - self.brightness_range
  902. brightness_upper = 1 + self.brightness_range
  903. contrast_lower = 1 - self.contrast_range
  904. contrast_upper = 1 + self.contrast_range
  905. saturation_lower = 1 - self.saturation_range
  906. saturation_upper = 1 + self.saturation_range
  907. hue_lower = -self.hue_range
  908. hue_upper = self.hue_range
  909. ops = [brightness, contrast, saturation, hue]
  910. random.shuffle(ops)
  911. params_dict = {
  912. 'brightness': {
  913. 'brightness_lower': brightness_lower,
  914. 'brightness_upper': brightness_upper
  915. },
  916. 'contrast': {
  917. 'contrast_lower': contrast_lower,
  918. 'contrast_upper': contrast_upper
  919. },
  920. 'saturation': {
  921. 'saturation_lower': saturation_lower,
  922. 'saturation_upper': saturation_upper
  923. },
  924. 'hue': {
  925. 'hue_lower': hue_lower,
  926. 'hue_upper': hue_upper
  927. }
  928. }
  929. prob_dict = {
  930. 'brightness': self.brightness_prob,
  931. 'contrast': self.contrast_prob,
  932. 'saturation': self.saturation_prob,
  933. 'hue': self.hue_prob
  934. }
  935. for id in range(4):
  936. params = params_dict[ops[id].__name__]
  937. prob = prob_dict[ops[id].__name__]
  938. params['im'] = im
  939. if np.random.uniform(0, 1) < prob:
  940. im = ops[id](**params)
  941. im = im.astype('float32')
  942. if label is None:
  943. return (im, im_info)
  944. else:
  945. return (im, im_info, label)
  946. class ArrangeSegmenter(SegTransform):
  947. """获取训练/验证/预测所需的信息。
  948. Args:
  949. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  950. Raises:
  951. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内
  952. """
  953. def __init__(self, mode):
  954. if mode not in ['train', 'eval', 'test', 'quant']:
  955. raise ValueError(
  956. "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
  957. )
  958. self.mode = mode
  959. def __call__(self, im, im_info, label=None):
  960. """
  961. Args:
  962. im (np.ndarray): 图像np.ndarray数据。
  963. im_info (list): 存储图像reisze或padding前的shape信息,如
  964. [('resize', [200, 300]), ('padding', [400, 600])]表示
  965. 图像在过resize前shape为(200, 300), 过padding前shape为
  966. (400, 600)
  967. label (np.ndarray): 标注图像np.ndarray数据。
  968. Returns:
  969. tuple: 当mode为'train'或'eval'时,返回的tuple为(im, label),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  970. 当mode为'test'时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;当mode为
  971. 'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
  972. """
  973. im = permute(im, False)
  974. if self.mode == 'train':
  975. label = label[np.newaxis, :, :]
  976. return (im, label)
  977. if self.mode == 'eval':
  978. label = label[np.newaxis, :, :]
  979. return (im, im_info, label)
  980. elif self.mode == 'test':
  981. return (im, im_info)
  982. else:
  983. return (im, )
  984. class ComposedSegTransforms(Compose):
  985. """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
  986. 训练阶段:
  987. 1. 随机对图像以0.5的概率水平翻转,若random_horizontal_flip为False,则跳过此步骤
  988. 2. 按不同的比例随机Resize原图, 处理方式参考[paddlex.seg.transforms.ResizeRangeScaling](#resizerangescaling)。若min_max_size为None,则跳过此步骤
  989. 3. 从原图中随机crop出大小为train_crop_size大小的子图,如若crop出来的图小于train_crop_size,则会将图padding到对应大小
  990. 4. 图像归一化
  991. 预测阶段:
  992. 1. 将图像的最长边resize至(min_max_size[0] + min_max_size[1])//2, 短边按比例resize。若min_max_size为None,则跳过此步骤
  993. 2. 图像归一化
  994. Args:
  995. mode(str): Transforms所处的阶段,包括`train', 'eval'或'test'
  996. min_max_size(list): 用于对图像进行resize,具体作用参见上述步骤。
  997. train_crop_size(list): 训练过程中随机裁剪原图用于训练,具体作用参见上述步骤。此参数仅在mode为`train`时生效。
  998. mean(list): 图像均值, 默认为[0.485, 0.456, 0.406]。
  999. std(list): 图像方差,默认为[0.229, 0.224, 0.225]。
  1000. random_horizontal_flip(bool): 数据增强,是否随机水平翻转图像,此参数仅在mode为`train`时生效。
  1001. """
  1002. def __init__(self,
  1003. mode,
  1004. min_max_size=[400, 600],
  1005. train_crop_size=[512, 512],
  1006. mean=[0.5, 0.5, 0.5],
  1007. std=[0.5, 0.5, 0.5],
  1008. random_horizontal_flip=True):
  1009. if mode == 'train':
  1010. # 训练时的transforms,包含数据增强
  1011. if min_max_size is None:
  1012. transforms = [
  1013. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  1014. mean=mean, std=std)
  1015. ]
  1016. else:
  1017. transforms = [
  1018. ResizeRangeScaling(
  1019. min_value=min(min_max_size),
  1020. max_value=max(min_max_size)),
  1021. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  1022. mean=mean, std=std)
  1023. ]
  1024. if random_horizontal_flip:
  1025. transforms.insert(0, RandomHorizontalFlip())
  1026. else:
  1027. # 验证/预测时的transforms
  1028. if min_max_size is None:
  1029. transforms = [Normalize(mean=mean, std=std)]
  1030. else:
  1031. long_size = (min(min_max_size) + max(min_max_size)) // 2
  1032. transforms = [
  1033. ResizeByLong(long_size=long_size), Normalize(
  1034. mean=mean, std=std)
  1035. ]
  1036. super(ComposedSegTransforms, self).__init__(transforms)