seg_transforms.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. function:
  16. transforms for segmentation in PaddleX<2.0
  17. """
  18. import numpy as np
  19. import cv2
  20. import copy
  21. from .operators import Transform, Compose, RandomHorizontalFlip, RandomVerticalFlip, Resize, \
  22. ResizeByShort, Normalize, RandomDistort, ArrangeSegmenter
  23. from .operators import Padding as dy_Padding
  24. class ResizeByLong(Transform):
  25. """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  26. Args:
  27. long_size (int): resize后图像的长边大小。
  28. """
  29. def __init__(self, long_size=256):
  30. super(ResizeByLong, self).__init__()
  31. self.long_size = long_size
  32. def apply_im(self, image):
  33. image = _resize_long(image, long_size=self.long_size)
  34. return image
  35. def apply_mask(self, mask):
  36. mask = _resize_long(
  37. mask, long_size=self.long_size, interpolation=cv2.INTER_NEAREST)
  38. return mask
  39. def apply(self, sample):
  40. sample['image'] = self.apply_im(sample['image'])
  41. if 'mask' in sample:
  42. sample['mask'] = self.apply_mask(sample['mask'])
  43. return sample
  44. class ResizeRangeScaling(Transform):
  45. """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  46. Args:
  47. min_value (int): 图像长边resize后的最小值。默认值400。
  48. max_value (int): 图像长边resize后的最大值。默认值600。
  49. Raises:
  50. ValueError: min_value大于max_value
  51. """
  52. def __init__(self, min_value=400, max_value=600):
  53. super(ResizeRangeScaling, self).__init__()
  54. if min_value > max_value:
  55. raise ValueError('min_value must be less than max_value, '
  56. 'but they are {} and {}.'.format(min_value,
  57. max_value))
  58. self.min_value = min_value
  59. self.max_value = max_value
  60. def apply_im(self, image, random_size):
  61. image = _resize_long(image, long_size=random_size)
  62. return image
  63. def apply_mask(self, mask, random_size):
  64. mask = _resize_long(
  65. mask, long_size=random_size, interpolation=cv2.INTER_NEAREST)
  66. return mask
  67. def apply(self, sample):
  68. if self.min_value == self.max_value:
  69. random_size = self.max_value
  70. else:
  71. random_size = int(
  72. np.random.uniform(self.min_value, self.max_value) + 0.5)
  73. sample['image'] = self.apply_im(sample['image'], random_size)
  74. if 'mask' in sample:
  75. sample['mask'] = self.apply_mask(sample['mask'], random_size)
  76. return sample
  77. class ResizeStepScaling(Transform):
  78. """对图像按照某一个比例resize,这个比例以scale_step_size为步长
  79. 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
  80. Args:
  81. min_scale_factor(float), resize最小尺度。默认值0.75。
  82. max_scale_factor (float), resize最大尺度。默认值1.25。
  83. scale_step_size (float), resize尺度范围间隔。默认值0.25。
  84. Raises:
  85. ValueError: min_scale_factor大于max_scale_factor
  86. """
  87. def __init__(self,
  88. min_scale_factor=0.75,
  89. max_scale_factor=1.25,
  90. scale_step_size=0.25):
  91. if min_scale_factor > max_scale_factor:
  92. raise ValueError(
  93. 'min_scale_factor must be less than max_scale_factor, '
  94. 'but they are {} and {}.'.format(min_scale_factor,
  95. max_scale_factor))
  96. super(ResizeStepScaling, self).__init__()
  97. self.min_scale_factor = min_scale_factor
  98. self.max_scale_factor = max_scale_factor
  99. self.scale_step_size = scale_step_size
  100. def apply_im(self, image, scale_factor):
  101. image = cv2.resize(
  102. image, (0, 0),
  103. fx=scale_factor,
  104. fy=scale_factor,
  105. interpolation=cv2.INTER_LINEAR)
  106. if image.ndim < 3:
  107. image = np.expand_dims(image, axis=-1)
  108. return image
  109. def apply_mask(self, mask, scale_factor):
  110. mask = cv2.resize(
  111. mask, (0, 0),
  112. fx=scale_factor,
  113. fy=scale_factor,
  114. interpolation=cv2.INTER_NEAREST)
  115. return mask
  116. def apply(self, sample):
  117. if self.min_scale_factor == self.max_scale_factor:
  118. scale_factor = self.min_scale_factor
  119. elif self.scale_step_size == 0:
  120. scale_factor = np.random.uniform(self.min_scale_factor,
  121. self.max_scale_factor)
  122. else:
  123. num_steps = int((self.max_scale_factor - self.min_scale_factor) /
  124. self.scale_step_size + 1)
  125. scale_factors = np.linspace(self.min_scale_factor,
  126. self.max_scale_factor,
  127. num_steps).tolist()
  128. np.random.shuffle(scale_factors)
  129. scale_factor = scale_factors[0]
  130. sample['image'] = self.apply_im(sample['image'], scale_factor)
  131. if 'mask' in sample:
  132. sample['mask'] = self.apply_mask(sample['mask'], scale_factor)
  133. return sample
  134. class Padding(dy_Padding):
  135. """对图像或标注图像进行padding,padding方向为右和下。
  136. 根据提供的值对图像或标注图像进行padding操作。
  137. Args:
  138. target_size (int|list|tuple): padding后图像的大小。
  139. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  140. label_padding_value (int): 标注图像padding的值。默认值为255。
  141. Raises:
  142. TypeError: target_size不是int|list|tuple。
  143. ValueError: target_size为list|tuple时元素个数不等于2。
  144. """
  145. def __init__(self,
  146. target_size,
  147. im_padding_value=[127.5, 127.5, 127.5],
  148. label_padding_value=255):
  149. super(Padding, self).__init__(
  150. target_size=target_size,
  151. pad_mode=0,
  152. offsets=None,
  153. im_padding_value=im_padding_value,
  154. label_padding_value=label_padding_value)
  155. class RandomPaddingCrop(Transform):
  156. """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
  157. Args:
  158. crop_size (int|list|tuple): 裁剪图像大小。默认为512。
  159. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  160. label_padding_value (int): 标注图像padding的值。默认值为255。
  161. Raises:
  162. TypeError: crop_size不是int/list/tuple。
  163. ValueError: target_size为list/tuple时元素个数不等于2。
  164. """
  165. def __init__(self,
  166. crop_size=512,
  167. im_padding_value=[127.5, 127.5, 127.5],
  168. label_padding_value=255):
  169. if isinstance(crop_size, list) or isinstance(crop_size, tuple):
  170. if len(crop_size) != 2:
  171. raise ValueError(
  172. 'when crop_size is list or tuple, it should include 2 elements, but it is {}'
  173. .format(crop_size))
  174. elif not isinstance(crop_size, int):
  175. raise TypeError(
  176. "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}"
  177. .format(type(crop_size)))
  178. super(RandomPaddingCrop, self).__init__()
  179. self.crop_size = crop_size
  180. self.im_padding_value = im_padding_value
  181. self.label_padding_value = label_padding_value
  182. def apply_im(self, image, pad_h, pad_w):
  183. im_h, im_w, im_c = image.shape
  184. orig_im = copy.deepcopy(image)
  185. image = np.zeros(
  186. (im_h + pad_h, im_w + pad_w, im_c)).astype(orig_im.dtype)
  187. for i in range(im_c):
  188. image[:, :, i] = np.pad(orig_im[:, :, i],
  189. pad_width=((0, pad_h), (0, pad_w)),
  190. mode='constant',
  191. constant_values=(self.im_padding_value[i],
  192. self.im_padding_value[i]))
  193. return image
  194. def apply_mask(self, mask, pad_h, pad_w):
  195. mask = np.pad(mask,
  196. pad_width=((0, pad_h), (0, pad_w)),
  197. mode='constant',
  198. constant_values=(self.label_padding_value,
  199. self.label_padding_value))
  200. return mask
  201. def apply(self, sample):
  202. """
  203. Args:
  204. im (np.ndarray): 图像np.ndarray数据。
  205. im_info (list): 存储图像reisze或padding前的shape信息,如
  206. [('resize', [200, 300]), ('padding', [400, 600])]表示
  207. 图像在过resize前shape为(200, 300), 过padding前shape为
  208. (400, 600)
  209. label (np.ndarray): 标注图像np.ndarray数据。
  210. Returns:
  211. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  212. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  213. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  214. """
  215. if isinstance(self.crop_size, int):
  216. crop_width = self.crop_size
  217. crop_height = self.crop_size
  218. else:
  219. crop_width = self.crop_size[0]
  220. crop_height = self.crop_size[1]
  221. im_h, im_w, im_c = sample['image'].shape
  222. if im_h == crop_height and im_w == crop_width:
  223. return sample
  224. else:
  225. pad_height = max(crop_height - im_h, 0)
  226. pad_width = max(crop_width - im_w, 0)
  227. if pad_height > 0 or pad_width > 0:
  228. sample['image'] = self.apply_im(sample['image'], pad_height,
  229. pad_width)
  230. if 'mask' in sample:
  231. sample['mask'] = self.apply_mask(sample['mask'],
  232. pad_height, pad_width)
  233. im_h = sample['image'].shape[0]
  234. im_w = sample['image'].shape[1]
  235. if crop_height > 0 and crop_width > 0:
  236. h_off = np.random.randint(im_h - crop_height + 1)
  237. w_off = np.random.randint(im_w - crop_width + 1)
  238. sample['image'] = sample['image'][h_off:(
  239. crop_height + h_off), w_off:(w_off + crop_width), :]
  240. if 'mask' in sample:
  241. sample['mask'] = sample['mask'][h_off:(
  242. crop_height + h_off), w_off:(w_off + crop_width)]
  243. return sample
  244. class RandomBlur(Transform):
  245. """以一定的概率对图像进行高斯模糊。
  246. Args:
  247. prob (float): 图像模糊概率。默认为0.1。
  248. """
  249. def __init__(self, prob=0.1):
  250. super(RandomBlur, self).__init__()
  251. self.prob = prob
  252. def apply_im(self, image, radius):
  253. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  254. return image
  255. def apply(self, sample):
  256. if self.prob <= 0:
  257. n = 0
  258. elif self.prob >= 1:
  259. n = 1
  260. else:
  261. n = int(1.0 / self.prob)
  262. if n > 0:
  263. if np.random.randint(0, n) == 0:
  264. radius = np.random.randint(3, 10)
  265. if radius % 2 != 1:
  266. radius = radius + 1
  267. if radius > 9:
  268. radius = 9
  269. sample['image'] = self.apply_im(sample['image'], radius)
  270. return sample
  271. class RandomRotate(Transform):
  272. """对图像进行随机旋转, 模型训练时的数据增强操作。
  273. 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
  274. 并对旋转后的图像和标注图像进行相应的padding。
  275. Args:
  276. rotate_range (float): 最大旋转角度。默认为15度。
  277. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  278. label_padding_value (int): 标注图像padding的值。默认为255。
  279. """
  280. def __init__(self,
  281. rotate_range=15,
  282. im_padding_value=[127.5, 127.5, 127.5],
  283. label_padding_value=255):
  284. super(RandomRotate, self).__init__()
  285. self.rotate_range = rotate_range
  286. self.im_padding_value = im_padding_value
  287. self.label_padding_value = label_padding_value
  288. def apply(self, sample):
  289. if self.rotate_range > 0:
  290. h, w, c = sample['image'].shape
  291. do_rotation = np.random.uniform(-self.rotate_range,
  292. self.rotate_range)
  293. pc = (w // 2, h // 2)
  294. r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
  295. cos = np.abs(r[0, 0])
  296. sin = np.abs(r[0, 1])
  297. nw = int((h * sin) + (w * cos))
  298. nh = int((h * cos) + (w * sin))
  299. (cx, cy) = pc
  300. r[0, 2] += (nw / 2) - cx
  301. r[1, 2] += (nh / 2) - cy
  302. dsize = (nw, nh)
  303. rot_ims = list()
  304. for i in range(0, c, 3):
  305. ori_im = sample['image'][:, :, i:i + 3]
  306. rot_im = cv2.warpAffine(
  307. ori_im,
  308. r,
  309. dsize=dsize,
  310. flags=cv2.INTER_LINEAR,
  311. borderMode=cv2.BORDER_CONSTANT,
  312. borderValue=self.im_padding_value[i:i + 3])
  313. rot_ims.append(rot_im)
  314. sample['image'] = np.concatenate(rot_ims, axis=-1)
  315. if 'mask' in sample:
  316. sample['mask'] = cv2.warpAffine(
  317. sample['mask'],
  318. r,
  319. dsize=dsize,
  320. flags=cv2.INTER_NEAREST,
  321. borderMode=cv2.BORDER_CONSTANT,
  322. borderValue=self.label_padding_value)
  323. return sample
  324. class RandomScaleAspect(Transform):
  325. """裁剪并resize回原始尺寸的图像和标注图像。
  326. 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
  327. Args:
  328. min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
  329. aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
  330. """
  331. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  332. super(RandomScaleAspect, self).__init__()
  333. self.min_scale = min_scale
  334. self.aspect_ratio = aspect_ratio
  335. def apply(self, sample):
  336. if self.min_scale != 0 and self.aspect_ratio != 0:
  337. img_height = sample['image'].shape[0]
  338. img_width = sample['image'].shape[1]
  339. for i in range(0, 10):
  340. area = img_height * img_width
  341. target_area = area * np.random.uniform(self.min_scale, 1.0)
  342. aspectRatio = np.random.uniform(self.aspect_ratio,
  343. 1.0 / self.aspect_ratio)
  344. dw = int(np.sqrt(target_area * 1.0 * aspectRatio))
  345. dh = int(np.sqrt(target_area * 1.0 / aspectRatio))
  346. if (np.random.randint(10) < 5):
  347. tmp = dw
  348. dw = dh
  349. dh = tmp
  350. if (dh < img_height and dw < img_width):
  351. h1 = np.random.randint(0, img_height - dh)
  352. w1 = np.random.randint(0, img_width - dw)
  353. sample['image'] = sample['image'][h1:(h1 + dh), w1:(w1 + dw
  354. ), :]
  355. sample['image'] = cv2.resize(
  356. sample['image'], (img_width, img_height),
  357. interpolation=cv2.INTER_LINEAR)
  358. if sample['image'].ndim < 3:
  359. sample['image'] = np.expand_dims(
  360. sample['image'], axis=-1)
  361. if 'mask' in sample:
  362. sample['mask'] = sample['mask'][h1:(h1 + dh), w1:(w1 +
  363. dw)]
  364. sample['mask'] = cv2.resize(
  365. sample['mask'], (img_width, img_height),
  366. interpolation=cv2.INTER_NEAREST)
  367. break
  368. return sample
  369. class Clip(Transform):
  370. """
  371. 对图像上超出一定范围的数据进行截断。
  372. Args:
  373. min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
  374. max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
  375. """
  376. def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
  377. if not (isinstance(min_val, list) and isinstance(max_val, list)):
  378. raise ValueError("{}: input type is invalid.".format(self))
  379. super(Clip, self).__init__()
  380. self.min_val = min_val
  381. self.max_val = max_val
  382. def apply_im(self, image):
  383. for k in range(image.shape[2]):
  384. np.clip(
  385. image[:, :, k],
  386. self.min_val[k],
  387. self.max_val[k],
  388. out=image[:, :, k])
  389. return image
  390. def apply(self, sample):
  391. sample['image'] = self.apply_im(sample['image'])
  392. return sample
  393. class ComposedSegTransforms(Compose):
  394. """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
  395. 训练阶段:
  396. 1. 随机对图像以0.5的概率水平翻转,若random_horizontal_flip为False,则跳过此步骤
  397. 2. 按不同的比例随机Resize原图, 处理方式参考[paddlex.seg.transforms.ResizeRangeScaling](#resizerangescaling)。若min_max_size为None,则跳过此步骤
  398. 3. 从原图中随机crop出大小为train_crop_size大小的子图,如若crop出来的图小于train_crop_size,则会将图padding到对应大小
  399. 4. 图像归一化
  400. 预测阶段:
  401. 1. 将图像的最长边resize至(min_max_size[0] + min_max_size[1])//2, 短边按比例resize。若min_max_size为None,则跳过此步骤
  402. 2. 图像归一化
  403. Args:
  404. mode(str): Transforms所处的阶段,包括`train', 'eval'或'test'
  405. min_max_size(list): 用于对图像进行resize,具体作用参见上述步骤。
  406. train_crop_size(list): 训练过程中随机裁剪原图用于训练,具体作用参见上述步骤。此参数仅在mode为`train`时生效。
  407. mean(list): 图像均值, 默认为[0.485, 0.456, 0.406]。
  408. std(list): 图像方差,默认为[0.229, 0.224, 0.225]。
  409. random_horizontal_flip(bool): 数据增强,是否随机水平翻转图像,此参数仅在mode为`train`时生效。
  410. """
  411. def __init__(self,
  412. mode,
  413. min_max_size=[400, 600],
  414. train_crop_size=[512, 512],
  415. mean=[0.5, 0.5, 0.5],
  416. std=[0.5, 0.5, 0.5],
  417. random_horizontal_flip=True):
  418. if mode == 'train':
  419. # 训练时的transforms,包含数据增强
  420. if min_max_size is None:
  421. transforms = [
  422. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  423. mean=mean, std=std)
  424. ]
  425. else:
  426. transforms = [
  427. ResizeRangeScaling(
  428. min_value=min(min_max_size),
  429. max_value=max(min_max_size)),
  430. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  431. mean=mean, std=std)
  432. ]
  433. if random_horizontal_flip:
  434. transforms.insert(0, RandomHorizontalFlip())
  435. else:
  436. # 验证/预测时的transforms
  437. if min_max_size is None:
  438. transforms = [Normalize(mean=mean, std=std)]
  439. else:
  440. long_size = (min(min_max_size) + max(min_max_size)) // 2
  441. transforms = [
  442. ResizeByLong(long_size=long_size), Normalize(
  443. mean=mean, std=std)
  444. ]
  445. super(ComposedSegTransforms, self).__init__(transforms)
  446. def _resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR):
  447. value = max(im.shape[0], im.shape[1])
  448. scale = float(long_size) / float(value)
  449. resized_width = int(round(im.shape[1] * scale))
  450. resized_height = int(round(im.shape[0] * scale))
  451. im_dims = im.ndim
  452. im = cv2.resize(
  453. im, (resized_width, resized_height), interpolation=interpolation)
  454. if im_dims >= 3 and im.ndim < 3:
  455. im = np.expand_dims(im, axis=-1)
  456. return im