seg_transforms.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from .ops import *
  16. import random
  17. import os.path as osp
  18. import numpy as np
  19. from PIL import Image
  20. import cv2
  21. from collections import OrderedDict
  22. class Compose:
  23. """根据数据预处理/增强算子对输入数据进行操作。
  24. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  25. Args:
  26. transforms (list): 数据预处理/增强算子。
  27. Raises:
  28. TypeError: transforms不是list对象
  29. ValueError: transforms元素个数小于1。
  30. """
  31. def __init__(self, transforms):
  32. if not isinstance(transforms, list):
  33. raise TypeError('The transforms must be a list!')
  34. if len(transforms) < 1:
  35. raise ValueError('The length of transforms ' + \
  36. 'must be equal or larger than 1!')
  37. self.transforms = transforms
  38. self.to_rgb = False
  39. def __call__(self, im, im_info=None, label=None):
  40. """
  41. Args:
  42. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  43. im_info (list): 存储图像reisze或padding前的shape信息,如
  44. [('resize', [200, 300]), ('padding', [400, 600])]表示
  45. 图像在过resize前shape为(200, 300), 过padding前shape为
  46. (400, 600)
  47. label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。
  48. Returns:
  49. tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
  50. """
  51. if im_info is None:
  52. im_info = list()
  53. try:
  54. im = cv2.imread(im).astype('float32')
  55. except:
  56. raise ValueError('Can\'t read The image file {}!'.format(im))
  57. if self.to_rgb:
  58. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  59. if label is not None:
  60. if not isinstance(label, np.ndarray):
  61. label = np.asarray(Image.open(label))
  62. for op in self.transforms:
  63. outputs = op(im, im_info, label)
  64. im = outputs[0]
  65. if len(outputs) >= 2:
  66. im_info = outputs[1]
  67. if len(outputs) == 3:
  68. label = outputs[2]
  69. return outputs
  70. class RandomHorizontalFlip:
  71. """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
  72. Args:
  73. prob (float): 随机水平翻转的概率。默认值为0.5。
  74. """
  75. def __init__(self, prob=0.5):
  76. self.prob = prob
  77. def __call__(self, im, im_info=None, label=None):
  78. """
  79. Args:
  80. im (np.ndarray): 图像np.ndarray数据。
  81. im_info (list): 存储图像reisze或padding前的shape信息,如
  82. [('resize', [200, 300]), ('padding', [400, 600])]表示
  83. 图像在过resize前shape为(200, 300), 过padding前shape为
  84. (400, 600)
  85. label (np.ndarray): 标注图像np.ndarray数据。
  86. Returns:
  87. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  88. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  89. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  90. """
  91. if random.random() < self.prob:
  92. im = horizontal_flip(im)
  93. if label is not None:
  94. label = horizontal_flip(label)
  95. if label is None:
  96. return (im, im_info)
  97. else:
  98. return (im, im_info, label)
  99. class RandomVerticalFlip:
  100. """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
  101. Args:
  102. prob (float): 随机垂直翻转的概率。默认值为0.1。
  103. """
  104. def __init__(self, prob=0.1):
  105. self.prob = prob
  106. def __call__(self, im, im_info=None, label=None):
  107. """
  108. Args:
  109. im (np.ndarray): 图像np.ndarray数据。
  110. im_info (list): 存储图像reisze或padding前的shape信息,如
  111. [('resize', [200, 300]), ('padding', [400, 600])]表示
  112. 图像在过resize前shape为(200, 300), 过padding前shape为
  113. (400, 600)
  114. label (np.ndarray): 标注图像np.ndarray数据。
  115. Returns:
  116. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  117. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  118. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  119. """
  120. if random.random() < self.prob:
  121. im = vertical_flip(im)
  122. if label is not None:
  123. label = vertical_flip(label)
  124. if label is None:
  125. return (im, im_info)
  126. else:
  127. return (im, im_info, label)
  128. class Resize:
  129. """调整图像大小(resize),当存在标注图像时,则同步进行处理。
  130. - 当目标大小(target_size)类型为int时,根据插值方式,
  131. 将图像resize为[target_size, target_size]。
  132. - 当目标大小(target_size)类型为list或tuple时,根据插值方式,
  133. 将图像resize为target_size, target_size的输入应为[w, h]或(w, h)。
  134. Args:
  135. target_size (int|list|tuple): 目标大小。
  136. interp (str): resize的插值方式,与opencv的插值方式对应,
  137. 可选的值为['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4'],默认为"LINEAR"。
  138. Raises:
  139. TypeError: target_size不是int/list/tuple。
  140. ValueError: target_size为list/tuple时元素个数不等于2。
  141. AssertionError: interp的取值不在['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4']之内。
  142. """
  143. # The interpolation mode
  144. interp_dict = {
  145. 'NEAREST': cv2.INTER_NEAREST,
  146. 'LINEAR': cv2.INTER_LINEAR,
  147. 'CUBIC': cv2.INTER_CUBIC,
  148. 'AREA': cv2.INTER_AREA,
  149. 'LANCZOS4': cv2.INTER_LANCZOS4
  150. }
  151. def __init__(self, target_size, interp='LINEAR'):
  152. self.interp = interp
  153. assert interp in self.interp_dict, "interp should be one of {}".format(
  154. interp_dict.keys())
  155. if isinstance(target_size, list) or isinstance(target_size, tuple):
  156. if len(target_size) != 2:
  157. raise ValueError(
  158. 'when target is list or tuple, it should include 2 elements, but it is {}'
  159. .format(target_size))
  160. elif not isinstance(target_size, int):
  161. raise TypeError(
  162. "Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
  163. .format(type(target_size)))
  164. self.target_size = target_size
  165. def __call__(self, im, im_info=None, label=None):
  166. """
  167. Args:
  168. im (np.ndarray): 图像np.ndarray数据。
  169. im_info (list): 存储图像reisze或padding前的shape信息,如
  170. [('resize', [200, 300]), ('padding', [400, 600])]表示
  171. 图像在过resize前shape为(200, 300), 过padding前shape为
  172. (400, 600)
  173. label (np.ndarray): 标注图像np.ndarray数据。
  174. Returns:
  175. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  176. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  177. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  178. 其中,im_info跟新字段为:
  179. -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
  180. Raises:
  181. ZeroDivisionError: im的短边为0。
  182. TypeError: im不是np.ndarray数据。
  183. ValueError: im不是3维nd.ndarray。
  184. """
  185. if im_info is None:
  186. im_info = OrderedDict()
  187. im_info.append(('resize', im.shape[:2]))
  188. if not isinstance(im, np.ndarray):
  189. raise TypeError("ResizeImage: image type is not np.ndarray.")
  190. if len(im.shape) != 3:
  191. raise ValueError('ResizeImage: image is not 3-dimensional.')
  192. im_shape = im.shape
  193. im_size_min = np.min(im_shape[0:2])
  194. im_size_max = np.max(im_shape[0:2])
  195. if float(im_size_min) == 0:
  196. raise ZeroDivisionError('ResizeImage: min size of image is 0')
  197. if isinstance(self.target_size, int):
  198. resize_w = self.target_size
  199. resize_h = self.target_size
  200. else:
  201. resize_w = self.target_size[0]
  202. resize_h = self.target_size[1]
  203. im_scale_x = float(resize_w) / float(im_shape[1])
  204. im_scale_y = float(resize_h) / float(im_shape[0])
  205. im = cv2.resize(
  206. im,
  207. None,
  208. None,
  209. fx=im_scale_x,
  210. fy=im_scale_y,
  211. interpolation=self.interp_dict[self.interp])
  212. if label is not None:
  213. label = cv2.resize(
  214. label,
  215. None,
  216. None,
  217. fx=im_scale_x,
  218. fy=im_scale_y,
  219. interpolation=self.interp_dict['NEAREST'])
  220. if label is None:
  221. return (im, im_info)
  222. else:
  223. return (im, im_info, label)
  224. class ResizeByLong:
  225. """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  226. Args:
  227. long_size (int): resize后图像的长边大小。
  228. """
  229. def __init__(self, long_size):
  230. self.long_size = long_size
  231. def __call__(self, im, im_info=None, label=None):
  232. """
  233. Args:
  234. im (np.ndarray): 图像np.ndarray数据。
  235. im_info (list): 存储图像reisze或padding前的shape信息,如
  236. [('resize', [200, 300]), ('padding', [400, 600])]表示
  237. 图像在过resize前shape为(200, 300), 过padding前shape为
  238. (400, 600)
  239. label (np.ndarray): 标注图像np.ndarray数据。
  240. Returns:
  241. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  242. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  243. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  244. 其中,im_info新增字段为:
  245. -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
  246. """
  247. if im_info is None:
  248. im_info = OrderedDict()
  249. im_info.append(('resize', im.shape[:2]))
  250. im = resize_long(im, self.long_size)
  251. if label is not None:
  252. label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
  253. if label is None:
  254. return (im, im_info)
  255. else:
  256. return (im, im_info, label)
  257. class ResizeRangeScaling:
  258. """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
  259. Args:
  260. min_value (int): 图像长边resize后的最小值。默认值400。
  261. max_value (int): 图像长边resize后的最大值。默认值600。
  262. Raises:
  263. ValueError: min_value大于max_value
  264. """
  265. def __init__(self, min_value=400, max_value=600):
  266. if min_value > max_value:
  267. raise ValueError('min_value must be less than max_value, '
  268. 'but they are {} and {}.'.format(
  269. min_value, max_value))
  270. self.min_value = min_value
  271. self.max_value = max_value
  272. def __call__(self, im, im_info=None, label=None):
  273. """
  274. Args:
  275. im (np.ndarray): 图像np.ndarray数据。
  276. im_info (list): 存储图像reisze或padding前的shape信息,如
  277. [('resize', [200, 300]), ('padding', [400, 600])]表示
  278. 图像在过resize前shape为(200, 300), 过padding前shape为
  279. (400, 600)
  280. label (np.ndarray): 标注图像np.ndarray数据。
  281. Returns:
  282. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  283. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  284. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  285. """
  286. if self.min_value == self.max_value:
  287. random_size = self.max_value
  288. else:
  289. random_size = int(
  290. np.random.uniform(self.min_value, self.max_value) + 0.5)
  291. im = resize_long(im, random_size, cv2.INTER_LINEAR)
  292. if label is not None:
  293. label = resize_long(label, random_size, cv2.INTER_NEAREST)
  294. if label is None:
  295. return (im, im_info)
  296. else:
  297. return (im, im_info, label)
  298. class ResizeStepScaling:
  299. """对图像按照某一个比例resize,这个比例以scale_step_size为步长
  300. 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
  301. Args:
  302. min_scale_factor(float), resize最小尺度。默认值0.75。
  303. max_scale_factor (float), resize最大尺度。默认值1.25。
  304. scale_step_size (float), resize尺度范围间隔。默认值0.25。
  305. Raises:
  306. ValueError: min_scale_factor大于max_scale_factor
  307. """
  308. def __init__(self,
  309. min_scale_factor=0.75,
  310. max_scale_factor=1.25,
  311. scale_step_size=0.25):
  312. if min_scale_factor > max_scale_factor:
  313. raise ValueError(
  314. 'min_scale_factor must be less than max_scale_factor, '
  315. 'but they are {} and {}.'.format(min_scale_factor,
  316. max_scale_factor))
  317. self.min_scale_factor = min_scale_factor
  318. self.max_scale_factor = max_scale_factor
  319. self.scale_step_size = scale_step_size
  320. def __call__(self, im, im_info=None, label=None):
  321. """
  322. Args:
  323. im (np.ndarray): 图像np.ndarray数据。
  324. im_info (list): 存储图像reisze或padding前的shape信息,如
  325. [('resize', [200, 300]), ('padding', [400, 600])]表示
  326. 图像在过resize前shape为(200, 300), 过padding前shape为
  327. (400, 600)
  328. label (np.ndarray): 标注图像np.ndarray数据。
  329. Returns:
  330. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  331. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  332. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  333. """
  334. if self.min_scale_factor == self.max_scale_factor:
  335. scale_factor = self.min_scale_factor
  336. elif self.scale_step_size == 0:
  337. scale_factor = np.random.uniform(self.min_scale_factor,
  338. self.max_scale_factor)
  339. else:
  340. num_steps = int((self.max_scale_factor - self.min_scale_factor) /
  341. self.scale_step_size + 1)
  342. scale_factors = np.linspace(self.min_scale_factor,
  343. self.max_scale_factor,
  344. num_steps).tolist()
  345. np.random.shuffle(scale_factors)
  346. scale_factor = scale_factors[0]
  347. im = cv2.resize(
  348. im, (0, 0),
  349. fx=scale_factor,
  350. fy=scale_factor,
  351. interpolation=cv2.INTER_LINEAR)
  352. if label is not None:
  353. label = cv2.resize(
  354. label, (0, 0),
  355. fx=scale_factor,
  356. fy=scale_factor,
  357. interpolation=cv2.INTER_NEAREST)
  358. if label is None:
  359. return (im, im_info)
  360. else:
  361. return (im, im_info, label)
  362. class Normalize:
  363. """对图像进行标准化。
  364. 1.尺度缩放到 [0,1]。
  365. 2.对图像进行减均值除以标准差操作。
  366. Args:
  367. mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
  368. std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
  369. Raises:
  370. ValueError: mean或std不是list对象。std包含0。
  371. """
  372. def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
  373. self.mean = mean
  374. self.std = std
  375. if not (isinstance(self.mean, list) and isinstance(self.std, list)):
  376. raise ValueError("{}: input type is invalid.".format(self))
  377. from functools import reduce
  378. if reduce(lambda x, y: x * y, self.std) == 0:
  379. raise ValueError('{}: std is invalid!'.format(self))
  380. def __call__(self, im, im_info=None, label=None):
  381. """
  382. Args:
  383. im (np.ndarray): 图像np.ndarray数据。
  384. im_info (list): 存储图像reisze或padding前的shape信息,如
  385. [('resize', [200, 300]), ('padding', [400, 600])]表示
  386. 图像在过resize前shape为(200, 300), 过padding前shape为
  387. (400, 600)
  388. label (np.ndarray): 标注图像np.ndarray数据。
  389. Returns:
  390. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  391. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  392. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  393. """
  394. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  395. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  396. im = normalize(im, mean, std)
  397. if label is None:
  398. return (im, im_info)
  399. else:
  400. return (im, im_info, label)
  401. class Padding:
  402. """对图像或标注图像进行padding,padding方向为右和下。
  403. 根据提供的值对图像或标注图像进行padding操作。
  404. Args:
  405. target_size (int|list|tuple): padding后图像的大小。
  406. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  407. label_padding_value (int): 标注图像padding的值。默认值为255。
  408. Raises:
  409. TypeError: target_size不是int|list|tuple。
  410. ValueError: target_size为list|tuple时元素个数不等于2。
  411. """
  412. def __init__(self,
  413. target_size,
  414. im_padding_value=[127.5, 127.5, 127.5],
  415. label_padding_value=255):
  416. if isinstance(target_size, list) or isinstance(target_size, tuple):
  417. if len(target_size) != 2:
  418. raise ValueError(
  419. 'when target is list or tuple, it should include 2 elements, but it is {}'
  420. .format(target_size))
  421. elif not isinstance(target_size, int):
  422. raise TypeError(
  423. "Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
  424. .format(type(target_size)))
  425. self.target_size = target_size
  426. self.im_padding_value = im_padding_value
  427. self.label_padding_value = label_padding_value
  428. def __call__(self, im, im_info=None, label=None):
  429. """
  430. Args:
  431. im (np.ndarray): 图像np.ndarray数据。
  432. im_info (list): 存储图像reisze或padding前的shape信息,如
  433. [('resize', [200, 300]), ('padding', [400, 600])]表示
  434. 图像在过resize前shape为(200, 300), 过padding前shape为
  435. (400, 600)
  436. label (np.ndarray): 标注图像np.ndarray数据。
  437. Returns:
  438. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  439. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  440. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  441. 其中,im_info新增字段为:
  442. -shape_before_padding (tuple): 保存padding之前图像的形状(h, w)。
  443. Raises:
  444. ValueError: 输入图像im或label的形状大于目标值
  445. """
  446. if im_info is None:
  447. im_info = OrderedDict()
  448. im_info.append(('padding', im.shape[:2]))
  449. im_height, im_width = im.shape[0], im.shape[1]
  450. if isinstance(self.target_size, int):
  451. target_height = self.target_size
  452. target_width = self.target_size
  453. else:
  454. target_height = self.target_size[1]
  455. target_width = self.target_size[0]
  456. pad_height = target_height - im_height
  457. pad_width = target_width - im_width
  458. if pad_height < 0 or pad_width < 0:
  459. raise ValueError(
  460. 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
  461. .format(im_width, im_height, target_width, target_height))
  462. else:
  463. im = cv2.copyMakeBorder(
  464. im,
  465. 0,
  466. pad_height,
  467. 0,
  468. pad_width,
  469. cv2.BORDER_CONSTANT,
  470. value=self.im_padding_value)
  471. if label is not None:
  472. label = cv2.copyMakeBorder(
  473. label,
  474. 0,
  475. pad_height,
  476. 0,
  477. pad_width,
  478. cv2.BORDER_CONSTANT,
  479. value=self.label_padding_value)
  480. if label is None:
  481. return (im, im_info)
  482. else:
  483. return (im, im_info, label)
  484. class RandomPaddingCrop:
  485. """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
  486. Args:
  487. crop_size (int|list|tuple): 裁剪图像大小。默认为512。
  488. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  489. label_padding_value (int): 标注图像padding的值。默认值为255。
  490. Raises:
  491. TypeError: crop_size不是int/list/tuple。
  492. ValueError: target_size为list/tuple时元素个数不等于2。
  493. """
  494. def __init__(self,
  495. crop_size=512,
  496. im_padding_value=[127.5, 127.5, 127.5],
  497. label_padding_value=255):
  498. if isinstance(crop_size, list) or isinstance(crop_size, tuple):
  499. if len(crop_size) != 2:
  500. raise ValueError(
  501. 'when crop_size is list or tuple, it should include 2 elements, but it is {}'
  502. .format(crop_size))
  503. elif not isinstance(crop_size, int):
  504. raise TypeError(
  505. "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}"
  506. .format(type(crop_size)))
  507. self.crop_size = crop_size
  508. self.im_padding_value = im_padding_value
  509. self.label_padding_value = label_padding_value
  510. def __call__(self, im, im_info=None, label=None):
  511. """
  512. Args:
  513. im (np.ndarray): 图像np.ndarray数据。
  514. im_info (list): 存储图像reisze或padding前的shape信息,如
  515. [('resize', [200, 300]), ('padding', [400, 600])]表示
  516. 图像在过resize前shape为(200, 300), 过padding前shape为
  517. (400, 600)
  518. label (np.ndarray): 标注图像np.ndarray数据。
  519. Returns:
  520. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  521. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  522. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  523. """
  524. if isinstance(self.crop_size, int):
  525. crop_width = self.crop_size
  526. crop_height = self.crop_size
  527. else:
  528. crop_width = self.crop_size[0]
  529. crop_height = self.crop_size[1]
  530. img_height = im.shape[0]
  531. img_width = im.shape[1]
  532. if img_height == crop_height and img_width == crop_width:
  533. if label is None:
  534. return (im, im_info)
  535. else:
  536. return (im, im_info, label)
  537. else:
  538. pad_height = max(crop_height - img_height, 0)
  539. pad_width = max(crop_width - img_width, 0)
  540. if (pad_height > 0 or pad_width > 0):
  541. im = cv2.copyMakeBorder(
  542. im,
  543. 0,
  544. pad_height,
  545. 0,
  546. pad_width,
  547. cv2.BORDER_CONSTANT,
  548. value=self.im_padding_value)
  549. if label is not None:
  550. label = cv2.copyMakeBorder(
  551. label,
  552. 0,
  553. pad_height,
  554. 0,
  555. pad_width,
  556. cv2.BORDER_CONSTANT,
  557. value=self.label_padding_value)
  558. img_height = im.shape[0]
  559. img_width = im.shape[1]
  560. if crop_height > 0 and crop_width > 0:
  561. h_off = np.random.randint(img_height - crop_height + 1)
  562. w_off = np.random.randint(img_width - crop_width + 1)
  563. im = im[h_off:(crop_height + h_off), w_off:(
  564. w_off + crop_width), :]
  565. if label is not None:
  566. label = label[h_off:(crop_height + h_off), w_off:(
  567. w_off + crop_width)]
  568. if label is None:
  569. return (im, im_info)
  570. else:
  571. return (im, im_info, label)
  572. class RandomBlur:
  573. """以一定的概率对图像进行高斯模糊。
  574. Args:
  575. prob (float): 图像模糊概率。默认为0.1。
  576. """
  577. def __init__(self, prob=0.1):
  578. self.prob = prob
  579. def __call__(self, im, im_info=None, label=None):
  580. """
  581. Args:
  582. im (np.ndarray): 图像np.ndarray数据。
  583. im_info (list): 存储图像reisze或padding前的shape信息,如
  584. [('resize', [200, 300]), ('padding', [400, 600])]表示
  585. 图像在过resize前shape为(200, 300), 过padding前shape为
  586. (400, 600)
  587. label (np.ndarray): 标注图像np.ndarray数据。
  588. Returns:
  589. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  590. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  591. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  592. """
  593. if self.prob <= 0:
  594. n = 0
  595. elif self.prob >= 1:
  596. n = 1
  597. else:
  598. n = int(1.0 / self.prob)
  599. if n > 0:
  600. if np.random.randint(0, n) == 0:
  601. radius = np.random.randint(3, 10)
  602. if radius % 2 != 1:
  603. radius = radius + 1
  604. if radius > 9:
  605. radius = 9
  606. im = cv2.GaussianBlur(im, (radius, radius), 0, 0)
  607. if label is None:
  608. return (im, im_info)
  609. else:
  610. return (im, im_info, label)
  611. class RandomRotate:
  612. """对图像进行随机旋转, 模型训练时的数据增强操作。
  613. 在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
  614. 并对旋转后的图像和标注图像进行相应的padding。
  615. Args:
  616. rotate_range (float): 最大旋转角度。默认为15度。
  617. im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
  618. label_padding_value (int): 标注图像padding的值。默认为255。
  619. """
  620. def __init__(self,
  621. rotate_range=15,
  622. im_padding_value=[127.5, 127.5, 127.5],
  623. label_padding_value=255):
  624. self.rotate_range = rotate_range
  625. self.im_padding_value = im_padding_value
  626. self.label_padding_value = label_padding_value
  627. def __call__(self, im, im_info=None, label=None):
  628. """
  629. Args:
  630. im (np.ndarray): 图像np.ndarray数据。
  631. im_info (list): 存储图像reisze或padding前的shape信息,如
  632. [('resize', [200, 300]), ('padding', [400, 600])]表示
  633. 图像在过resize前shape为(200, 300), 过padding前shape为
  634. (400, 600)
  635. label (np.ndarray): 标注图像np.ndarray数据。
  636. Returns:
  637. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  638. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  639. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  640. """
  641. if self.rotate_range > 0:
  642. (h, w) = im.shape[:2]
  643. do_rotation = np.random.uniform(-self.rotate_range,
  644. self.rotate_range)
  645. pc = (w // 2, h // 2)
  646. r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
  647. cos = np.abs(r[0, 0])
  648. sin = np.abs(r[0, 1])
  649. nw = int((h * sin) + (w * cos))
  650. nh = int((h * cos) + (w * sin))
  651. (cx, cy) = pc
  652. r[0, 2] += (nw / 2) - cx
  653. r[1, 2] += (nh / 2) - cy
  654. dsize = (nw, nh)
  655. im = cv2.warpAffine(
  656. im,
  657. r,
  658. dsize=dsize,
  659. flags=cv2.INTER_LINEAR,
  660. borderMode=cv2.BORDER_CONSTANT,
  661. borderValue=self.im_padding_value)
  662. label = cv2.warpAffine(
  663. label,
  664. r,
  665. dsize=dsize,
  666. flags=cv2.INTER_NEAREST,
  667. borderMode=cv2.BORDER_CONSTANT,
  668. borderValue=self.label_padding_value)
  669. if label is None:
  670. return (im, im_info)
  671. else:
  672. return (im, im_info, label)
  673. class RandomScaleAspect:
  674. """裁剪并resize回原始尺寸的图像和标注图像。
  675. 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
  676. Args:
  677. min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
  678. aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
  679. """
  680. def __init__(self, min_scale=0.5, aspect_ratio=0.33):
  681. self.min_scale = min_scale
  682. self.aspect_ratio = aspect_ratio
  683. def __call__(self, im, im_info=None, label=None):
  684. """
  685. Args:
  686. im (np.ndarray): 图像np.ndarray数据。
  687. im_info (list): 存储图像reisze或padding前的shape信息,如
  688. [('resize', [200, 300]), ('padding', [400, 600])]表示
  689. 图像在过resize前shape为(200, 300), 过padding前shape为
  690. (400, 600)
  691. label (np.ndarray): 标注图像np.ndarray数据。
  692. Returns:
  693. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  694. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  695. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  696. """
  697. if self.min_scale != 0 and self.aspect_ratio != 0:
  698. img_height = im.shape[0]
  699. img_width = im.shape[1]
  700. for i in range(0, 10):
  701. area = img_height * img_width
  702. target_area = area * np.random.uniform(self.min_scale, 1.0)
  703. aspectRatio = np.random.uniform(self.aspect_ratio,
  704. 1.0 / self.aspect_ratio)
  705. dw = int(np.sqrt(target_area * 1.0 * aspectRatio))
  706. dh = int(np.sqrt(target_area * 1.0 / aspectRatio))
  707. if (np.random.randint(10) < 5):
  708. tmp = dw
  709. dw = dh
  710. dh = tmp
  711. if (dh < img_height and dw < img_width):
  712. h1 = np.random.randint(0, img_height - dh)
  713. w1 = np.random.randint(0, img_width - dw)
  714. im = im[h1:(h1 + dh), w1:(w1 + dw), :]
  715. label = label[h1:(h1 + dh), w1:(w1 + dw)]
  716. im = cv2.resize(
  717. im, (img_width, img_height),
  718. interpolation=cv2.INTER_LINEAR)
  719. label = cv2.resize(
  720. label, (img_width, img_height),
  721. interpolation=cv2.INTER_NEAREST)
  722. break
  723. if label is None:
  724. return (im, im_info)
  725. else:
  726. return (im, im_info, label)
  727. class RandomDistort:
  728. """对图像进行随机失真。
  729. 1. 对变换的操作顺序进行随机化操作。
  730. 2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。
  731. Args:
  732. brightness_range (float): 明亮度因子的范围。默认为0.5。
  733. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  734. contrast_range (float): 对比度因子的范围。默认为0.5。
  735. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  736. saturation_range (float): 饱和度因子的范围。默认为0.5。
  737. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  738. hue_range (int): 色调因子的范围。默认为18。
  739. hue_prob (float): 随机调整色调的概率。默认为0.5。
  740. """
  741. def __init__(self,
  742. brightness_range=0.5,
  743. brightness_prob=0.5,
  744. contrast_range=0.5,
  745. contrast_prob=0.5,
  746. saturation_range=0.5,
  747. saturation_prob=0.5,
  748. hue_range=18,
  749. hue_prob=0.5):
  750. self.brightness_range = brightness_range
  751. self.brightness_prob = brightness_prob
  752. self.contrast_range = contrast_range
  753. self.contrast_prob = contrast_prob
  754. self.saturation_range = saturation_range
  755. self.saturation_prob = saturation_prob
  756. self.hue_range = hue_range
  757. self.hue_prob = hue_prob
  758. def __call__(self, im, im_info=None, label=None):
  759. """
  760. Args:
  761. im (np.ndarray): 图像np.ndarray数据。
  762. im_info (list): 存储图像reisze或padding前的shape信息,如
  763. [('resize', [200, 300]), ('padding', [400, 600])]表示
  764. 图像在过resize前shape为(200, 300), 过padding前shape为
  765. (400, 600)
  766. label (np.ndarray): 标注图像np.ndarray数据。
  767. Returns:
  768. tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  769. 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
  770. 存储与图像相关信息的字典和标注图像np.ndarray数据。
  771. """
  772. brightness_lower = 1 - self.brightness_range
  773. brightness_upper = 1 + self.brightness_range
  774. contrast_lower = 1 - self.contrast_range
  775. contrast_upper = 1 + self.contrast_range
  776. saturation_lower = 1 - self.saturation_range
  777. saturation_upper = 1 + self.saturation_range
  778. hue_lower = -self.hue_range
  779. hue_upper = self.hue_range
  780. ops = [brightness, contrast, saturation, hue]
  781. random.shuffle(ops)
  782. params_dict = {
  783. 'brightness': {
  784. 'brightness_lower': brightness_lower,
  785. 'brightness_upper': brightness_upper
  786. },
  787. 'contrast': {
  788. 'contrast_lower': contrast_lower,
  789. 'contrast_upper': contrast_upper
  790. },
  791. 'saturation': {
  792. 'saturation_lower': saturation_lower,
  793. 'saturation_upper': saturation_upper
  794. },
  795. 'hue': {
  796. 'hue_lower': hue_lower,
  797. 'hue_upper': hue_upper
  798. }
  799. }
  800. prob_dict = {
  801. 'brightness': self.brightness_prob,
  802. 'contrast': self.contrast_prob,
  803. 'saturation': self.saturation_prob,
  804. 'hue': self.hue_prob
  805. }
  806. for id in range(4):
  807. params = params_dict[ops[id].__name__]
  808. prob = prob_dict[ops[id].__name__]
  809. params['im'] = im
  810. if np.random.uniform(0, 1) < prob:
  811. im = ops[id](**params)
  812. if label is None:
  813. return (im, im_info)
  814. else:
  815. return (im, im_info, label)
  816. class ArrangeSegmenter:
  817. """获取训练/验证/预测所需的信息。
  818. Args:
  819. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  820. Raises:
  821. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内
  822. """
  823. def __init__(self, mode):
  824. if mode not in ['train', 'eval', 'test', 'quant']:
  825. raise ValueError(
  826. "mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
  827. )
  828. self.mode = mode
  829. def __call__(self, im, im_info, label=None):
  830. """
  831. Args:
  832. im (np.ndarray): 图像np.ndarray数据。
  833. im_info (list): 存储图像reisze或padding前的shape信息,如
  834. [('resize', [200, 300]), ('padding', [400, 600])]表示
  835. 图像在过resize前shape为(200, 300), 过padding前shape为
  836. (400, 600)
  837. label (np.ndarray): 标注图像np.ndarray数据。
  838. Returns:
  839. tuple: 当mode为'train'或'eval'时,返回的tuple为(im, label),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  840. 当mode为'test'时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;当mode为
  841. 'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
  842. """
  843. im = permute(im, False)
  844. if self.mode == 'train' or self.mode == 'eval':
  845. label = label[np.newaxis, :, :]
  846. return (im, label)
  847. elif self.mode == 'test':
  848. return (im, im_info)
  849. else:
  850. return (im, )