transforms.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import os.path as osp
  3. import imghdr
  4. import gdal
  5. import numpy as np
  6. from PIL import Image
  7. from paddlex.seg import transforms
  8. def read_img(img_path):
  9. img_format = imghdr.what(img_path)
  10. name, ext = osp.splitext(img_path)
  11. if img_format == 'tiff' or ext == '.img':
  12. dataset = gdal.Open(img_path)
  13. if dataset == None:
  14. raise Exception('Can not open', img_path)
  15. im_data = dataset.ReadAsArray()
  16. return im_data.transpose((1, 2, 0))
  17. elif img_format == 'png':
  18. return np.asarray(Image.open(img_path))
  19. elif ext == '.npy':
  20. return np.load(img_path)
  21. else:
  22. raise Exception('Image format {} is not supported!'.format(ext))
  23. def decode_image(im, label):
  24. if isinstance(im, np.ndarray):
  25. if len(im.shape) != 3:
  26. raise Exception(
  27. "im should be 3-dimensions, but now is {}-dimensions".format(
  28. len(im.shape)))
  29. else:
  30. try:
  31. im = read_img(im)
  32. except:
  33. raise ValueError('Can\'t read The image file {}!'.format(im))
  34. if label is not None:
  35. if not isinstance(label, np.ndarray):
  36. label = read_img(label)
  37. return (im, label)
  38. class Clip(transforms.SegTransform):
  39. """
  40. 对图像上超出一定范围的数据进行裁剪。
  41. Args:
  42. min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
  43. max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
  44. """
  45. def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
  46. self.min_val = min_val
  47. self.max_val = max_val
  48. if not (isinstance(self.min_val, list) and isinstance(self.max_val,
  49. list)):
  50. raise ValueError("{}: input type is invalid.".format(self))
  51. def __call__(self, im, im_info=None, label=None):
  52. for k in range(im.shape[2]):
  53. np.clip(
  54. im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
  55. if label is None:
  56. return (im, im_info)
  57. else:
  58. return (im, im_info, label)