seg_transforms.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import cv2
  16. import copy
  17. from .operators import Transform, Compose, RandomHorizontalFlip, RandomVerticalFlip, Resize, \
  18. ResizeByShort, Normalize, RandomDistort, ArrangeSegmenter
  19. from .operators import Padding as dy_Padding
  20. class ResizeByLong(Transform):
  21. """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  22. Args:
  23. long_size (int): resize后图像的长边大小。
  24. """
  25. def __init__(self, long_size=256):
  26. super(ResizeByLong, self).__init__()
  27. self.long_size = long_size
  28. def apply_im(self, image):
  29. image = _resize_long(image, long_size=self.long_size)
  30. return image
  31. def apply_mask(self, mask):
  32. mask = _resize_long(
  33. mask, long_size=self.long_size, interpolation=cv2.INTER_NEAREST)
  34. return mask
  35. def apply(self, sample):
  36. sample['image'] = self.apply_im(sample['image'])
  37. if 'mask' in sample:
  38. sample['mask'] = self.apply_mask(sample['mask'])
  39. return sample
  40. class ResizeRangeScaling(Transform):
  41. """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  42. Args:
  43. min_value (int): 图像长边resize后的最小值。默认值400。
  44. max_value (int): 图像长边resize后的最大值。默认值600。
  45. Raises:
  46. ValueError: min_value大于max_value
  47. """
  48. def __init__(self, min_value=400, max_value=600):
  49. super(ResizeRangeScaling, self).__init__()
  50. if min_value > max_value:
  51. raise ValueError('min_value must be less than max_value, '
  52. 'but they are {} and {}.'.format(min_value,
  53. max_value))
  54. self.min_value = min_value
  55. self.max_value = max_value
  56. def apply_im(self, image, random_size):
  57. image = _resize_long(image, long_size=random_size)
  58. return image
  59. def apply_mask(self, mask, random_size):
  60. mask = _resize_long(
  61. mask, long_size=random_size, interpolation=cv2.INTER_NEAREST)
  62. return mask
  63. def apply(self, sample):
  64. if self.min_value == self.max_value:
  65. random_size = self.max_value
  66. else:
  67. random_size = int(
  68. np.random.uniform(self.min_value, self.max_value) + 0.5)
  69. sample['image'] = self.apply_im(sample['image'], random_size)
  70. if 'mask' in sample:
  71. sample['mask'] = self.apply_mask(sample['mask'], random_size)
  72. return sample
  73. class ResizeStepScaling(Transform):
  74. """对图像按照某一个比例resize,这个比例以scale_step_size为步长
  75. 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
  76. Args:
  77. min_scale_factor(float), resize最小尺度。默认值0.75。
  78. max_scale_factor (float), resize最大尺度。默认值1.25。
  79. scale_step_size (float), resize尺度范围间隔。默认值0.25。
  80. Raises:
  81. ValueError: min_scale_factor大于max_scale_factor
  82. """
  83. def __init__(self,
  84. min_scale_factor=0.75,
  85. max_scale_factor=1.25,
  86. scale_step_size=0.25):
  87. if min_scale_factor > max_scale_factor:
  88. raise ValueError(
  89. 'min_scale_factor must be less than max_scale_factor, '
  90. 'but they are {} and {}.'.format(min_scale_factor,
  91. max_scale_factor))
  92. super(ResizeStepScaling, self).__init__()
  93. self.min_scale_factor = min_scale_factor
  94. self.max_scale_factor = max_scale_factor
  95. self.scale_step_size = scale_step_size
  96. def apply_im(self, image, scale_factor):
  97. image = cv2.resize(
  98. image, (0, 0),
  99. fx=scale_factor,
  100. fy=scale_factor,
  101. interpolation=cv2.INTER_LINEAR)
  102. if image.ndim < 3:
  103. image = np.expand_dims(image, axis=-1)
  104. return image
  105. def apply_mask(self, mask, scale_factor):
  106. mask = cv2.resize(
  107. mask, (0, 0),
  108. fx=scale_factor,
  109. fy=scale_factor,
  110. interpolation=cv2.INTER_NEAREST)
  111. return mask
  112. def apply(self, sample):
  113. if self.min_scale_factor == self.max_scale_factor:
  114. scale_factor = self.min_scale_factor
  115. elif self.scale_step_size == 0:
  116. scale_factor = np.random.uniform(self.min_scale_factor,
  117. self.max_scale_factor)
  118. else:
  119. num_steps = int((self.max_scale_factor - self.min_scale_factor) /
  120. self.scale_step_size + 1)
  121. scale_factors = np.linspace(self.min_scale_factor,
  122. self.max_scale_factor,
  123. num_steps).tolist()
  124. np.random.shuffle(scale_factors)
  125. scale_factor = scale_factors[0]
  126. sample['image'] = self.apply_im(sample['image'], scale_factor)
  127. if 'mask' in sample:
  128. sample['mask'] = self.apply_mask(sample['mask'], scale_factor)
  129. return sample
  130. class Padding(dy_Padding):
  131. """对图像或标注图像进行padding,padding方向为右和下。
  132. 根据提供的值对图像或标注图像进行padding操作。
  133. Args:
  134. target_size (int|list|tuple): padding后图像的大小。
  135. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  136. label_padding_value (int): 标注图像padding的值。默认值为255。
  137. Raises:
  138. TypeError: target_size不是int|list|tuple。
  139. ValueError: target_size为list|tuple时元素个数不等于2。
  140. """
  141. def __init__(self,
  142. target_size,
  143. im_padding_value=[127.5, 127.5, 127.5],
  144. label_padding_value=255):
  145. super(Padding, self).__init__(
  146. target_size=target_size,
  147. pad_mode=0,
  148. offsets=None,
  149. im_padding_value=im_padding_value,
  150. label_padding_value=label_padding_value)
  151. class RandomPaddingCrop(Transform):
  152. """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
  153. Args:
  154. crop_size (int|list|tuple): 裁剪图像大小。默认为512。
  155. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  156. label_padding_value (int): 标注图像padding的值。默认值为255。
  157. Raises:
  158. TypeError: crop_size不是int/list/tuple。
  159. ValueError: target_size为list/tuple时元素个数不等于2。
  160. """
  161. def __init__(self,
  162. crop_size=512,
  163. im_padding_value=[127.5, 127.5, 127.5],
  164. label_padding_value=255):
  165. if isinstance(crop_size, list) or isinstance(crop_size, tuple):
  166. if len(crop_size) != 2:
  167. raise ValueError(
  168. 'when crop_size is list or tuple, it should include 2 elements, but it is {}'
  169. .format(crop_size))
  170. elif not isinstance(crop_size, int):
  171. raise TypeError(
  172. "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}"
  173. .format(type(crop_size)))
  174. super(RandomPaddingCrop, self).__init__()
  175. self.crop_size = crop_size
  176. self.im_padding_value = im_padding_value
  177. self.label_padding_value = label_padding_value
  178. def apply_im(self, image, pad_h, pad_w):
  179. im_h, im_w, im_c = image.shape
  180. orig_im = copy.deepcopy(image)
  181. image = np.zeros(
  182. (im_h + pad_h, im_w + pad_w, im_c)).astype(orig_im.dtype)
  183. for i in range(im_c):
  184. image[:, :, i] = np.pad(orig_im[:, :, i],
  185. pad_width=((0, pad_h), (0, pad_w)),
  186. mode='constant',
  187. constant_values=(self.im_padding_value[i],
  188. self.im_padding_value[i]))
  189. return image
  190. def apply_mask(self, mask, pad_h, pad_w):
  191. mask = np.pad(mask,
  192. pad_width=((0, pad_h), (0, pad_w)),
  193. mode='constant',
  194. constant_values=(self.label_padding_value,
  195. self.label_padding_value))
  196. return mask
  197. def apply(self, sample):
  198. """
  199. Args:
  200. im (np.ndarray): 图像np.ndarray数据。
  201. im_info (list): 存储图像reisze或padding前的shape信息,如
  202. [('resize', [200, 300]), ('padding', [400, 600])]表示
  203. 图像在过resize前shape为(200, 300), 过padding前shape为
  204. (400, 600)
  205. label (np.ndarray): 标注图像np.ndarray数据。
  206. Returns:
  207. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  208. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  209. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  210. """
  211. if isinstance(self.crop_size, int):
  212. crop_width = self.crop_size
  213. crop_height = self.crop_size
  214. else:
  215. crop_width = self.crop_size[0]
  216. crop_height = self.crop_size[1]
  217. im_h, im_w, im_c = sample['image'].shape
  218. if im_h == crop_height and im_w == crop_width:
  219. return sample
  220. else:
  221. pad_height = max(crop_height - im_h, 0)
  222. pad_width = max(crop_width - im_w, 0)
  223. if pad_height > 0 or pad_width > 0:
  224. sample['image'] = self.apply_im(sample['image'], pad_height,
  225. pad_width)
  226. if 'mask' in sample:
  227. sample['mask'] = self.apply_mask(sample['mask'],
  228. pad_height, pad_width)
  229. im_h = sample['image'].shape[0]
  230. im_w = sample['image'].shape[1]
  231. if crop_height > 0 and crop_width > 0:
  232. h_off = np.random.randint(im_h - crop_height + 1)
  233. w_off = np.random.randint(im_w - crop_width + 1)
  234. sample['image'] = sample['image'][h_off:(
  235. crop_height + h_off), w_off:(w_off + crop_width), :]
  236. if 'mask' in sample:
  237. sample['mask'] = sample['mask'][h_off:(
  238. crop_height + h_off), w_off:(w_off + crop_width)]
  239. return sample
  240. class RandomBlur(Transform):
  241. """以一定的概率对图像进行高斯模糊。
  242. Args:
  243. prob (float): 图像模糊概率。默认为0.1。
  244. """
  245. def __init__(self, prob=0.1):
  246. super(RandomBlur, self).__init__()
  247. self.prob = prob
  248. def apply_im(self, image, radius):
  249. image = cv2.GaussianBlur(image, (radius, radius), 0, 0)
  250. return image
  251. def apply(self, sample):
  252. if self.prob <= 0:
  253. n = 0
  254. elif self.prob >= 1:
  255. n = 1
  256. else:
  257. n = int(1.0 / self.prob)
  258. if n > 0:
  259. if np.random.randint(0, n) == 0:
  260. radius = np.random.randint(3, 10)
  261. if radius % 2 != 1:
  262. radius = radius + 1
  263. if radius > 9:
  264. radius = 9
  265. sample['image'] = self.apply_im(sample['image'], radius)
  266. return sample
  267. class RandomRotate(Transform):
  268. """对图像进行随机旋转, 模型训练时的数据增强操作。
  269. 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
  270. 并对旋转后的图像和标注图像进行相应的padding。
  271. Args:
  272. rotate_range (float): 最大旋转角度。默认为15度。
  273. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  274. label_padding_value (int): 标注图像padding的值。默认为255。
  275. """
  276. def __init__(self,
  277. rotate_range=15,
  278. im_padding_value=[127.5, 127.5, 127.5],
  279. label_padding_value=255):
  280. super(RandomRotate, self).__init__()
  281. self.rotate_range = rotate_range
  282. self.im_padding_value = im_padding_value
  283. self.label_padding_value = label_padding_value
  284. def apply(self, sample):
  285. if self.rotate_range > 0:
  286. h, w, c = sample['image'].shape
  287. do_rotation = np.random.uniform(-self.rotate_range,
  288. self.rotate_range)
  289. pc = (w // 2, h // 2)
  290. r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
  291. cos = np.abs(r[0, 0])
  292. sin = np.abs(r[0, 1])
  293. nw = int((h * sin) + (w * cos))
  294. nh = int((h * cos) + (w * sin))
  295. (cx, cy) = pc
  296. r[0, 2] += (nw / 2) - cx
  297. r[1, 2] += (nh / 2) - cy
  298. dsize = (nw, nh)
  299. rot_ims = list()
  300. for i in range(0, c, 3):
  301. ori_im = sample['image'][:, :, i:i + 3]
  302. rot_im = cv2.warpAffine(
  303. ori_im,
  304. r,
  305. dsize=dsize,
  306. flags=cv2.INTER_LINEAR,
  307. borderMode=cv2.BORDER_CONSTANT,
  308. borderValue=self.im_padding_value[i:i + 3])
  309. rot_ims.append(rot_im)
  310. sample['image'] = np.concatenate(rot_ims, axis=-1)
  311. if 'mask' in sample:
  312. sample['mask'] = cv2.warpAffine(
  313. sample['mask'],
  314. r,
  315. dsize=dsize,
  316. flags=cv2.INTER_NEAREST,
  317. borderMode=cv2.BORDER_CONSTANT,
  318. borderValue=self.label_padding_value)
  319. return sample
  320. class RandomScaleAspect(Transform):
  321. """裁剪并resize回原始尺寸的图像和标注图像。
  322. 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
  323. Args:
  324. min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
  325. aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
  326. """
  327. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  328. super(RandomScaleAspect, self).__init__()
  329. self.min_scale = min_scale
  330. self.aspect_ratio = aspect_ratio
  331. def apply(self, sample):
  332. if self.min_scale != 0 and self.aspect_ratio != 0:
  333. img_height = sample['image'].shape[0]
  334. img_width = sample['image'].shape[1]
  335. for i in range(0, 10):
  336. area = img_height * img_width
  337. target_area = area * np.random.uniform(self.min_scale, 1.0)
  338. aspectRatio = np.random.uniform(self.aspect_ratio,
  339. 1.0 / self.aspect_ratio)
  340. dw = int(np.sqrt(target_area * 1.0 * aspectRatio))
  341. dh = int(np.sqrt(target_area * 1.0 / aspectRatio))
  342. if (np.random.randint(10) < 5):
  343. tmp = dw
  344. dw = dh
  345. dh = tmp
  346. if (dh < img_height and dw < img_width):
  347. h1 = np.random.randint(0, img_height - dh)
  348. w1 = np.random.randint(0, img_width - dw)
  349. sample['image'] = sample['image'][h1:(h1 + dh), w1:(w1 + dw
  350. ), :]
  351. sample['image'] = cv2.resize(
  352. sample['image'], (img_width, img_height),
  353. interpolation=cv2.INTER_LINEAR)
  354. if sample['image'].ndim < 3:
  355. sample['image'] = np.expand_dims(
  356. sample['image'], axis=-1)
  357. if 'mask' in sample:
  358. sample['mask'] = sample['mask'][h1:(h1 + dh), w1:(w1 +
  359. dw)]
  360. sample['mask'] = cv2.resize(
  361. sample['mask'], (img_width, img_height),
  362. interpolation=cv2.INTER_NEAREST)
  363. break
  364. return sample
  365. class Clip(Transform):
  366. """
  367. 对图像上超出一定范围的数据进行截断。
  368. Args:
  369. min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
  370. max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
  371. """
  372. def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
  373. if not (isinstance(min_val, list) and isinstance(max_val, list)):
  374. raise ValueError("{}: input type is invalid.".format(self))
  375. super(Clip, self).__init__()
  376. self.min_val = min_val
  377. self.max_val = max_val
  378. def apply_im(self, image):
  379. for k in range(image.shape[2]):
  380. np.clip(
  381. image[:, :, k],
  382. self.min_val[k],
  383. self.max_val[k],
  384. out=image[:, :, k])
  385. return image
  386. def apply(self, sample):
  387. sample['image'] = self.apply_im(sample['image'])
  388. return sample
  389. class ComposedSegTransforms(Compose):
  390. """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
  391. 训练阶段:
  392. 1. 随机对图像以0.5的概率水平翻转,若random_horizontal_flip为False,则跳过此步骤
  393. 2. 按不同的比例随机Resize原图, 处理方式参考[paddlex.seg.transforms.ResizeRangeScaling](#resizerangescaling)。若min_max_size为None,则跳过此步骤
  394. 3. 从原图中随机crop出大小为train_crop_size大小的子图,如若crop出来的图小于train_crop_size,则会将图padding到对应大小
  395. 4. 图像归一化
  396. 预测阶段:
  397. 1. 将图像的最长边resize至(min_max_size[0] + min_max_size[1])//2, 短边按比例resize。若min_max_size为None,则跳过此步骤
  398. 2. 图像归一化
  399. Args:
  400. mode(str): Transforms所处的阶段,包括`train', 'eval'或'test'
  401. min_max_size(list): 用于对图像进行resize,具体作用参见上述步骤。
  402. train_crop_size(list): 训练过程中随机裁剪原图用于训练,具体作用参见上述步骤。此参数仅在mode为`train`时生效。
  403. mean(list): 图像均值, 默认为[0.485, 0.456, 0.406]。
  404. std(list): 图像方差,默认为[0.229, 0.224, 0.225]。
  405. random_horizontal_flip(bool): 数据增强,是否随机水平翻转图像,此参数仅在mode为`train`时生效。
  406. """
  407. def __init__(self,
  408. mode,
  409. min_max_size=[400, 600],
  410. train_crop_size=[512, 512],
  411. mean=[0.5, 0.5, 0.5],
  412. std=[0.5, 0.5, 0.5],
  413. random_horizontal_flip=True):
  414. if mode == 'train':
  415. # 训练时的transforms,包含数据增强
  416. if min_max_size is None:
  417. transforms = [
  418. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  419. mean=mean, std=std)
  420. ]
  421. else:
  422. transforms = [
  423. ResizeRangeScaling(
  424. min_value=min(min_max_size),
  425. max_value=max(min_max_size)),
  426. RandomPaddingCrop(crop_size=train_crop_size), Normalize(
  427. mean=mean, std=std)
  428. ]
  429. if random_horizontal_flip:
  430. transforms.insert(0, RandomHorizontalFlip())
  431. else:
  432. # 验证/预测时的transforms
  433. if min_max_size is None:
  434. transforms = [Normalize(mean=mean, std=std)]
  435. else:
  436. long_size = (min(min_max_size) + max(min_max_size)) // 2
  437. transforms = [
  438. ResizeByLong(long_size=long_size), Normalize(
  439. mean=mean, std=std)
  440. ]
  441. super(ComposedSegTransforms, self).__init__(transforms)
  442. def _resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR):
  443. value = max(im.shape[0], im.shape[1])
  444. scale = float(long_size) / float(value)
  445. resized_width = int(round(im.shape[1] * scale))
  446. resized_height = int(round(im.shape[0] * scale))
  447. im_dims = im.ndim
  448. im = cv2.resize(
  449. im, (resized_width, resized_height), interpolation=interpolation)
  450. if im_dims >= 3 and im.ndim < 3:
  451. im = np.expand_dims(im, axis=-1)
  452. return im