seg_transforms.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. function:
  16. transforms for segmentation in PaddleX<2.0
  17. """
  18. import numpy as np
  19. import cv2
  20. import copy
  21. from .operators import Transform, Compose, RandomHorizontalFlip, RandomVerticalFlip, Resize, \
  22. ResizeByShort, Normalize, RandomDistort, ArrangeSegmenter
  23. from .operators import Padding as dy_Padding
  24. __all__ = [
  25. 'Compose', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Resize',
  26. 'ResizeByShort', 'Normalize', 'RandomDistort', 'ArrangeSegmenter',
  27. 'ResizeByLong', 'ResizeRangeScaling', 'ResizeStepScaling', 'Padding',
  28. 'RandomPaddingCrop', 'RandomBlur', 'RandomRotate', 'RandomScaleAspect',
  29. 'Clip', 'ComposedSegTransforms'
  30. ]
  31. class ResizeByLong(Transform):
  32. """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  33. Args:
  34. long_size (int): resize后图像的长边大小。
  35. """
  36. def __init__(self, long_size=256):
  37. super(ResizeByLong, self).__init__()
  38. self.long_size = long_size
  39. def apply_im(self, image):
  40. image = _resize_long(image, long_size=self.long_size)
  41. return image
  42. def apply_mask(self, mask):
  43. mask = _resize_long(
  44. mask, long_size=self.long_size, interpolation=cv2.INTER_NEAREST)
  45. return mask
  46. def apply(self, sample):
  47. sample['image'] = self.apply_im(sample['image'])
  48. if 'mask' in sample:
  49. sample['mask'] = self.apply_mask(sample['mask'])
  50. return sample
  51. class ResizeRangeScaling(Transform):
  52. """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  53. Args:
  54. min_value (int): 图像长边resize后的最小值。默认值400。
  55. max_value (int): 图像长边resize后的最大值。默认值600。
  56. Raises:
  57. ValueError: min_value大于max_value
  58. """
  59. def __init__(self, min_value=400, max_value=600):
  60. super(ResizeRangeScaling, self).__init__()
  61. if min_value > max_value:
  62. raise ValueError('min_value must be less than max_value, '
  63. 'but they are {} and {}.'.format(min_value,
  64. max_value))
  65. self.min_value = min_value
  66. self.max_value = max_value
  67. def apply_im(self, image, random_size):
  68. image = _resize_long(image, long_size=random_size)
  69. return image
  70. def apply_mask(self, mask, random_size):
  71. mask = _resize_long(
  72. mask, long_size=random_size, interpolation=cv2.INTER_NEAREST)
  73. return mask
  74. def apply(self, sample):
  75. if self.min_value == self.max_value:
  76. random_size = self.max_value
  77. else:
  78. random_size = int(
  79. np.random.uniform(self.min_value, self.max_value) + 0.5)
  80. sample['image'] = self.apply_im(sample['image'], random_size)
  81. if 'mask' in sample:
  82. sample['mask'] = self.apply_mask(sample['mask'], random_size)
  83. return sample
  84. class ResizeStepScaling(Transform):
  85. """对图像按照某一个比例resize,这个比例以scale_step_size为步长
  86. 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
  87. Args:
  88. min_scale_factor(float), resize最小尺度。默认值0.75。
  89. max_scale_factor (float), resize最大尺度。默认值1.25。
  90. scale_step_size (float), resize尺度范围间隔。默认值0.25。
  91. Raises:
  92. ValueError: min_scale_factor大于max_scale_factor
  93. """
  94. def __init__(self,
  95. min_scale_factor=0.75,
  96. max_scale_factor=1.25,
  97. scale_step_size=0.25):
  98. if min_scale_factor > max_scale_factor:
  99. raise ValueError(
  100. 'min_scale_factor must be less than max_scale_factor, '
  101. 'but they are {} and {}.'.format(min_scale_factor,
  102. max_scale_factor))
  103. super(ResizeStepScaling, self).__init__()
  104. self.min_scale_factor = min_scale_factor
  105. self.max_scale_factor = max_scale_factor
  106. self.scale_step_size = scale_step_size
  107. def apply_im(self, image, scale_factor):
  108. image = cv2.resize(
  109. image, (0, 0),
  110. fx=scale_factor,
  111. fy=scale_factor,
  112. interpolation=cv2.INTER_LINEAR)
  113. if image.ndim < 3:
  114. image = np.expand_dims(image, axis=-1)
  115. return image
  116. def apply_mask(self, mask, scale_factor):
  117. mask = cv2.resize(
  118. mask, (0, 0),
  119. fx=scale_factor,
  120. fy=scale_factor,
  121. interpolation=cv2.INTER_NEAREST)
  122. return mask
  123. def apply(self, sample):
  124. if self.min_scale_factor == self.max_scale_factor:
  125. scale_factor = self.min_scale_factor
  126. elif self.scale_step_size == 0:
  127. scale_factor = np.random.uniform(self.min_scale_factor,
  128. self.max_scale_factor)
  129. else:
  130. num_steps = int((self.max_scale_factor - self.min_scale_factor) /
  131. self.scale_step_size + 1)
  132. scale_factors = np.linspace(self.min_scale_factor,
  133. self.max_scale_factor,
  134. num_steps).tolist()
  135. np.random.shuffle(scale_factors)
  136. scale_factor = scale_factors[0]
  137. sample['image'] = self.apply_im(sample['image'], scale_factor)
  138. if 'mask' in sample:
  139. sample['mask'] = self.apply_mask(sample['mask'], scale_factor)
  140. return sample
  141. class Padding(dy_Padding):
  142. """对图像或标注图像进行padding,padding方向为右和下。
  143. 根据提供的值对图像或标注图像进行padding操作。
  144. Args:
  145. target_size (int|list|tuple): padding后图像的大小。
  146. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  147. label_padding_value (int): 标注图像padding的值。默认值为255。
  148. Raises:
  149. TypeError: target_size不是int|list|tuple。
  150. ValueError: target_size为list|tuple时元素个数不等于2。
  151. """
  152. def __init__(self,
  153. target_size,
  154. im_padding_value=[127.5, 127.5, 127.5],
  155. label_padding_value=255):
  156. super(Padding, self).__init__(
  157. target_size=target_size,
  158. pad_mode=0,
  159. offsets=None,
  160. im_padding_value=im_padding_value,
  161. label_padding_value=label_padding_value)
  162. class RandomPaddingCrop(Transform):
  163. """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
  164. Args:
  165. crop_size (int|list|tuple): 裁剪图像大小。默认为512。
  166. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  167. label_padding_value (int): 标注图像padding的值。默认值为255。
  168. Raises:
  169. TypeError: crop_size不是int/list/tuple。
  170. ValueError: target_size为list/tuple时元素个数不等于2。
  171. """
  172. def __init__(self,
  173. crop_size=512,
  174. im_padding_value=[127.5, 127.5, 127.5],
  175. label_padding_value=255):
  176. if isinstance(crop_size, list) or isinstance(crop_size, tuple):
  177. if len(crop_size) != 2:
  178. raise ValueError(
  179. 'when crop_size is list or tuple, it should include 2 elements, but it is {}'
  180. .format(crop_size))
  181. elif not isinstance(crop_size, int):
  182. raise TypeError(
  183. "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}"
  184. .format(type(crop_size)))
  185. super(RandomPaddingCrop, self).__init__()
  186. self.crop_size = crop_size
  187. self.im_padding_value = im_padding_value
  188. self.label_padding_value = label_padding_value
  189. def apply_im(self, image, pad_h, pad_w):
  190. im_h, im_w, im_c = image.shape
  191. orig_im = copy.deepcopy(image)
  192. image = np.zeros(
  193. (im_h + pad_h, im_w + pad_w, im_c)).astype(orig_im.dtype)
  194. for i in range(im_c):
  195. image[:, :, i] = np.pad(orig_im[:, :, i],
  196. pad_width=((0, pad_h), (0, pad_w)),
  197. mode='constant',
  198. constant_values=(self.im_padding_value[i],
  199. self.im_padding_value[i]))
  200. return image
  201. def apply_mask(self, mask, pad_h, pad_w):
  202. mask = np.pad(mask,
  203. pad_width=((0, pad_h), (0, pad_w)),
  204. mode='constant',
  205. constant_values=(self.label_padding_value,
  206. self.label_padding_value))
  207. return mask
  208. def apply(self, sample):
  209. """
  210. Args:
  211. im (np.ndarray): 图像np.ndarray数据。
  212. im_info (list): 存储图像reisze或padding前的shape信息,如
  213. [('resize', [200, 300]), ('padding', [400, 600])]表示
  214. 图像在过resize前shape为(200, 300), 过padding前shape为
  215. (400, 600)
  216. label (np.ndarray): 标注图像np.ndarray数据。
  217. Returns:
  218. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  219. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  220. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  221. """
  222. if isinstance(self.crop_size, int):
  223. crop_width = self.crop_size
  224. crop_height = self.crop_size
  225. else:
  226. crop_width = self.crop_size[0]
  227. crop_height = self.crop_size[1]
  228. im_h, im_w, im_c = sample['image'].shape
  229. if im_h == crop_height and im_w == crop_width:
  230. return sample
  231. else:
  232. pad_height = max(crop_height - im_h, 0)
  233. pad_width = max(crop_width - im_w, 0)
  234. if pad_height > 0 or pad_width > 0:
  235. sample['image'] = self.apply_im(sample['image'], pad_height,
  236. pad_width)
  237. if 'mask' in sample:
  238. sample['mask'] = self.apply_mask(sample['mask'],
  239. pad_height, pad_width)
  240. im_h = sample['image'].shape[0]
  241. im_w = sample['image'].shape[1]
  242. if crop_height > 0 and crop_width > 0:
  243. h_off = np.random.randint(im_h - crop_height + 1)
  244. w_off = np.random.randint(im_w - crop_width + 1)
  245. sample['image'] = sample['image'][h_off:(
  246. crop_height + h_off), w_off:(w_off + crop_width), :]
  247. if 'mask' in sample:
  248. sample['mask'] = sample['mask'][h_off:(
  249. crop_height + h_off), w_off:(w_off + crop_width)]
  250. return sample
  251. class RandomBlur(Transform):
  252. """以一定的概率对图像进行高斯模糊。
  253. Args:
  254. prob (float): 图像模糊概率。默认为0.1。
  255. """
  256. def __init__(self, prob=0.1):
  257. super(RandomBlur, self).__init__()
  258. self.prob = prob
  259. def apply_im(self, image, radius):
  260. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  261. return image
  262. def apply(self, sample):
  263. if self.prob <= 0:
  264. n = 0
  265. elif self.prob >= 1:
  266. n = 1
  267. else:
  268. n = int(1.0 / self.prob)
  269. if n > 0:
  270. if np.random.randint(0, n) == 0:
  271. radius = np.random.randint(3, 10)
  272. if radius % 2 != 1:
  273. radius = radius + 1
  274. if radius > 9:
  275. radius = 9
  276. sample['image'] = self.apply_im(sample['image'], radius)
  277. return sample
  278. class RandomRotate(Transform):
  279. """对图像进行随机旋转, 模型训练时的数据增强操作。
  280. 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
  281. 并对旋转后的图像和标注图像进行相应的padding。
  282. Args:
  283. rotate_range (float): 最大旋转角度。默认为15度。
  284. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  285. label_padding_value (int): 标注图像padding的值。默认为255。
  286. """
  287. def __init__(self,
  288. rotate_range=15,
  289. im_padding_value=[127.5, 127.5, 127.5],
  290. label_padding_value=255):
  291. super(RandomRotate, self).__init__()
  292. self.rotate_range = rotate_range
  293. self.im_padding_value = im_padding_value
  294. self.label_padding_value = label_padding_value
  295. def apply(self, sample):
  296. if self.rotate_range > 0:
  297. h, w, c = sample['image'].shape
  298. do_rotation = np.random.uniform(-self.rotate_range,
  299. self.rotate_range)
  300. pc = (w // 2, h // 2)
  301. r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
  302. cos = np.abs(r[0, 0])
  303. sin = np.abs(r[0, 1])
  304. nw = int((h * sin) + (w * cos))
  305. nh = int((h * cos) + (w * sin))
  306. (cx, cy) = pc
  307. r[0, 2] += (nw / 2) - cx
  308. r[1, 2] += (nh / 2) - cy
  309. dsize = (nw, nh)
  310. rot_ims = list()
  311. for i in range(0, c, 3):
  312. ori_im = sample['image'][:, :, i:i + 3]
  313. rot_im = cv2.warpAffine(
  314. ori_im,
  315. r,
  316. dsize=dsize,
  317. flags=cv2.INTER_LINEAR,
  318. borderMode=cv2.BORDER_CONSTANT,
  319. borderValue=self.im_padding_value[i:i + 3])
  320. rot_ims.append(rot_im)
  321. sample['image'] = np.concatenate(rot_ims, axis=-1)
  322. if 'mask' in sample:
  323. sample['mask'] = cv2.warpAffine(
  324. sample['mask'],
  325. r,
  326. dsize=dsize,
  327. flags=cv2.INTER_NEAREST,
  328. borderMode=cv2.BORDER_CONSTANT,
  329. borderValue=self.label_padding_value)
  330. return sample
  331. class RandomScaleAspect(Transform):
  332. """裁剪并resize回原始尺寸的图像和标注图像。
  333. 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
  334. Args:
  335. min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
  336. aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
  337. """
  338. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  339. super(RandomScaleAspect, self).__init__()
  340. self.min_scale = min_scale
  341. self.aspect_ratio = aspect_ratio
  342. def apply(self, sample):
  343. if self.min_scale != 0 and self.aspect_ratio != 0:
  344. img_height = sample['image'].shape[0]
  345. img_width = sample['image'].shape[1]
  346. for i in range(0, 10):
  347. area = img_height * img_width
  348. target_area = area * np.random.uniform(self.min_scale, 1.0)
  349. aspectRatio = np.random.uniform(self.aspect_ratio,
  350. 1.0 / self.aspect_ratio)
  351. dw = int(np.sqrt(target_area * 1.0 * aspectRatio))
  352. dh = int(np.sqrt(target_area * 1.0 / aspectRatio))
  353. if (np.random.randint(10) < 5):
  354. tmp = dw
  355. dw = dh
  356. dh = tmp
  357. if (dh < img_height and dw < img_width):
  358. h1 = np.random.randint(0, img_height - dh)
  359. w1 = np.random.randint(0, img_width - dw)
  360. sample['image'] = sample['image'][h1:(h1 + dh), w1:(w1 + dw
  361. ), :]
  362. sample['image'] = cv2.resize(
  363. sample['image'], (img_width, img_height),
  364. interpolation=cv2.INTER_LINEAR)
  365. if sample['image'].ndim < 3:
  366. sample['image'] = np.expand_dims(
  367. sample['image'], axis=-1)
  368. if 'mask' in sample:
  369. sample['mask'] = sample['mask'][h1:(h1 + dh), w1:(w1 +
  370. dw)]
  371. sample['mask'] = cv2.resize(
  372. sample['mask'], (img_width, img_height),
  373. interpolation=cv2.INTER_NEAREST)
  374. break
  375. return sample
  376. class Clip(Transform):
  377. """
  378. 对图像上超出一定范围的数据进行截断。
  379. Args:
  380. min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
  381. max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
  382. """
  383. def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
  384. if not (isinstance(min_val, list) and isinstance(max_val, list)):
  385. raise ValueError("{}: input type is invalid.".format(self))
  386. super(Clip, self).__init__()
  387. self.min_val = min_val
  388. self.max_val = max_val
  389. def apply_im(self, image):
  390. for k in range(image.shape[2]):
  391. np.clip(
  392. image[:, :, k],
  393. self.min_val[k],
  394. self.max_val[k],
  395. out=image[:, :, k])
  396. return image
  397. def apply(self, sample):
  398. sample['image'] = self.apply_im(sample['image'])
  399. return sample
  400. class ComposedSegTransforms(Compose):
  401. """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
  402. 训练阶段:
  403. 1. 随机对图像以0.5的概率水平翻转,若random_horizontal_flip为False,则跳过此步骤
  404. 2. 按不同的比例随机Resize原图, 处理方式参考[paddlex.seg.transforms.ResizeRangeScaling](#resizerangescaling)。若min_max_size为None,则跳过此步骤
  405. 3. 从原图中随机crop出大小为train_crop_size大小的子图,如若crop出来的图小于train_crop_size,则会将图padding到对应大小
  406. 4. 图像归一化
  407. 预测阶段:
  408. 1. 将图像的最长边resize至(min_max_size[0] + min_max_size[1])//2, 短边按比例resize。若min_max_size为None,则跳过此步骤
  409. 2. 图像归一化
  410. Args:
  411. mode(str): Transforms所处的阶段,包括`train', 'eval'或'test'
  412. min_max_size(list): 用于对图像进行resize,具体作用参见上述步骤。
  413. train_crop_size(list): 训练过程中随机裁剪原图用于训练,具体作用参见上述步骤。此参数仅在mode为`train`时生效。
  414. mean(list): 图像均值, 默认为[0.485, 0.456, 0.406]。
  415. std(list): 图像方差,默认为[0.229, 0.224, 0.225]。
  416. random_horizontal_flip(bool): 数据增强,是否随机水平翻转图像,此参数仅在mode为`train`时生效。
  417. """
  418. def __init__(self,
  419. mode,
  420. min_max_size=[400, 600],
  421. train_crop_size=[512, 512],
  422. mean=[0.5, 0.5, 0.5],
  423. std=[0.5, 0.5, 0.5],
  424. random_horizontal_flip=True):
  425. if mode == 'train':
  426. # 训练时的transforms,包含数据增强
  427. if min_max_size is None:
  428. transforms = [
  429. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  430. mean=mean, std=std)
  431. ]
  432. else:
  433. transforms = [
  434. ResizeRangeScaling(
  435. min_value=min(min_max_size),
  436. max_value=max(min_max_size)),
  437. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  438. mean=mean, std=std)
  439. ]
  440. if random_horizontal_flip:
  441. transforms.insert(0, RandomHorizontalFlip())
  442. else:
  443. # 验证/预测时的transforms
  444. if min_max_size is None:
  445. transforms = [Normalize(mean=mean, std=std)]
  446. else:
  447. long_size = (min(min_max_size) + max(min_max_size)) // 2
  448. transforms = [
  449. ResizeByLong(long_size=long_size), Normalize(
  450. mean=mean, std=std)
  451. ]
  452. super(ComposedSegTransforms, self).__init__(transforms)
  453. def _resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR):
  454. value = max(im.shape[0], im.shape[1])
  455. scale = float(long_size) / float(value)
  456. resized_width = int(round(im.shape[1] * scale))
  457. resized_height = int(round(im.shape[0] * scale))
  458. im_dims = im.ndim
  459. im = cv2.resize(
  460. im, (resized_width, resized_height), interpolation=interpolation)
  461. if im_dims >= 3 and im.ndim < 3:
  462. im = np.expand_dims(im, axis=-1)
  463. return im