seg_transforms.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from .ops import *
  16. import random
  17. import os.path as osp
  18. import numpy as np
  19. from PIL import Image
  20. import cv2
  21. from collections import OrderedDict
  22. class Compose:
  23. """根据数据预处理/增强算子对输入数据进行操作。
  24. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  25. Args:
  26. transforms (list): 数据预处理/增强算子。
  27. Raises:
  28. TypeError: transforms不是list对象
  29. ValueError: transforms元素个数小于1。
  30. """
  31. def __init__(self, transforms):
  32. if not isinstance(transforms, list):
  33. raise TypeError('The transforms must be a list!')
  34. if len(transforms) < 1:
  35. raise ValueError('The length of transforms ' + \
  36. 'must be equal or larger than 1!')
  37. self.transforms = transforms
  38. self.to_rgb = False
  39. def __call__(self, im, im_info=None, label=None):
  40. """
  41. Args:
  42. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  43. im_info (dict): 存储与图像相关的信息,dict中的字段如下:
  44. - shape_before_resize (tuple): 图像resize之前的大小(h, w)。
  45. - shape_before_padding (tuple): 图像padding之前的大小(h, w)。
  46. label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。
  47. Returns:
  48. tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
  49. """
  50. if im_info is None:
  51. im_info = dict()
  52. try:
  53. im = cv2.imread(im).astype('float32')
  54. except:
  55. raise ValueError('Can\'t read The image file {}!'.format(im))
  56. if self.to_rgb:
  57. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  58. if label is not None:
  59. if not isinstance(label, np.ndarray):
  60. label = np.asarray(Image.open(label))
  61. for op in self.transforms:
  62. outputs = op(im, im_info, label)
  63. im = outputs[0]
  64. if len(outputs) >= 2:
  65. im_info = outputs[1]
  66. if len(outputs) == 3:
  67. label = outputs[2]
  68. return outputs
  69. class RandomHorizontalFlip:
  70. """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
  71. Args:
  72. prob (float): 随机水平翻转的概率。默认值为0.5。
  73. """
  74. def __init__(self, prob=0.5):
  75. self.prob = prob
  76. def __call__(self, im, im_info=None, label=None):
  77. """
  78. Args:
  79. im (np.ndarray): 图像np.ndarray数据。
  80. im_info (dict): 存储与图像相关的信息。
  81. label (np.ndarray): 标注图像np.ndarray数据。
  82. Returns:
  83. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  84. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  85. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  86. """
  87. if random.random() < self.prob:
  88. im = horizontal_flip(im)
  89. if label is not None:
  90. label = horizontal_flip(label)
  91. if label is None:
  92. return (im, im_info)
  93. else:
  94. return (im, im_info, label)
  95. class RandomVerticalFlip:
  96. """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
  97. Args:
  98. prob (float): 随机垂直翻转的概率。默认值为0.1。
  99. """
  100. def __init__(self, prob=0.1):
  101. self.prob = prob
  102. def __call__(self, im, im_info=None, label=None):
  103. """
  104. Args:
  105. im (np.ndarray): 图像np.ndarray数据。
  106. im_info (dict): 存储与图像相关的信息。
  107. label (np.ndarray): 标注图像np.ndarray数据。
  108. Returns:
  109. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  110. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  111. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  112. """
  113. if random.random() < self.prob:
  114. im = vertical_flip(im)
  115. if label is not None:
  116. label = vertical_flip(label)
  117. if label is None:
  118. return (im, im_info)
  119. else:
  120. return (im, im_info, label)
  121. class Resize:
  122. """调整图像大小(resize),当存在标注图像时,则同步进行处理。
  123. - 当目标大小(target_size)类型为int时,根据插值方式,
  124. 将图像resize为[target_size, target_size]。
  125. - 当目标大小(target_size)类型为list或tuple时,根据插值方式,
  126. 将图像resize为target_size, target_size的输入应为[w, h]或(w, h)。
  127. Args:
  128. target_size (int|list|tuple): 目标大小。
  129. interp (str): resize的插值方式,与opencv的插值方式对应,
  130. 可选的值为['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4'],默认为"LINEAR"。
  131. Raises:
  132. TypeError: target_size不是int/list/tuple。
  133. ValueError: target_size为list/tuple时元素个数不等于2。
  134. AssertionError: interp的取值不在['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4']之内。
  135. """
  136. # The interpolation mode
  137. interp_dict = {
  138. 'NEAREST': cv2.INTER_NEAREST,
  139. 'LINEAR': cv2.INTER_LINEAR,
  140. 'CUBIC': cv2.INTER_CUBIC,
  141. 'AREA': cv2.INTER_AREA,
  142. 'LANCZOS4': cv2.INTER_LANCZOS4
  143. }
  144. def __init__(self, target_size, interp='LINEAR'):
  145. self.interp = interp
  146. assert interp in self.interp_dict, "interp should be one of {}".format(
  147. interp_dict.keys())
  148. if isinstance(target_size, list) or isinstance(target_size, tuple):
  149. if len(target_size) != 2:
  150. raise ValueError(
  151. 'when target is list or tuple, it should include 2 elements, but it is {}'
  152. .format(target_size))
  153. elif not isinstance(target_size, int):
  154. raise TypeError(
  155. "Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
  156. .format(type(target_size)))
  157. self.target_size = target_size
  158. def __call__(self, im, im_info=None, label=None):
  159. """
  160. Args:
  161. im (np.ndarray): 图像np.ndarray数据。
  162. im_info (dict): 存储与图像相关的信息。
  163. label (np.ndarray): 标注图像np.ndarray数据。
  164. Returns:
  165. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  166. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  167. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  168. 其中,im_info跟新字段为:
  169. -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
  170. Raises:
  171. ZeroDivisionError: im的短边为0。
  172. TypeError: im不是np.ndarray数据。
  173. ValueError: im不是3维nd.ndarray。
  174. """
  175. if im_info is None:
  176. im_info = OrderedDict()
  177. im_info['shape_before_resize'] = im.shape[:2]
  178. if not isinstance(im, np.ndarray):
  179. raise TypeError("ResizeImage: image type is not np.ndarray.")
  180. if len(im.shape) != 3:
  181. raise ValueError('ResizeImage: image is not 3-dimensional.')
  182. im_shape = im.shape
  183. im_size_min = np.min(im_shape[0:2])
  184. im_size_max = np.max(im_shape[0:2])
  185. if float(im_size_min) == 0:
  186. raise ZeroDivisionError('ResizeImage: min size of image is 0')
  187. if isinstance(self.target_size, int):
  188. resize_w = self.target_size
  189. resize_h = self.target_size
  190. else:
  191. resize_w = self.target_size[0]
  192. resize_h = self.target_size[1]
  193. im_scale_x = float(resize_w) / float(im_shape[1])
  194. im_scale_y = float(resize_h) / float(im_shape[0])
  195. im = cv2.resize(
  196. im,
  197. None,
  198. None,
  199. fx=im_scale_x,
  200. fy=im_scale_y,
  201. interpolation=self.interp_dict[self.interp])
  202. if label is not None:
  203. label = cv2.resize(
  204. label,
  205. None,
  206. None,
  207. fx=im_scale_x,
  208. fy=im_scale_y,
  209. interpolation=self.interp_dict['NEAREST'])
  210. if label is None:
  211. return (im, im_info)
  212. else:
  213. return (im, im_info, label)
  214. class ResizeByLong:
  215. """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  216. Args:
  217. long_size (int): resize后图像的长边大小。
  218. """
  219. def __init__(self, long_size):
  220. self.long_size = long_size
  221. def __call__(self, im, im_info=None, label=None):
  222. """
  223. Args:
  224. im (np.ndarray): 图像np.ndarray数据。
  225. im_info (dict): 存储与图像相关的信息。
  226. label (np.ndarray): 标注图像np.ndarray数据。
  227. Returns:
  228. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  229. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  230. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  231. 其中,im_info新增字段为:
  232. -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
  233. """
  234. if im_info is None:
  235. im_info = OrderedDict()
  236. im_info['shape_before_resize'] = im.shape[:2]
  237. im = resize_long(im, self.long_size)
  238. if label is not None:
  239. label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
  240. if label is None:
  241. return (im, im_info)
  242. else:
  243. return (im, im_info, label)
  244. class ResizeRangeScaling:
  245. """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  246. Args:
  247. min_value (int): 图像长边resize后的最小值。默认值400。
  248. max_value (int): 图像长边resize后的最大值。默认值600。
  249. Raises:
  250. ValueError: min_value大于max_value
  251. """
  252. def __init__(self, min_value=400, max_value=600):
  253. if min_value > max_value:
  254. raise ValueError('min_value must be less than max_value, '
  255. 'but they are {} and {}.'.format(
  256. min_value, max_value))
  257. self.min_value = min_value
  258. self.max_value = max_value
  259. def __call__(self, im, im_info=None, label=None):
  260. """
  261. Args:
  262. im (np.ndarray): 图像np.ndarray数据。
  263. im_info (dict): 存储与图像相关的信息。
  264. label (np.ndarray): 标注图像np.ndarray数据。
  265. Returns:
  266. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  267. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  268. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  269. """
  270. if self.min_value == self.max_value:
  271. random_size = self.max_value
  272. else:
  273. random_size = int(
  274. np.random.uniform(self.min_value, self.max_value) + 0.5)
  275. im = resize_long(im, random_size, cv2.INTER_LINEAR)
  276. if label is not None:
  277. label = resize_long(label, random_size, cv2.INTER_NEAREST)
  278. if label is None:
  279. return (im, im_info)
  280. else:
  281. return (im, im_info, label)
  282. class ResizeStepScaling:
  283. """对图像按照某一个比例resize,这个比例以scale_step_size为步长
  284. 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
  285. Args:
  286. min_scale_factor(float), resize最小尺度。默认值0.75。
  287. max_scale_factor (float), resize最大尺度。默认值1.25。
  288. scale_step_size (float), resize尺度范围间隔。默认值0.25。
  289. Raises:
  290. ValueError: min_scale_factor大于max_scale_factor
  291. """
  292. def __init__(self,
  293. min_scale_factor=0.75,
  294. max_scale_factor=1.25,
  295. scale_step_size=0.25):
  296. if min_scale_factor > max_scale_factor:
  297. raise ValueError(
  298. 'min_scale_factor must be less than max_scale_factor, '
  299. 'but they are {} and {}.'.format(min_scale_factor,
  300. max_scale_factor))
  301. self.min_scale_factor = min_scale_factor
  302. self.max_scale_factor = max_scale_factor
  303. self.scale_step_size = scale_step_size
  304. def __call__(self, im, im_info=None, label=None):
  305. """
  306. Args:
  307. im (np.ndarray): 图像np.ndarray数据。
  308. im_info (dict): 存储与图像相关的信息。
  309. label (np.ndarray): 标注图像np.ndarray数据。
  310. Returns:
  311. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  312. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  313. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  314. """
  315. if self.min_scale_factor == self.max_scale_factor:
  316. scale_factor = self.min_scale_factor
  317. elif self.scale_step_size == 0:
  318. scale_factor = np.random.uniform(self.min_scale_factor,
  319. self.max_scale_factor)
  320. else:
  321. num_steps = int((self.max_scale_factor - self.min_scale_factor) /
  322. self.scale_step_size + 1)
  323. scale_factors = np.linspace(self.min_scale_factor,
  324. self.max_scale_factor,
  325. num_steps).tolist()
  326. np.random.shuffle(scale_factors)
  327. scale_factor = scale_factors[0]
  328. im = cv2.resize(
  329. im, (0, 0),
  330. fx=scale_factor,
  331. fy=scale_factor,
  332. interpolation=cv2.INTER_LINEAR)
  333. if label is not None:
  334. label = cv2.resize(
  335. label, (0, 0),
  336. fx=scale_factor,
  337. fy=scale_factor,
  338. interpolation=cv2.INTER_NEAREST)
  339. if label is None:
  340. return (im, im_info)
  341. else:
  342. return (im, im_info, label)
  343. class Normalize:
  344. """对图像进行标准化。
  345. 1.尺度缩放到 [0,1]。
  346. 2.对图像进行减均值除以标准差操作。
  347. Args:
  348. mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
  349. std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
  350. Raises:
  351. ValueError: mean或std不是list对象。std包含0。
  352. """
  353. def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
  354. self.mean = mean
  355. self.std = std
  356. if not (isinstance(self.mean, list) and isinstance(self.std, list)):
  357. raise ValueError("{}: input type is invalid.".format(self))
  358. from functools import reduce
  359. if reduce(lambda x, y: x * y, self.std) == 0:
  360. raise ValueError('{}: std is invalid!'.format(self))
  361. def __call__(self, im, im_info=None, label=None):
  362. """
  363. Args:
  364. im (np.ndarray): 图像np.ndarray数据。
  365. im_info (dict): 存储与图像相关的信息。
  366. label (np.ndarray): 标注图像np.ndarray数据。
  367. Returns:
  368. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  369. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  370. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  371. """
  372. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  373. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  374. im = normalize(im, mean, std)
  375. if label is None:
  376. return (im, im_info)
  377. else:
  378. return (im, im_info, label)
  379. class Padding:
  380. """对图像或标注图像进行padding,padding方向为右和下。
  381. 根据提供的值对图像或标注图像进行padding操作。
  382. Args:
  383. target_size (int|list|tuple): padding后图像的大小。
  384. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  385. label_padding_value (int): 标注图像padding的值。默认值为255。
  386. Raises:
  387. TypeError: target_size不是int|list|tuple。
  388. ValueError: target_size为list|tuple时元素个数不等于2。
  389. """
  390. def __init__(self,
  391. target_size,
  392. im_padding_value=[127.5, 127.5, 127.5],
  393. label_padding_value=255):
  394. if isinstance(target_size, list) or isinstance(target_size, tuple):
  395. if len(target_size) != 2:
  396. raise ValueError(
  397. 'when target is list or tuple, it should include 2 elements, but it is {}'
  398. .format(target_size))
  399. elif not isinstance(target_size, int):
  400. raise TypeError(
  401. "Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
  402. .format(type(target_size)))
  403. self.target_size = target_size
  404. self.im_padding_value = im_padding_value
  405. self.label_padding_value = label_padding_value
  406. def __call__(self, im, im_info=None, label=None):
  407. """
  408. Args:
  409. im (np.ndarray): 图像np.ndarray数据。
  410. im_info (dict): 存储与图像相关的信息。
  411. label (np.ndarray): 标注图像np.ndarray数据。
  412. Returns:
  413. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  414. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  415. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  416. 其中,im_info新增字段为:
  417. -shape_before_padding (tuple): 保存padding之前图像的形状(h, w)。
  418. Raises:
  419. ValueError: 输入图像im或label的形状大于目标值
  420. """
  421. if im_info is None:
  422. im_info = OrderedDict()
  423. im_info['shape_before_padding'] = im.shape[:2]
  424. im_height, im_width = im.shape[0], im.shape[1]
  425. if isinstance(self.target_size, int):
  426. target_height = self.target_size
  427. target_width = self.target_size
  428. else:
  429. target_height = self.target_size[1]
  430. target_width = self.target_size[0]
  431. pad_height = target_height - im_height
  432. pad_width = target_width - im_width
  433. if pad_height < 0 or pad_width < 0:
  434. raise ValueError(
  435. 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
  436. .format(im_width, im_height, target_width, target_height))
  437. else:
  438. im = cv2.copyMakeBorder(
  439. im,
  440. 0,
  441. pad_height,
  442. 0,
  443. pad_width,
  444. cv2.BORDER_CONSTANT,
  445. value=self.im_padding_value)
  446. if label is not None:
  447. label = cv2.copyMakeBorder(
  448. label,
  449. 0,
  450. pad_height,
  451. 0,
  452. pad_width,
  453. cv2.BORDER_CONSTANT,
  454. value=self.label_padding_value)
  455. if label is None:
  456. return (im, im_info)
  457. else:
  458. return (im, im_info, label)
  459. class RandomPaddingCrop:
  460. """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
  461. Args:
  462. crop_size (int|list|tuple): 裁剪图像大小。默认为512。
  463. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  464. label_padding_value (int): 标注图像padding的值。默认值为255。
  465. Raises:
  466. TypeError: crop_size不是int/list/tuple。
  467. ValueError: target_size为list/tuple时元素个数不等于2。
  468. """
  469. def __init__(self,
  470. crop_size=512,
  471. im_padding_value=[127.5, 127.5, 127.5],
  472. label_padding_value=255):
  473. if isinstance(crop_size, list) or isinstance(crop_size, tuple):
  474. if len(crop_size) != 2:
  475. raise ValueError(
  476. 'when crop_size is list or tuple, it should include 2 elements, but it is {}'
  477. .format(crop_size))
  478. elif not isinstance(crop_size, int):
  479. raise TypeError(
  480. "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}"
  481. .format(type(crop_size)))
  482. self.crop_size = crop_size
  483. self.im_padding_value = im_padding_value
  484. self.label_padding_value = label_padding_value
  485. def __call__(self, im, im_info=None, label=None):
  486. """
  487. Args:
  488. im (np.ndarray): 图像np.ndarray数据。
  489. im_info (dict): 存储与图像相关的信息。
  490. label (np.ndarray): 标注图像np.ndarray数据。
  491. Returns:
  492. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  493. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  494. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  495. """
  496. if isinstance(self.crop_size, int):
  497. crop_width = self.crop_size
  498. crop_height = self.crop_size
  499. else:
  500. crop_width = self.crop_size[0]
  501. crop_height = self.crop_size[1]
  502. img_height = im.shape[0]
  503. img_width = im.shape[1]
  504. if img_height == crop_height and img_width == crop_width:
  505. if label is None:
  506. return (im, im_info)
  507. else:
  508. return (im, im_info, label)
  509. else:
  510. pad_height = max(crop_height - img_height, 0)
  511. pad_width = max(crop_width - img_width, 0)
  512. if (pad_height > 0 or pad_width > 0):
  513. im = cv2.copyMakeBorder(
  514. im,
  515. 0,
  516. pad_height,
  517. 0,
  518. pad_width,
  519. cv2.BORDER_CONSTANT,
  520. value=self.im_padding_value)
  521. if label is not None:
  522. label = cv2.copyMakeBorder(
  523. label,
  524. 0,
  525. pad_height,
  526. 0,
  527. pad_width,
  528. cv2.BORDER_CONSTANT,
  529. value=self.label_padding_value)
  530. img_height = im.shape[0]
  531. img_width = im.shape[1]
  532. if crop_height > 0 and crop_width > 0:
  533. h_off = np.random.randint(img_height - crop_height + 1)
  534. w_off = np.random.randint(img_width - crop_width + 1)
  535. im = im[h_off:(crop_height + h_off), w_off:(
  536. w_off + crop_width), :]
  537. if label is not None:
  538. label = label[h_off:(crop_height + h_off), w_off:(
  539. w_off + crop_width)]
  540. if label is None:
  541. return (im, im_info)
  542. else:
  543. return (im, im_info, label)
  544. class RandomBlur:
  545. """以一定的概率对图像进行高斯模糊。
  546. Args:
  547. prob (float): 图像模糊概率。默认为0.1。
  548. """
  549. def __init__(self, prob=0.1):
  550. self.prob = prob
  551. def __call__(self, im, im_info=None, label=None):
  552. """
  553. Args:
  554. im (np.ndarray): 图像np.ndarray数据。
  555. im_info (dict): 存储与图像相关的信息。
  556. label (np.ndarray): 标注图像np.ndarray数据。
  557. Returns:
  558. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  559. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  560. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  561. """
  562. if self.prob <= 0:
  563. n = 0
  564. elif self.prob >= 1:
  565. n = 1
  566. else:
  567. n = int(1.0 / self.prob)
  568. if n > 0:
  569. if np.random.randint(0, n) == 0:
  570. radius = np.random.randint(3, 10)
  571. if radius % 2 != 1:
  572. radius = radius + 1
  573. if radius > 9:
  574. radius = 9
  575. im = cv2.GaussianBlur(im, (radius, radius), 0, 0)
  576. if label is None:
  577. return (im, im_info)
  578. else:
  579. return (im, im_info, label)
  580. class RandomRotate:
  581. """对图像进行随机旋转, 模型训练时的数据增强操作。
  582. 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
  583. 并对旋转后的图像和标注图像进行相应的padding。
  584. Args:
  585. rotate_range (float): 最大旋转角度。默认为15度。
  586. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  587. label_padding_value (int): 标注图像padding的值。默认为255。
  588. """
  589. def __init__(self,
  590. rotate_range=15,
  591. im_padding_value=[127.5, 127.5, 127.5],
  592. label_padding_value=255):
  593. self.rotate_range = rotate_range
  594. self.im_padding_value = im_padding_value
  595. self.label_padding_value = label_padding_value
  596. def __call__(self, im, im_info=None, label=None):
  597. """
  598. Args:
  599. im (np.ndarray): 图像np.ndarray数据。
  600. im_info (dict): 存储与图像相关的信息。
  601. label (np.ndarray): 标注图像np.ndarray数据。
  602. Returns:
  603. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  604. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  605. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  606. """
  607. if self.rotate_range > 0:
  608. (h, w) = im.shape[:2]
  609. do_rotation = np.random.uniform(-self.rotate_range,
  610. self.rotate_range)
  611. pc = (w // 2, h // 2)
  612. r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
  613. cos = np.abs(r[0, 0])
  614. sin = np.abs(r[0, 1])
  615. nw = int((h * sin) + (w * cos))
  616. nh = int((h * cos) + (w * sin))
  617. (cx, cy) = pc
  618. r[0, 2] += (nw / 2) - cx
  619. r[1, 2] += (nh / 2) - cy
  620. dsize = (nw, nh)
  621. im = cv2.warpAffine(
  622. im,
  623. r,
  624. dsize=dsize,
  625. flags=cv2.INTER_LINEAR,
  626. borderMode=cv2.BORDER_CONSTANT,
  627. borderValue=self.im_padding_value)
  628. label = cv2.warpAffine(
  629. label,
  630. r,
  631. dsize=dsize,
  632. flags=cv2.INTER_NEAREST,
  633. borderMode=cv2.BORDER_CONSTANT,
  634. borderValue=self.label_padding_value)
  635. if label is None:
  636. return (im, im_info)
  637. else:
  638. return (im, im_info, label)
  639. class RandomScaleAspect:
  640. """裁剪并resize回原始尺寸的图像和标注图像。
  641. 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
  642. Args:
  643. min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
  644. aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
  645. """
  646. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  647. self.min_scale = min_scale
  648. self.aspect_ratio = aspect_ratio
  649. def __call__(self, im, im_info=None, label=None):
  650. """
  651. Args:
  652. im (np.ndarray): 图像np.ndarray数据。
  653. im_info (dict): 存储与图像相关的信息。
  654. label (np.ndarray): 标注图像np.ndarray数据。
  655. Returns:
  656. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  657. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  658. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  659. """
  660. if self.min_scale != 0 and self.aspect_ratio != 0:
  661. img_height = im.shape[0]
  662. img_width = im.shape[1]
  663. for i in range(0, 10):
  664. area = img_height * img_width
  665. target_area = area * np.random.uniform(self.min_scale, 1.0)
  666. aspectRatio = np.random.uniform(self.aspect_ratio,
  667. 1.0 / self.aspect_ratio)
  668. dw = int(np.sqrt(target_area * 1.0 * aspectRatio))
  669. dh = int(np.sqrt(target_area * 1.0 / aspectRatio))
  670. if (np.random.randint(10) < 5):
  671. tmp = dw
  672. dw = dh
  673. dh = tmp
  674. if (dh < img_height and dw < img_width):
  675. h1 = np.random.randint(0, img_height - dh)
  676. w1 = np.random.randint(0, img_width - dw)
  677. im = im[h1:(h1 + dh), w1:(w1 + dw), :]
  678. label = label[h1:(h1 + dh), w1:(w1 + dw)]
  679. im = cv2.resize(
  680. im, (img_width, img_height),
  681. interpolation=cv2.INTER_LINEAR)
  682. label = cv2.resize(
  683. label, (img_width, img_height),
  684. interpolation=cv2.INTER_NEAREST)
  685. break
  686. if label is None:
  687. return (im, im_info)
  688. else:
  689. return (im, im_info, label)
  690. class RandomDistort:
  691. """对图像进行随机失真。
  692. 1. 对变换的操作顺序进行随机化操作。
  693. 2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。
  694. Args:
  695. brightness_range (float): 明亮度因子的范围。默认为0.5。
  696. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  697. contrast_range (float): 对比度因子的范围。默认为0.5。
  698. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  699. saturation_range (float): 饱和度因子的范围。默认为0.5。
  700. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  701. hue_range (int): 色调因子的范围。默认为18。
  702. hue_prob (float): 随机调整色调的概率。默认为0.5。
  703. """
  704. def __init__(self,
  705. brightness_range=0.5,
  706. brightness_prob=0.5,
  707. contrast_range=0.5,
  708. contrast_prob=0.5,
  709. saturation_range=0.5,
  710. saturation_prob=0.5,
  711. hue_range=18,
  712. hue_prob=0.5):
  713. self.brightness_range = brightness_range
  714. self.brightness_prob = brightness_prob
  715. self.contrast_range = contrast_range
  716. self.contrast_prob = contrast_prob
  717. self.saturation_range = saturation_range
  718. self.saturation_prob = saturation_prob
  719. self.hue_range = hue_range
  720. self.hue_prob = hue_prob
  721. def __call__(self, im, im_info=None, label=None):
  722. """
  723. Args:
  724. im (np.ndarray): 图像np.ndarray数据。
  725. im_info (dict): 存储与图像相关的信息。
  726. label (np.ndarray): 标注图像np.ndarray数据。
  727. Returns:
  728. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  729. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  730. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  731. """
  732. brightness_lower = 1 - self.brightness_range
  733. brightness_upper = 1 + self.brightness_range
  734. contrast_lower = 1 - self.contrast_range
  735. contrast_upper = 1 + self.contrast_range
  736. saturation_lower = 1 - self.saturation_range
  737. saturation_upper = 1 + self.saturation_range
  738. hue_lower = -self.hue_range
  739. hue_upper = self.hue_range
  740. ops = [brightness, contrast, saturation, hue]
  741. random.shuffle(ops)
  742. params_dict = {
  743. 'brightness': {
  744. 'brightness_lower': brightness_lower,
  745. 'brightness_upper': brightness_upper
  746. },
  747. 'contrast': {
  748. 'contrast_lower': contrast_lower,
  749. 'contrast_upper': contrast_upper
  750. },
  751. 'saturation': {
  752. 'saturation_lower': saturation_lower,
  753. 'saturation_upper': saturation_upper
  754. },
  755. 'hue': {
  756. 'hue_lower': hue_lower,
  757. 'hue_upper': hue_upper
  758. }
  759. }
  760. prob_dict = {
  761. 'brightness': self.brightness_prob,
  762. 'contrast': self.contrast_prob,
  763. 'saturation': self.saturation_prob,
  764. 'hue': self.hue_prob
  765. }
  766. for id in range(4):
  767. params = params_dict[ops[id].__name__]
  768. prob = prob_dict[ops[id].__name__]
  769. params['im'] = im
  770. if np.random.uniform(0, 1) < prob:
  771. im = ops[id](**params)
  772. if label is None:
  773. return (im, im_info)
  774. else:
  775. return (im, im_info, label)
  776. class ArrangeSegmenter:
  777. """获取训练/验证/预测所需的信息。
  778. Args:
  779. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  780. Raises:
  781. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内
  782. """
  783. def __init__(self, mode):
  784. if mode not in ['train', 'eval', 'test', 'quant']:
  785. raise ValueError(
  786. "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
  787. )
  788. self.mode = mode
  789. def __call__(self, im, im_info, label=None):
  790. """
  791. Args:
  792. im (np.ndarray): 图像np.ndarray数据。
  793. im_info (dict): 存储与图像相关的信息。
  794. label (np.ndarray): 标注图像np.ndarray数据。
  795. Returns:
  796. tuple: 当mode为'train'或'eval'时,返回的tuple为(im, label),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  797. 当mode为'test'时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;当mode为
  798. 'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
  799. """
  800. im = permute(im, False)
  801. if self.mode == 'train' or self.mode == 'eval':
  802. label = label[np.newaxis, :, :]
  803. return (im, label)
  804. elif self.mode == 'test':
  805. return (im, im_info)
  806. else:
  807. return (im, )