operators.py 77 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # function:
  15. # operators to process sample,
  16. # eg: decode/resize/crop image
  17. from __future__ import absolute_import
  18. from __future__ import print_function
  19. from __future__ import division
  20. try:
  21. from collections.abc import Sequence
  22. except Exception:
  23. from collections import Sequence
  24. from numbers import Number, Integral
  25. import uuid
  26. import random
  27. import math
  28. import numpy as np
  29. import os
  30. import copy
  31. import cv2
  32. from PIL import Image, ImageEnhance, ImageDraw
  33. from paddlex.ppdet.core.workspace import serializable
  34. from paddlex.ppdet.modeling.layers import AnchorGrid
  35. from paddlex.ppdet.modeling import bbox_utils
  36. from .op_helper import (
  37. satisfy_sample_constraint, filter_and_process, generate_sample_bbox,
  38. clip_bbox, data_anchor_sampling, satisfy_sample_constraint_coverage,
  39. crop_image_sampling, generate_sample_bbox_square, bbox_area_sampling,
  40. is_poly, gaussian_radius, draw_gaussian)
  41. from paddlex.ppdet.utils.logger import setup_logger
  42. logger = setup_logger(__name__)
  43. registered_ops = []
  44. def register_op(cls):
  45. registered_ops.append(cls.__name__)
  46. if not hasattr(BaseOperator, cls.__name__):
  47. setattr(BaseOperator, cls.__name__, cls)
  48. else:
  49. raise KeyError("The {} class has been registered.".format(
  50. cls.__name__))
  51. return serializable(cls)
  52. class BboxError(ValueError):
  53. pass
  54. class ImageError(ValueError):
  55. pass
  56. class BaseOperator(object):
  57. def __init__(self, name=None):
  58. if name is None:
  59. name = self.__class__.__name__
  60. self._id = name + '_' + str(uuid.uuid4())[-6:]
  61. def apply(self, sample, context=None):
  62. """ Process a sample.
  63. Args:
  64. sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
  65. context (dict): info about this sample processing
  66. Returns:
  67. result (dict): a processed sample
  68. """
  69. return sample
  70. def __call__(self, sample, context=None):
  71. """ Process a sample.
  72. Args:
  73. sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
  74. context (dict): info about this sample processing
  75. Returns:
  76. result (dict): a processed sample
  77. """
  78. if isinstance(sample, Sequence):
  79. for i in range(len(sample)):
  80. sample[i] = self.apply(sample[i], context)
  81. else:
  82. sample = self.apply(sample, context)
  83. return sample
  84. def __str__(self):
  85. return str(self._id)
  86. @register_op
  87. class Decode(BaseOperator):
  88. def __init__(self):
  89. """ Transform the image data to numpy format following the rgb format
  90. """
  91. super(Decode, self).__init__()
  92. def apply(self, sample, context=None):
  93. """ load image if 'im_file' field is not empty but 'image' is"""
  94. if 'image' not in sample:
  95. with open(sample['im_file'], 'rb') as f:
  96. sample['image'] = f.read()
  97. sample.pop('im_file')
  98. im = sample['image']
  99. data = np.frombuffer(im, dtype='uint8')
  100. im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
  101. if 'keep_ori_im' in sample and sample['keep_ori_im']:
  102. sample['ori_image'] = im
  103. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  104. sample['image'] = im
  105. if 'h' not in sample:
  106. sample['h'] = im.shape[0]
  107. elif sample['h'] != im.shape[0]:
  108. logger.warn(
  109. "The actual image height: {} is not equal to the "
  110. "height: {} in annotation, and update sample['h'] by actual "
  111. "image height.".format(im.shape[0], sample['h']))
  112. sample['h'] = im.shape[0]
  113. if 'w' not in sample:
  114. sample['w'] = im.shape[1]
  115. elif sample['w'] != im.shape[1]:
  116. logger.warn(
  117. "The actual image width: {} is not equal to the "
  118. "width: {} in annotation, and update sample['w'] by actual "
  119. "image width.".format(im.shape[1], sample['w']))
  120. sample['w'] = im.shape[1]
  121. sample['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
  122. sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
  123. return sample
  124. @register_op
  125. class Permute(BaseOperator):
  126. def __init__(self):
  127. """
  128. Change the channel to be (C, H, W)
  129. """
  130. super(Permute, self).__init__()
  131. def apply(self, sample, context=None):
  132. im = sample['image']
  133. im = im.transpose((2, 0, 1))
  134. sample['image'] = im
  135. return sample
  136. @register_op
  137. class Lighting(BaseOperator):
  138. """
  139. Lighting the image by eigenvalues and eigenvectors
  140. Args:
  141. eigval (list): eigenvalues
  142. eigvec (list): eigenvectors
  143. alphastd (float): random weight of lighting, 0.1 by default
  144. """
  145. def __init__(self, eigval, eigvec, alphastd=0.1):
  146. super(Lighting, self).__init__()
  147. self.alphastd = alphastd
  148. self.eigval = np.array(eigval).astype('float32')
  149. self.eigvec = np.array(eigvec).astype('float32')
  150. def apply(self, sample, context=None):
  151. alpha = np.random.normal(scale=self.alphastd, size=(3, ))
  152. sample['image'] += np.dot(self.eigvec, self.eigval * alpha)
  153. return sample
  154. @register_op
  155. class RandomErasingImage(BaseOperator):
  156. def __init__(self, prob=0.5, lower=0.02, higher=0.4, aspect_ratio=0.3):
  157. """
  158. Random Erasing Data Augmentation, see https://arxiv.org/abs/1708.04896
  159. Args:
  160. prob (float): probability to carry out random erasing
  161. lower (float): lower limit of the erasing area ratio
  162. heigher (float): upper limit of the erasing area ratio
  163. aspect_ratio (float): aspect ratio of the erasing region
  164. """
  165. super(RandomErasingImage, self).__init__()
  166. self.prob = prob
  167. self.lower = lower
  168. self.heigher = heigher
  169. self.aspect_ratio = aspect_ratio
  170. def apply(self, sample):
  171. gt_bbox = sample['gt_bbox']
  172. im = sample['image']
  173. if not isinstance(im, np.ndarray):
  174. raise TypeError("{}: image is not a numpy array.".format(self))
  175. if len(im.shape) != 3:
  176. raise ImageError("{}: image is not 3-dimensional.".format(self))
  177. for idx in range(gt_bbox.shape[0]):
  178. if self.prob <= np.random.rand():
  179. continue
  180. x1, y1, x2, y2 = gt_bbox[idx, :]
  181. w_bbox = x2 - x1
  182. h_bbox = y2 - y1
  183. area = w_bbox * h_bbox
  184. target_area = random.uniform(self.lower, self.higher) * area
  185. aspect_ratio = random.uniform(self.aspect_ratio,
  186. 1 / self.aspect_ratio)
  187. h = int(round(math.sqrt(target_area * aspect_ratio)))
  188. w = int(round(math.sqrt(target_area / aspect_ratio)))
  189. if w < w_bbox and h < h_bbox:
  190. off_y1 = random.randint(0, int(h_bbox - h))
  191. off_x1 = random.randint(0, int(w_bbox - w))
  192. im[int(y1 + off_y1):int(y1 + off_y1 + h), int(x1 + off_x1):int(
  193. x1 + off_x1 + w), :] = 0
  194. sample['image'] = im
  195. return sample
  196. @register_op
  197. class NormalizeImage(BaseOperator):
  198. def __init__(self,
  199. mean=[0.485, 0.456, 0.406],
  200. std=[1, 1, 1],
  201. is_scale=True):
  202. """
  203. Args:
  204. mean (list): the pixel mean
  205. std (list): the pixel variance
  206. """
  207. super(NormalizeImage, self).__init__()
  208. self.mean = mean
  209. self.std = std
  210. self.is_scale = is_scale
  211. if not (isinstance(self.mean, list) and isinstance(self.std, list) and
  212. isinstance(self.is_scale, bool)):
  213. raise TypeError("{}: input type is invalid.".format(self))
  214. from functools import reduce
  215. if reduce(lambda x, y: x * y, self.std) == 0:
  216. raise ValueError('{}: std is invalid!'.format(self))
  217. def apply(self, sample, context=None):
  218. """Normalize the image.
  219. Operators:
  220. 1.(optional) Scale the image to [0,1]
  221. 2. Each pixel minus mean and is divided by std
  222. """
  223. im = sample['image']
  224. im = im.astype(np.float32, copy=False)
  225. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  226. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  227. if self.is_scale:
  228. im = im / 255.0
  229. im -= mean
  230. im /= std
  231. sample['image'] = im
  232. return sample
  233. @register_op
  234. class GridMask(BaseOperator):
  235. def __init__(self,
  236. use_h=True,
  237. use_w=True,
  238. rotate=1,
  239. offset=False,
  240. ratio=0.5,
  241. mode=1,
  242. prob=0.7,
  243. upper_iter=360000):
  244. """
  245. GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086
  246. Args:
  247. use_h (bool): whether to mask vertically
  248. use_w (boo;): whether to mask horizontally
  249. rotate (float): angle for the mask to rotate
  250. offset (float): mask offset
  251. ratio (float): mask ratio
  252. mode (int): gridmask mode
  253. prob (float): max probability to carry out gridmask
  254. upper_iter (int): suggested to be equal to global max_iter
  255. """
  256. super(GridMask, self).__init__()
  257. self.use_h = use_h
  258. self.use_w = use_w
  259. self.rotate = rotate
  260. self.offset = offset
  261. self.ratio = ratio
  262. self.mode = mode
  263. self.prob = prob
  264. self.upper_iter = upper_iter
  265. from .gridmask_utils import Gridmask
  266. self.gridmask_op = Gridmask(
  267. use_h,
  268. use_w,
  269. rotate=rotate,
  270. offset=offset,
  271. ratio=ratio,
  272. mode=mode,
  273. prob=prob,
  274. upper_iter=upper_iter)
  275. def apply(self, sample, context=None):
  276. sample['image'] = self.gridmask_op(sample['image'],
  277. sample['curr_iter'])
  278. return sample
  279. @register_op
  280. class RandomDistort(BaseOperator):
  281. """Random color distortion.
  282. Args:
  283. hue (list): hue settings. in [lower, upper, probability] format.
  284. saturation (list): saturation settings. in [lower, upper, probability] format.
  285. contrast (list): contrast settings. in [lower, upper, probability] format.
  286. brightness (list): brightness settings. in [lower, upper, probability] format.
  287. random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
  288. order.
  289. count (int): the number of doing distrot
  290. random_channel (bool): whether to swap channels randomly
  291. """
  292. def __init__(self,
  293. hue=[-18, 18, 0.5],
  294. saturation=[0.5, 1.5, 0.5],
  295. contrast=[0.5, 1.5, 0.5],
  296. brightness=[0.5, 1.5, 0.5],
  297. random_apply=True,
  298. count=4,
  299. random_channel=False):
  300. super(RandomDistort, self).__init__()
  301. self.hue = hue
  302. self.saturation = saturation
  303. self.contrast = contrast
  304. self.brightness = brightness
  305. self.random_apply = random_apply
  306. self.count = count
  307. self.random_channel = random_channel
  308. def apply_hue(self, img):
  309. low, high, prob = self.hue
  310. if np.random.uniform(0., 1.) < prob:
  311. return img
  312. img = img.astype(np.float32)
  313. # it works, but result differ from HSV version
  314. delta = np.random.uniform(low, high)
  315. u = np.cos(delta * np.pi)
  316. w = np.sin(delta * np.pi)
  317. bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
  318. tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
  319. [0.211, -0.523, 0.311]])
  320. ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
  321. [1.0, -1.107, 1.705]])
  322. t = np.dot(np.dot(ityiq, bt), tyiq).T
  323. img = np.dot(img, t)
  324. return img
  325. def apply_saturation(self, img):
  326. low, high, prob = self.saturation
  327. if np.random.uniform(0., 1.) < prob:
  328. return img
  329. delta = np.random.uniform(low, high)
  330. img = img.astype(np.float32)
  331. # it works, but result differ from HSV version
  332. gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
  333. gray = gray.sum(axis=2, keepdims=True)
  334. gray *= (1.0 - delta)
  335. img *= delta
  336. img += gray
  337. return img
  338. def apply_contrast(self, img):
  339. low, high, prob = self.contrast
  340. if np.random.uniform(0., 1.) < prob:
  341. return img
  342. delta = np.random.uniform(low, high)
  343. img = img.astype(np.float32)
  344. img *= delta
  345. return img
  346. def apply_brightness(self, img):
  347. low, high, prob = self.brightness
  348. if np.random.uniform(0., 1.) < prob:
  349. return img
  350. delta = np.random.uniform(low, high)
  351. img = img.astype(np.float32)
  352. img += delta
  353. return img
  354. def apply(self, sample, context=None):
  355. img = sample['image']
  356. if self.random_apply:
  357. functions = [
  358. self.apply_brightness, self.apply_contrast,
  359. self.apply_saturation, self.apply_hue
  360. ]
  361. distortions = np.random.permutation(functions)[:self.count]
  362. for func in distortions:
  363. img = func(img)
  364. sample['image'] = img
  365. return sample
  366. img = self.apply_brightness(img)
  367. mode = np.random.randint(0, 2)
  368. if mode:
  369. img = self.apply_contrast(img)
  370. img = self.apply_saturation(img)
  371. img = self.apply_hue(img)
  372. if not mode:
  373. img = self.apply_contrast(img)
  374. if self.random_channel:
  375. if np.random.randint(0, 2):
  376. img = img[..., np.random.permutation(3)]
  377. sample['image'] = img
  378. return sample
  379. @register_op
  380. class AutoAugment(BaseOperator):
  381. def __init__(self, autoaug_type="v1"):
  382. """
  383. Args:
  384. autoaug_type (str): autoaug type, support v0, v1, v2, v3, test
  385. """
  386. super(AutoAugment, self).__init__()
  387. self.autoaug_type = autoaug_type
  388. def apply(self, sample, context=None):
  389. """
  390. Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172
  391. """
  392. im = sample['image']
  393. gt_bbox = sample['gt_bbox']
  394. if not isinstance(im, np.ndarray):
  395. raise TypeError("{}: image is not a numpy array.".format(self))
  396. if len(im.shape) != 3:
  397. raise ImageError("{}: image is not 3-dimensional.".format(self))
  398. if len(gt_bbox) == 0:
  399. return sample
  400. height, width, _ = im.shape
  401. norm_gt_bbox = np.ones_like(gt_bbox, dtype=np.float32)
  402. norm_gt_bbox[:, 0] = gt_bbox[:, 1] / float(height)
  403. norm_gt_bbox[:, 1] = gt_bbox[:, 0] / float(width)
  404. norm_gt_bbox[:, 2] = gt_bbox[:, 3] / float(height)
  405. norm_gt_bbox[:, 3] = gt_bbox[:, 2] / float(width)
  406. from .autoaugment_utils import distort_image_with_autoaugment
  407. im, norm_gt_bbox = distort_image_with_autoaugment(im, norm_gt_bbox,
  408. self.autoaug_type)
  409. gt_bbox[:, 0] = norm_gt_bbox[:, 1] * float(width)
  410. gt_bbox[:, 1] = norm_gt_bbox[:, 0] * float(height)
  411. gt_bbox[:, 2] = norm_gt_bbox[:, 3] * float(width)
  412. gt_bbox[:, 3] = norm_gt_bbox[:, 2] * float(height)
  413. sample['image'] = im
  414. sample['gt_bbox'] = gt_bbox
  415. return sample
  416. @register_op
  417. class RandomFlip(BaseOperator):
  418. def __init__(self, prob=0.5):
  419. """
  420. Args:
  421. prob (float): the probability of flipping image
  422. """
  423. super(RandomFlip, self).__init__()
  424. self.prob = prob
  425. if not (isinstance(self.prob, float)):
  426. raise TypeError("{}: input type is invalid.".format(self))
  427. def apply_segm(self, segms, height, width):
  428. def _flip_poly(poly, width):
  429. flipped_poly = np.array(poly)
  430. flipped_poly[0::2] = width - np.array(poly[0::2])
  431. return flipped_poly.tolist()
  432. def _flip_rle(rle, height, width):
  433. if 'counts' in rle and type(rle['counts']) == list:
  434. rle = mask_util.frPyObjects(rle, height, width)
  435. mask = mask_util.decode(rle)
  436. mask = mask[:, ::-1]
  437. rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
  438. return rle
  439. flipped_segms = []
  440. for segm in segms:
  441. if is_poly(segm):
  442. # Polygon format
  443. flipped_segms.append(
  444. [_flip_poly(poly, width) for poly in segm])
  445. else:
  446. # RLE format
  447. import pycocotools.mask as mask_util
  448. flipped_segms.append(_flip_rle(segm, height, width))
  449. return flipped_segms
  450. def apply_keypoint(self, gt_keypoint, width):
  451. for i in range(gt_keypoint.shape[1]):
  452. if i % 2 == 0:
  453. old_x = gt_keypoint[:, i].copy()
  454. gt_keypoint[:, i] = width - old_x
  455. return gt_keypoint
  456. def apply_image(self, image):
  457. return image[:, ::-1, :]
  458. def apply_bbox(self, bbox, width):
  459. oldx1 = bbox[:, 0].copy()
  460. oldx2 = bbox[:, 2].copy()
  461. bbox[:, 0] = width - oldx2
  462. bbox[:, 2] = width - oldx1
  463. return bbox
  464. def apply_rbox(self, bbox, width):
  465. oldx1 = bbox[:, 0].copy()
  466. oldx2 = bbox[:, 2].copy()
  467. oldx3 = bbox[:, 4].copy()
  468. oldx4 = bbox[:, 6].copy()
  469. bbox[:, 0] = width - oldx1
  470. bbox[:, 2] = width - oldx2
  471. bbox[:, 4] = width - oldx3
  472. bbox[:, 6] = width - oldx4
  473. bbox = [bbox_utils.get_best_begin_point_single(e) for e in bbox]
  474. return bbox
  475. def apply(self, sample, context=None):
  476. """Filp the image and bounding box.
  477. Operators:
  478. 1. Flip the image numpy.
  479. 2. Transform the bboxes' x coordinates.
  480. (Must judge whether the coordinates are normalized!)
  481. 3. Transform the segmentations' x coordinates.
  482. (Must judge whether the coordinates are normalized!)
  483. Output:
  484. sample: the image, bounding box and segmentation part
  485. in sample are flipped.
  486. """
  487. if np.random.uniform(0, 1) < self.prob:
  488. im = sample['image']
  489. height, width = im.shape[:2]
  490. im = self.apply_image(im)
  491. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  492. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], width)
  493. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  494. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], height,
  495. width)
  496. if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0:
  497. sample['gt_keypoint'] = self.apply_keypoint(
  498. sample['gt_keypoint'], width)
  499. if 'semantic' in sample and sample['semantic']:
  500. sample['semantic'] = sample['semantic'][:, ::-1]
  501. if 'gt_segm' in sample and sample['gt_segm'].any():
  502. sample['gt_segm'] = sample['gt_segm'][:, :, ::-1]
  503. if 'gt_rbox2poly' in sample and sample['gt_rbox2poly'].any():
  504. sample['gt_rbox2poly'] = self.apply_rbox(
  505. sample['gt_rbox2poly'], width)
  506. sample['flipped'] = True
  507. sample['image'] = im
  508. return sample
  509. @register_op
  510. class Resize(BaseOperator):
  511. def __init__(self, target_size, keep_ratio, interp=cv2.INTER_LINEAR):
  512. """
  513. Resize image to target size. if keep_ratio is True,
  514. resize the image's long side to the maximum of target_size
  515. if keep_ratio is False, resize the image to target size(h, w)
  516. Args:
  517. target_size (int|list): image target size
  518. keep_ratio (bool): whether keep_ratio or not, default true
  519. interp (int): the interpolation method
  520. """
  521. super(Resize, self).__init__()
  522. self.keep_ratio = keep_ratio
  523. self.interp = interp
  524. if not isinstance(target_size, (Integral, Sequence)):
  525. raise TypeError(
  526. "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
  527. format(type(target_size)))
  528. if isinstance(target_size, Integral):
  529. target_size = [target_size, target_size]
  530. self.target_size = target_size
  531. def apply_image(self, image, scale):
  532. im_scale_x, im_scale_y = scale
  533. return cv2.resize(
  534. image,
  535. None,
  536. None,
  537. fx=im_scale_x,
  538. fy=im_scale_y,
  539. interpolation=self.interp)
  540. def apply_bbox(self, bbox, scale, size):
  541. im_scale_x, im_scale_y = scale
  542. resize_w, resize_h = size
  543. bbox[:, 0::2] *= im_scale_x
  544. bbox[:, 1::2] *= im_scale_y
  545. bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, resize_w)
  546. bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h)
  547. return bbox
  548. def apply_segm(self, segms, im_size, scale):
  549. def _resize_poly(poly, im_scale_x, im_scale_y):
  550. resized_poly = np.array(poly).astype('float32')
  551. resized_poly[0::2] *= im_scale_x
  552. resized_poly[1::2] *= im_scale_y
  553. return resized_poly.tolist()
  554. def _resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y):
  555. if 'counts' in rle and type(rle['counts']) == list:
  556. rle = mask_util.frPyObjects(rle, im_h, im_w)
  557. mask = mask_util.decode(rle)
  558. mask = cv2.resize(
  559. image,
  560. None,
  561. None,
  562. fx=im_scale_x,
  563. fy=im_scale_y,
  564. interpolation=self.interp)
  565. rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
  566. return rle
  567. im_h, im_w = im_size
  568. im_scale_x, im_scale_y = scale
  569. resized_segms = []
  570. for segm in segms:
  571. if is_poly(segm):
  572. # Polygon format
  573. resized_segms.append([
  574. _resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
  575. ])
  576. else:
  577. # RLE format
  578. import pycocotools.mask as mask_util
  579. resized_segms.append(
  580. _resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
  581. return resized_segms
  582. def apply(self, sample, context=None):
  583. """ Resize the image numpy.
  584. """
  585. im = sample['image']
  586. if not isinstance(im, np.ndarray):
  587. raise TypeError("{}: image type is not numpy.".format(self))
  588. if len(im.shape) != 3:
  589. raise ImageError('{}: image is not 3-dimensional.'.format(self))
  590. # apply image
  591. im_shape = im.shape
  592. if self.keep_ratio:
  593. im_size_min = np.min(im_shape[0:2])
  594. im_size_max = np.max(im_shape[0:2])
  595. target_size_min = np.min(self.target_size)
  596. target_size_max = np.max(self.target_size)
  597. im_scale = min(target_size_min / im_size_min,
  598. target_size_max / im_size_max)
  599. resize_h = im_scale * float(im_shape[0])
  600. resize_w = im_scale * float(im_shape[1])
  601. im_scale_x = im_scale
  602. im_scale_y = im_scale
  603. else:
  604. resize_h, resize_w = self.target_size
  605. im_scale_y = resize_h / im_shape[0]
  606. im_scale_x = resize_w / im_shape[1]
  607. im = self.apply_image(sample['image'], [im_scale_x, im_scale_y])
  608. sample['image'] = im
  609. sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32)
  610. if 'scale_factor' in sample:
  611. scale_factor = sample['scale_factor']
  612. sample['scale_factor'] = np.asarray(
  613. [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
  614. dtype=np.float32)
  615. else:
  616. sample['scale_factor'] = np.asarray(
  617. [im_scale_y, im_scale_x], dtype=np.float32)
  618. # apply bbox
  619. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  620. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'],
  621. [im_scale_x, im_scale_y],
  622. [resize_w, resize_h])
  623. # apply rbox
  624. if 'gt_rbox2poly' in sample:
  625. if np.array(sample['gt_rbox2poly']).shape[1] != 8:
  626. logger.warn(
  627. "gt_rbox2poly's length shoule be 8, but actually is {}".
  628. format(len(sample['gt_rbox2poly'])))
  629. sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'],
  630. [im_scale_x, im_scale_y],
  631. [resize_w, resize_h])
  632. # apply polygon
  633. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  634. sample['gt_poly'] = self.apply_segm(
  635. sample['gt_poly'], im_shape[:2], [im_scale_x, im_scale_y])
  636. # apply semantic
  637. if 'semantic' in sample and sample['semantic']:
  638. semantic = sample['semantic']
  639. semantic = cv2.resize(
  640. semantic.astype('float32'),
  641. None,
  642. None,
  643. fx=im_scale_x,
  644. fy=im_scale_y,
  645. interpolation=self.interp)
  646. semantic = np.asarray(semantic).astype('int32')
  647. semantic = np.expand_dims(semantic, 0)
  648. sample['semantic'] = semantic
  649. # apply gt_segm
  650. if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
  651. masks = [
  652. cv2.resize(
  653. gt_segm,
  654. None,
  655. None,
  656. fx=im_scale_x,
  657. fy=im_scale_y,
  658. interpolation=cv2.INTER_NEAREST)
  659. for gt_segm in sample['gt_segm']
  660. ]
  661. sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
  662. return sample
  663. @register_op
  664. class MultiscaleTestResize(BaseOperator):
  665. def __init__(self,
  666. origin_target_size=[800, 1333],
  667. target_size=[],
  668. interp=cv2.INTER_LINEAR,
  669. use_flip=True):
  670. """
  671. Rescale image to the each size in target size, and capped at max_size.
  672. Args:
  673. origin_target_size (list): origin target size of image
  674. target_size (list): A list of target sizes of image.
  675. interp (int): the interpolation method.
  676. use_flip (bool): whether use flip augmentation.
  677. """
  678. super(MultiscaleTestResize, self).__init__()
  679. self.interp = interp
  680. self.use_flip = use_flip
  681. if not isinstance(target_size, Sequence):
  682. raise TypeError(
  683. "Type of target_size is invalid. Must be List or Tuple, now is {}".
  684. format(type(target_size)))
  685. self.target_size = target_size
  686. if not isinstance(origin_target_size, Sequence):
  687. raise TypeError(
  688. "Type of origin_target_size is invalid. Must be List or Tuple, now is {}".
  689. format(type(origin_target_size)))
  690. self.origin_target_size = origin_target_size
  691. def apply(self, sample, context=None):
  692. """ Resize the image numpy for multi-scale test.
  693. """
  694. samples = []
  695. resizer = Resize(
  696. self.origin_target_size, keep_ratio=True, interp=self.interp)
  697. samples.append(resizer(sample.copy(), context))
  698. if self.use_flip:
  699. flipper = RandomFlip(1.1)
  700. samples.append(flipper(sample.copy(), context=context))
  701. for size in self.target_size:
  702. resizer = Resize(size, keep_ratio=True, interp=self.interp)
  703. samples.append(resizer(sample.copy(), context))
  704. return samples
  705. @register_op
  706. class RandomResize(BaseOperator):
  707. def __init__(self,
  708. target_size,
  709. keep_ratio=True,
  710. interp=cv2.INTER_LINEAR,
  711. random_size=True,
  712. random_interp=False):
  713. """
  714. Resize image to target size randomly. random target_size and interpolation method
  715. Args:
  716. target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
  717. keep_ratio (bool): whether keep_raio or not, default true
  718. interp (int): the interpolation method
  719. random_size (bool): whether random select target size of image
  720. random_interp (bool): whether random select interpolation method
  721. """
  722. super(RandomResize, self).__init__()
  723. self.keep_ratio = keep_ratio
  724. self.interp = interp
  725. self.interps = [
  726. cv2.INTER_NEAREST,
  727. cv2.INTER_LINEAR,
  728. cv2.INTER_AREA,
  729. cv2.INTER_CUBIC,
  730. cv2.INTER_LANCZOS4,
  731. ]
  732. assert isinstance(target_size, (
  733. Integral, Sequence)), "target_size must be Integer, List or Tuple"
  734. if random_size and not isinstance(target_size, Sequence):
  735. raise TypeError(
  736. "Type of target_size is invalid when random_size is True. Must be List or Tuple, now is {}".
  737. format(type(target_size)))
  738. self.target_size = target_size
  739. self.random_size = random_size
  740. self.random_interp = random_interp
  741. def apply(self, sample, context=None):
  742. """ Resize the image numpy.
  743. """
  744. if self.random_size:
  745. target_size = random.choice(self.target_size)
  746. else:
  747. target_size = self.target_size
  748. if self.random_interp:
  749. interp = random.choice(self.interps)
  750. else:
  751. interp = self.interp
  752. resizer = Resize(target_size, self.keep_ratio, interp)
  753. return resizer(sample, context=context)
  754. @register_op
  755. class RandomExpand(BaseOperator):
  756. """Random expand the canvas.
  757. Args:
  758. ratio (float): maximum expansion ratio.
  759. prob (float): probability to expand.
  760. fill_value (list): color value used to fill the canvas. in RGB order.
  761. """
  762. def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)):
  763. super(RandomExpand, self).__init__()
  764. assert ratio > 1.01, "expand ratio must be larger than 1.01"
  765. self.ratio = ratio
  766. self.prob = prob
  767. assert isinstance(fill_value, (Number, Sequence)), \
  768. "fill value must be either float or sequence"
  769. if isinstance(fill_value, Number):
  770. fill_value = (fill_value, ) * 3
  771. if not isinstance(fill_value, tuple):
  772. fill_value = tuple(fill_value)
  773. self.fill_value = fill_value
  774. def apply(self, sample, context=None):
  775. if np.random.uniform(0., 1.) < self.prob:
  776. return sample
  777. im = sample['image']
  778. height, width = im.shape[:2]
  779. ratio = np.random.uniform(1., self.ratio)
  780. h = int(height * ratio)
  781. w = int(width * ratio)
  782. if not h > height or not w > width:
  783. return sample
  784. y = np.random.randint(0, h - height)
  785. x = np.random.randint(0, w - width)
  786. offsets, size = [x, y], [h, w]
  787. pad = Pad(size,
  788. pad_mode=-1,
  789. offsets=offsets,
  790. fill_value=self.fill_value)
  791. return pad(sample, context=context)
  792. @register_op
  793. class CropWithSampling(BaseOperator):
  794. def __init__(self, batch_sampler, satisfy_all=False, avoid_no_bbox=True):
  795. """
  796. Args:
  797. batch_sampler (list): Multiple sets of different
  798. parameters for cropping.
  799. satisfy_all (bool): whether all boxes must satisfy.
  800. e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
  801. [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
  802. [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
  803. [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
  804. [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
  805. [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
  806. [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]
  807. [max sample, max trial, min scale, max scale,
  808. min aspect ratio, max aspect ratio,
  809. min overlap, max overlap]
  810. avoid_no_bbox (bool): whether to to avoid the
  811. situation where the box does not appear.
  812. """
  813. super(CropWithSampling, self).__init__()
  814. self.batch_sampler = batch_sampler
  815. self.satisfy_all = satisfy_all
  816. self.avoid_no_bbox = avoid_no_bbox
  817. def apply(self, sample, context):
  818. """
  819. Crop the image and modify bounding box.
  820. Operators:
  821. 1. Scale the image width and height.
  822. 2. Crop the image according to a radom sample.
  823. 3. Rescale the bounding box.
  824. 4. Determine if the new bbox is satisfied in the new image.
  825. Returns:
  826. sample: the image, bounding box are replaced.
  827. """
  828. assert 'image' in sample, "image data not found"
  829. im = sample['image']
  830. gt_bbox = sample['gt_bbox']
  831. gt_class = sample['gt_class']
  832. im_height, im_width = im.shape[:2]
  833. gt_score = None
  834. if 'gt_score' in sample:
  835. gt_score = sample['gt_score']
  836. sampled_bbox = []
  837. gt_bbox = gt_bbox.tolist()
  838. for sampler in self.batch_sampler:
  839. found = 0
  840. for i in range(sampler[1]):
  841. if found >= sampler[0]:
  842. break
  843. sample_bbox = generate_sample_bbox(sampler)
  844. if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox,
  845. self.satisfy_all):
  846. sampled_bbox.append(sample_bbox)
  847. found = found + 1
  848. im = np.array(im)
  849. while sampled_bbox:
  850. idx = int(np.random.uniform(0, len(sampled_bbox)))
  851. sample_bbox = sampled_bbox.pop(idx)
  852. sample_bbox = clip_bbox(sample_bbox)
  853. crop_bbox, crop_class, crop_score = \
  854. filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score)
  855. if self.avoid_no_bbox:
  856. if len(crop_bbox) < 1:
  857. continue
  858. xmin = int(sample_bbox[0] * im_width)
  859. xmax = int(sample_bbox[2] * im_width)
  860. ymin = int(sample_bbox[1] * im_height)
  861. ymax = int(sample_bbox[3] * im_height)
  862. im = im[ymin:ymax, xmin:xmax]
  863. sample['image'] = im
  864. sample['gt_bbox'] = crop_bbox
  865. sample['gt_class'] = crop_class
  866. sample['gt_score'] = crop_score
  867. return sample
  868. return sample
  869. @register_op
  870. class CropWithDataAchorSampling(BaseOperator):
  871. def __init__(self,
  872. batch_sampler,
  873. anchor_sampler=None,
  874. target_size=None,
  875. das_anchor_scales=[16, 32, 64, 128],
  876. sampling_prob=0.5,
  877. min_size=8.,
  878. avoid_no_bbox=True):
  879. """
  880. Args:
  881. anchor_sampler (list): anchor_sampling sets of different
  882. parameters for cropping.
  883. batch_sampler (list): Multiple sets of different
  884. parameters for cropping.
  885. e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]]
  886. [[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
  887. [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
  888. [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
  889. [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
  890. [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]]
  891. [max sample, max trial, min scale, max scale,
  892. min aspect ratio, max aspect ratio,
  893. min overlap, max overlap, min coverage, max coverage]
  894. target_size (int): target image size.
  895. das_anchor_scales (list[float]): a list of anchor scales in data
  896. anchor smapling.
  897. min_size (float): minimum size of sampled bbox.
  898. avoid_no_bbox (bool): whether to to avoid the
  899. situation where the box does not appear.
  900. """
  901. super(CropWithDataAchorSampling, self).__init__()
  902. self.anchor_sampler = anchor_sampler
  903. self.batch_sampler = batch_sampler
  904. self.target_size = target_size
  905. self.sampling_prob = sampling_prob
  906. self.min_size = min_size
  907. self.avoid_no_bbox = avoid_no_bbox
  908. self.das_anchor_scales = np.array(das_anchor_scales)
  909. def apply(self, sample, context):
  910. """
  911. Crop the image and modify bounding box.
  912. Operators:
  913. 1. Scale the image width and height.
  914. 2. Crop the image according to a radom sample.
  915. 3. Rescale the bounding box.
  916. 4. Determine if the new bbox is satisfied in the new image.
  917. Returns:
  918. sample: the image, bounding box are replaced.
  919. """
  920. assert 'image' in sample, "image data not found"
  921. im = sample['image']
  922. gt_bbox = sample['gt_bbox']
  923. gt_class = sample['gt_class']
  924. image_height, image_width = im.shape[:2]
  925. gt_bbox[:, 0] /= image_width
  926. gt_bbox[:, 1] /= image_height
  927. gt_bbox[:, 2] /= image_width
  928. gt_bbox[:, 3] /= image_height
  929. gt_score = None
  930. if 'gt_score' in sample:
  931. gt_score = sample['gt_score']
  932. sampled_bbox = []
  933. gt_bbox = gt_bbox.tolist()
  934. prob = np.random.uniform(0., 1.)
  935. if prob > self.sampling_prob: # anchor sampling
  936. assert self.anchor_sampler
  937. for sampler in self.anchor_sampler:
  938. found = 0
  939. for i in range(sampler[1]):
  940. if found >= sampler[0]:
  941. break
  942. sample_bbox = data_anchor_sampling(
  943. gt_bbox, image_width, image_height,
  944. self.das_anchor_scales, self.target_size)
  945. if sample_bbox == 0:
  946. break
  947. if satisfy_sample_constraint_coverage(sampler, sample_bbox,
  948. gt_bbox):
  949. sampled_bbox.append(sample_bbox)
  950. found = found + 1
  951. im = np.array(im)
  952. while sampled_bbox:
  953. idx = int(np.random.uniform(0, len(sampled_bbox)))
  954. sample_bbox = sampled_bbox.pop(idx)
  955. if 'gt_keypoint' in sample.keys():
  956. keypoints = (sample['gt_keypoint'],
  957. sample['keypoint_ignore'])
  958. crop_bbox, crop_class, crop_score, gt_keypoints = \
  959. filter_and_process(sample_bbox, gt_bbox, gt_class,
  960. scores=gt_score,
  961. keypoints=keypoints)
  962. else:
  963. crop_bbox, crop_class, crop_score = filter_and_process(
  964. sample_bbox, gt_bbox, gt_class, scores=gt_score)
  965. crop_bbox, crop_class, crop_score = bbox_area_sampling(
  966. crop_bbox, crop_class, crop_score, self.target_size,
  967. self.min_size)
  968. if self.avoid_no_bbox:
  969. if len(crop_bbox) < 1:
  970. continue
  971. im = crop_image_sampling(im, sample_bbox, image_width,
  972. image_height, self.target_size)
  973. height, width = im.shape[:2]
  974. crop_bbox[:, 0] *= width
  975. crop_bbox[:, 1] *= height
  976. crop_bbox[:, 2] *= width
  977. crop_bbox[:, 3] *= height
  978. sample['image'] = im
  979. sample['gt_bbox'] = crop_bbox
  980. sample['gt_class'] = crop_class
  981. if 'gt_score' in sample:
  982. sample['gt_score'] = crop_score
  983. if 'gt_keypoint' in sample.keys():
  984. sample['gt_keypoint'] = gt_keypoints[0]
  985. sample['keypoint_ignore'] = gt_keypoints[1]
  986. return sample
  987. return sample
  988. else:
  989. for sampler in self.batch_sampler:
  990. found = 0
  991. for i in range(sampler[1]):
  992. if found >= sampler[0]:
  993. break
  994. sample_bbox = generate_sample_bbox_square(
  995. sampler, image_width, image_height)
  996. if satisfy_sample_constraint_coverage(sampler, sample_bbox,
  997. gt_bbox):
  998. sampled_bbox.append(sample_bbox)
  999. found = found + 1
  1000. im = np.array(im)
  1001. while sampled_bbox:
  1002. idx = int(np.random.uniform(0, len(sampled_bbox)))
  1003. sample_bbox = sampled_bbox.pop(idx)
  1004. sample_bbox = clip_bbox(sample_bbox)
  1005. if 'gt_keypoint' in sample.keys():
  1006. keypoints = (sample['gt_keypoint'],
  1007. sample['keypoint_ignore'])
  1008. crop_bbox, crop_class, crop_score, gt_keypoints = \
  1009. filter_and_process(sample_bbox, gt_bbox, gt_class,
  1010. scores=gt_score,
  1011. keypoints=keypoints)
  1012. else:
  1013. crop_bbox, crop_class, crop_score = filter_and_process(
  1014. sample_bbox, gt_bbox, gt_class, scores=gt_score)
  1015. # sampling bbox according the bbox area
  1016. crop_bbox, crop_class, crop_score = bbox_area_sampling(
  1017. crop_bbox, crop_class, crop_score, self.target_size,
  1018. self.min_size)
  1019. if self.avoid_no_bbox:
  1020. if len(crop_bbox) < 1:
  1021. continue
  1022. xmin = int(sample_bbox[0] * image_width)
  1023. xmax = int(sample_bbox[2] * image_width)
  1024. ymin = int(sample_bbox[1] * image_height)
  1025. ymax = int(sample_bbox[3] * image_height)
  1026. im = im[ymin:ymax, xmin:xmax]
  1027. height, width = im.shape[:2]
  1028. crop_bbox[:, 0] *= width
  1029. crop_bbox[:, 1] *= height
  1030. crop_bbox[:, 2] *= width
  1031. crop_bbox[:, 3] *= height
  1032. sample['image'] = im
  1033. sample['gt_bbox'] = crop_bbox
  1034. sample['gt_class'] = crop_class
  1035. if 'gt_score' in sample:
  1036. sample['gt_score'] = crop_score
  1037. if 'gt_keypoint' in sample.keys():
  1038. sample['gt_keypoint'] = gt_keypoints[0]
  1039. sample['keypoint_ignore'] = gt_keypoints[1]
  1040. return sample
  1041. return sample
  1042. @register_op
  1043. class RandomCrop(BaseOperator):
  1044. """Random crop image and bboxes.
  1045. Args:
  1046. aspect_ratio (list): aspect ratio of cropped region.
  1047. in [min, max] format.
  1048. thresholds (list): iou thresholds for decide a valid bbox crop.
  1049. scaling (list): ratio between a cropped region and the original image.
  1050. in [min, max] format.
  1051. num_attempts (int): number of tries before giving up.
  1052. allow_no_crop (bool): allow return without actually cropping them.
  1053. cover_all_box (bool): ensure all bboxes are covered in the final crop.
  1054. is_mask_crop(bool): whether crop the segmentation.
  1055. """
  1056. def __init__(self,
  1057. aspect_ratio=[.5, 2.],
  1058. thresholds=[.0, .1, .3, .5, .7, .9],
  1059. scaling=[.3, 1.],
  1060. num_attempts=50,
  1061. allow_no_crop=True,
  1062. cover_all_box=False,
  1063. is_mask_crop=False):
  1064. super(RandomCrop, self).__init__()
  1065. self.aspect_ratio = aspect_ratio
  1066. self.thresholds = thresholds
  1067. self.scaling = scaling
  1068. self.num_attempts = num_attempts
  1069. self.allow_no_crop = allow_no_crop
  1070. self.cover_all_box = cover_all_box
  1071. self.is_mask_crop = is_mask_crop
  1072. def crop_segms(self, segms, valid_ids, crop, height, width):
  1073. def _crop_poly(segm, crop):
  1074. xmin, ymin, xmax, ymax = crop
  1075. crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]
  1076. crop_p = np.array(crop_coord).reshape(4, 2)
  1077. crop_p = Polygon(crop_p)
  1078. crop_segm = list()
  1079. for poly in segm:
  1080. poly = np.array(poly).reshape(len(poly) // 2, 2)
  1081. polygon = Polygon(poly)
  1082. if not polygon.is_valid:
  1083. exterior = polygon.exterior
  1084. multi_lines = exterior.intersection(exterior)
  1085. polygons = shapely.ops.polygonize(multi_lines)
  1086. polygon = MultiPolygon(polygons)
  1087. multi_polygon = list()
  1088. if isinstance(polygon, MultiPolygon):
  1089. multi_polygon = copy.deepcopy(polygon)
  1090. else:
  1091. multi_polygon.append(copy.deepcopy(polygon))
  1092. for per_polygon in multi_polygon:
  1093. inter = per_polygon.intersection(crop_p)
  1094. if not inter:
  1095. continue
  1096. if isinstance(inter, (MultiPolygon, GeometryCollection)):
  1097. for part in inter:
  1098. if not isinstance(part, Polygon):
  1099. continue
  1100. part = np.squeeze(
  1101. np.array(part.exterior.coords[:-1]).reshape(
  1102. 1, -1))
  1103. part[0::2] -= xmin
  1104. part[1::2] -= ymin
  1105. crop_segm.append(part.tolist())
  1106. elif isinstance(inter, Polygon):
  1107. crop_poly = np.squeeze(
  1108. np.array(inter.exterior.coords[:-1]).reshape(1,
  1109. -1))
  1110. crop_poly[0::2] -= xmin
  1111. crop_poly[1::2] -= ymin
  1112. crop_segm.append(crop_poly.tolist())
  1113. else:
  1114. continue
  1115. return crop_segm
  1116. def _crop_rle(rle, crop, height, width):
  1117. if 'counts' in rle and type(rle['counts']) == list:
  1118. rle = mask_util.frPyObjects(rle, height, width)
  1119. mask = mask_util.decode(rle)
  1120. mask = mask[crop[1]:crop[3], crop[0]:crop[2]]
  1121. rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
  1122. return rle
  1123. crop_segms = []
  1124. for id in valid_ids:
  1125. segm = segms[id]
  1126. if is_poly(segm):
  1127. import copy
  1128. import shapely.ops
  1129. from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
  1130. logging.getLogger("shapely").setLevel(logging.WARNING)
  1131. # Polygon format
  1132. crop_segms.append(_crop_poly(segm, crop))
  1133. else:
  1134. # RLE format
  1135. import pycocotools.mask as mask_util
  1136. crop_segms.append(_crop_rle(segm, crop, height, width))
  1137. return crop_segms
  1138. def apply(self, sample, context=None):
  1139. if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
  1140. return sample
  1141. h, w = sample['image'].shape[:2]
  1142. gt_bbox = sample['gt_bbox']
  1143. # NOTE Original method attempts to generate one candidate for each
  1144. # threshold then randomly sample one from the resulting list.
  1145. # Here a short circuit approach is taken, i.e., randomly choose a
  1146. # threshold and attempt to find a valid crop, and simply return the
  1147. # first one found.
  1148. # The probability is not exactly the same, kinda resembling the
  1149. # "Monty Hall" problem. Actually carrying out the attempts will affect
  1150. # observability (just like opening doors in the "Monty Hall" game).
  1151. thresholds = list(self.thresholds)
  1152. if self.allow_no_crop:
  1153. thresholds.append('no_crop')
  1154. np.random.shuffle(thresholds)
  1155. for thresh in thresholds:
  1156. if thresh == 'no_crop':
  1157. return sample
  1158. found = False
  1159. for i in range(self.num_attempts):
  1160. scale = np.random.uniform(*self.scaling)
  1161. if self.aspect_ratio is not None:
  1162. min_ar, max_ar = self.aspect_ratio
  1163. aspect_ratio = np.random.uniform(
  1164. max(min_ar, scale**2), min(max_ar, scale**-2))
  1165. h_scale = scale / np.sqrt(aspect_ratio)
  1166. w_scale = scale * np.sqrt(aspect_ratio)
  1167. else:
  1168. h_scale = np.random.uniform(*self.scaling)
  1169. w_scale = np.random.uniform(*self.scaling)
  1170. crop_h = h * h_scale
  1171. crop_w = w * w_scale
  1172. if self.aspect_ratio is None:
  1173. if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
  1174. continue
  1175. crop_h = int(crop_h)
  1176. crop_w = int(crop_w)
  1177. crop_y = np.random.randint(0, h - crop_h)
  1178. crop_x = np.random.randint(0, w - crop_w)
  1179. crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
  1180. iou = self._iou_matrix(
  1181. gt_bbox, np.array(
  1182. [crop_box], dtype=np.float32))
  1183. if iou.max() < thresh:
  1184. continue
  1185. if self.cover_all_box and iou.min() < thresh:
  1186. continue
  1187. cropped_box, valid_ids = self._crop_box_with_center_constraint(
  1188. gt_bbox, np.array(
  1189. crop_box, dtype=np.float32))
  1190. if valid_ids.size > 0:
  1191. found = True
  1192. break
  1193. if found:
  1194. if self.is_mask_crop and 'gt_poly' in sample and len(sample[
  1195. 'gt_poly']) > 0:
  1196. crop_polys = self.crop_segms(
  1197. sample['gt_poly'],
  1198. valid_ids,
  1199. np.array(
  1200. crop_box, dtype=np.int64),
  1201. h,
  1202. w)
  1203. if [] in crop_polys:
  1204. delete_id = list()
  1205. valid_polys = list()
  1206. for id, crop_poly in enumerate(crop_polys):
  1207. if crop_poly == []:
  1208. delete_id.append(id)
  1209. else:
  1210. valid_polys.append(crop_poly)
  1211. valid_ids = np.delete(valid_ids, delete_id)
  1212. if len(valid_polys) == 0:
  1213. return sample
  1214. sample['gt_poly'] = valid_polys
  1215. else:
  1216. sample['gt_poly'] = crop_polys
  1217. if 'gt_segm' in sample:
  1218. sample['gt_segm'] = self._crop_segm(sample['gt_segm'],
  1219. crop_box)
  1220. sample['gt_segm'] = np.take(
  1221. sample['gt_segm'], valid_ids, axis=0)
  1222. sample['image'] = self._crop_image(sample['image'], crop_box)
  1223. sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
  1224. sample['gt_class'] = np.take(
  1225. sample['gt_class'], valid_ids, axis=0)
  1226. if 'gt_score' in sample:
  1227. sample['gt_score'] = np.take(
  1228. sample['gt_score'], valid_ids, axis=0)
  1229. if 'is_crowd' in sample:
  1230. sample['is_crowd'] = np.take(
  1231. sample['is_crowd'], valid_ids, axis=0)
  1232. return sample
  1233. return sample
  1234. def _iou_matrix(self, a, b):
  1235. tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
  1236. br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
  1237. area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
  1238. area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
  1239. area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
  1240. area_o = (area_a[:, np.newaxis] + area_b - area_i)
  1241. return area_i / (area_o + 1e-10)
  1242. def _crop_box_with_center_constraint(self, box, crop):
  1243. cropped_box = box.copy()
  1244. cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
  1245. cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
  1246. cropped_box[:, :2] -= crop[:2]
  1247. cropped_box[:, 2:] -= crop[:2]
  1248. centers = (box[:, :2] + box[:, 2:]) / 2
  1249. valid = np.logical_and(crop[:2] <= centers,
  1250. centers < crop[2:]).all(axis=1)
  1251. valid = np.logical_and(
  1252. valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
  1253. return cropped_box, np.where(valid)[0]
  1254. def _crop_image(self, img, crop):
  1255. x1, y1, x2, y2 = crop
  1256. return img[y1:y2, x1:x2, :]
  1257. def _crop_segm(self, segm, crop):
  1258. x1, y1, x2, y2 = crop
  1259. return segm[:, y1:y2, x1:x2]
  1260. @register_op
  1261. class RandomScaledCrop(BaseOperator):
  1262. """Resize image and bbox based on long side (with optional random scaling),
  1263. then crop or pad image to target size.
  1264. Args:
  1265. target_dim (int): target size.
  1266. scale_range (list): random scale range.
  1267. interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
  1268. """
  1269. def __init__(self,
  1270. target_dim=512,
  1271. scale_range=[.1, 2.],
  1272. interp=cv2.INTER_LINEAR):
  1273. super(RandomScaledCrop, self).__init__()
  1274. self.target_dim = target_dim
  1275. self.scale_range = scale_range
  1276. self.interp = interp
  1277. def apply(self, sample, context=None):
  1278. img = sample['image']
  1279. h, w = img.shape[:2]
  1280. random_scale = np.random.uniform(*self.scale_range)
  1281. dim = self.target_dim
  1282. random_dim = int(dim * random_scale)
  1283. dim_max = max(h, w)
  1284. scale = random_dim / dim_max
  1285. resize_w = w * scale
  1286. resize_h = h * scale
  1287. offset_x = int(max(0, np.random.uniform(0., resize_w - dim)))
  1288. offset_y = int(max(0, np.random.uniform(0., resize_h - dim)))
  1289. img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interp)
  1290. img = np.array(img)
  1291. canvas = np.zeros((dim, dim, 3), dtype=img.dtype)
  1292. canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[
  1293. offset_y:offset_y + dim, offset_x:offset_x + dim, :]
  1294. sample['image'] = canvas
  1295. sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32)
  1296. scale_factor = sample['sacle_factor']
  1297. sample['scale_factor'] = np.asarray(
  1298. [scale_factor[0] * scale, scale_factor[1] * scale],
  1299. dtype=np.float32)
  1300. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1301. scale_array = np.array([scale, scale] * 2, dtype=np.float32)
  1302. shift_array = np.array([offset_x, offset_y] * 2, dtype=np.float32)
  1303. boxes = sample['gt_bbox'] * scale_array - shift_array
  1304. boxes = np.clip(boxes, 0, dim - 1)
  1305. # filter boxes with no area
  1306. area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1)
  1307. valid = (area > 1.).nonzero()[0]
  1308. sample['gt_bbox'] = boxes[valid]
  1309. sample['gt_class'] = sample['gt_class'][valid]
  1310. return sample
  1311. @register_op
  1312. class Cutmix(BaseOperator):
  1313. def __init__(self, alpha=1.5, beta=1.5):
  1314. """
  1315. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://arxiv.org/abs/1905.04899
  1316. Cutmix image and gt_bbbox/gt_score
  1317. Args:
  1318. alpha (float): alpha parameter of beta distribute
  1319. beta (float): beta parameter of beta distribute
  1320. """
  1321. super(Cutmix, self).__init__()
  1322. self.alpha = alpha
  1323. self.beta = beta
  1324. if self.alpha <= 0.0:
  1325. raise ValueError("alpha shold be positive in {}".format(self))
  1326. if self.beta <= 0.0:
  1327. raise ValueError("beta shold be positive in {}".format(self))
  1328. def apply_image(self, img1, img2, factor):
  1329. """ _rand_bbox """
  1330. h = max(img1.shape[0], img2.shape[0])
  1331. w = max(img1.shape[1], img2.shape[1])
  1332. cut_rat = np.sqrt(1. - factor)
  1333. cut_w = np.int(w * cut_rat)
  1334. cut_h = np.int(h * cut_rat)
  1335. # uniform
  1336. cx = np.random.randint(w)
  1337. cy = np.random.randint(h)
  1338. bbx1 = np.clip(cx - cut_w // 2, 0, w - 1)
  1339. bby1 = np.clip(cy - cut_h // 2, 0, h - 1)
  1340. bbx2 = np.clip(cx + cut_w // 2, 0, w - 1)
  1341. bby2 = np.clip(cy + cut_h // 2, 0, h - 1)
  1342. img_1_pad = np.zeros((h, w, img1.shape[2]), 'float32')
  1343. img_1_pad[:img1.shape[0], :img1.shape[1], :] = \
  1344. img1.astype('float32')
  1345. img_2_pad = np.zeros((h, w, img2.shape[2]), 'float32')
  1346. img_2_pad[:img2.shape[0], :img2.shape[1], :] = \
  1347. img2.astype('float32')
  1348. img_1_pad[bby1:bby2, bbx1:bbx2, :] = img_2_pad[bby1:bby2, bbx1:bbx2, :]
  1349. return img_1_pad
  1350. def __call__(self, sample, context=None):
  1351. if not isinstance(sample, Sequence):
  1352. return sample
  1353. assert len(sample) == 2, 'cutmix need two samples'
  1354. factor = np.random.beta(self.alpha, self.beta)
  1355. factor = max(0.0, min(1.0, factor))
  1356. if factor >= 1.0:
  1357. return sample[0]
  1358. if factor <= 0.0:
  1359. return sample[1]
  1360. img1 = sample[0]['image']
  1361. img2 = sample[1]['image']
  1362. img = self.apply_image(img1, img2, factor)
  1363. gt_bbox1 = sample[0]['gt_bbox']
  1364. gt_bbox2 = sample[1]['gt_bbox']
  1365. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  1366. gt_class1 = sample[0]['gt_class']
  1367. gt_class2 = sample[1]['gt_class']
  1368. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  1369. gt_score1 = np.ones_like(sample[0]['gt_class'])
  1370. gt_score2 = np.ones_like(sample[1]['gt_class'])
  1371. gt_score = np.concatenate(
  1372. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  1373. result = copy.deepcopy(sample[0])
  1374. result['image'] = img
  1375. result['gt_bbox'] = gt_bbox
  1376. result['gt_score'] = gt_score
  1377. result['gt_class'] = gt_class
  1378. if 'is_crowd' in sample[0]:
  1379. is_crowd1 = sample[0]['is_crowd']
  1380. is_crowd2 = sample[1]['is_crowd']
  1381. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  1382. result['is_crowd'] = is_crowd
  1383. if 'difficult' in sample[0]:
  1384. is_difficult1 = sample[0]['difficult']
  1385. is_difficult2 = sample[1]['difficult']
  1386. is_difficult = np.concatenate(
  1387. (is_difficult1, is_difficult2), axis=0)
  1388. result['difficult'] = is_difficult
  1389. return result
  1390. @register_op
  1391. class Mixup(BaseOperator):
  1392. def __init__(self, alpha=1.5, beta=1.5):
  1393. """ Mixup image and gt_bbbox/gt_score
  1394. Args:
  1395. alpha (float): alpha parameter of beta distribute
  1396. beta (float): beta parameter of beta distribute
  1397. """
  1398. super(Mixup, self).__init__()
  1399. self.alpha = alpha
  1400. self.beta = beta
  1401. if self.alpha <= 0.0:
  1402. raise ValueError("alpha shold be positive in {}".format(self))
  1403. if self.beta <= 0.0:
  1404. raise ValueError("beta shold be positive in {}".format(self))
  1405. def apply_image(self, img1, img2, factor):
  1406. h = max(img1.shape[0], img2.shape[0])
  1407. w = max(img1.shape[1], img2.shape[1])
  1408. img = np.zeros((h, w, img1.shape[2]), 'float32')
  1409. img[:img1.shape[0], :img1.shape[1], :] = \
  1410. img1.astype('float32') * factor
  1411. img[:img2.shape[0], :img2.shape[1], :] += \
  1412. img2.astype('float32') * (1.0 - factor)
  1413. return img.astype('uint8')
  1414. def __call__(self, sample, context=None):
  1415. if not isinstance(sample, Sequence):
  1416. return sample
  1417. assert len(sample) == 2, 'mixup need two samples'
  1418. factor = np.random.beta(self.alpha, self.beta)
  1419. factor = max(0.0, min(1.0, factor))
  1420. if factor >= 1.0:
  1421. return sample[0]
  1422. if factor <= 0.0:
  1423. return sample[1]
  1424. im = self.apply_image(sample[0]['image'], sample[1]['image'], factor)
  1425. result = copy.deepcopy(sample[0])
  1426. result['image'] = im
  1427. # apply bbox and score
  1428. if 'gt_bbox' in sample[0]:
  1429. gt_bbox1 = sample[0]['gt_bbox']
  1430. gt_bbox2 = sample[1]['gt_bbox']
  1431. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  1432. result['gt_bbox'] = gt_bbox
  1433. if 'gt_class' in sample[0]:
  1434. gt_class1 = sample[0]['gt_class']
  1435. gt_class2 = sample[1]['gt_class']
  1436. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  1437. result['gt_class'] = gt_class
  1438. gt_score1 = np.ones_like(sample[0]['gt_class'])
  1439. gt_score2 = np.ones_like(sample[1]['gt_class'])
  1440. gt_score = np.concatenate(
  1441. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  1442. result['gt_score'] = gt_score
  1443. if 'is_crowd' in sample[0]:
  1444. is_crowd1 = sample[0]['is_crowd']
  1445. is_crowd2 = sample[1]['is_crowd']
  1446. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  1447. result['is_crowd'] = is_crowd
  1448. if 'difficult' in sample[0]:
  1449. is_difficult1 = sample[0]['difficult']
  1450. is_difficult2 = sample[1]['difficult']
  1451. is_difficult = np.concatenate(
  1452. (is_difficult1, is_difficult2), axis=0)
  1453. result['difficult'] = is_difficult
  1454. if 'gt_ide' in sample[0]:
  1455. gt_ide1 = sample[0]['gt_ide']
  1456. gt_ide2 = sample[1]['gt_ide']
  1457. gt_ide = np.concatenate((gt_ide1, gt_ide2), axis=0)
  1458. result['gt_ide'] = gt_ide
  1459. return result
  1460. @register_op
  1461. class NormalizeBox(BaseOperator):
  1462. """Transform the bounding box's coornidates to [0,1]."""
  1463. def __init__(self):
  1464. super(NormalizeBox, self).__init__()
  1465. def apply(self, sample, context):
  1466. im = sample['image']
  1467. gt_bbox = sample['gt_bbox']
  1468. height, width, _ = im.shape
  1469. for i in range(gt_bbox.shape[0]):
  1470. gt_bbox[i][0] = gt_bbox[i][0] / width
  1471. gt_bbox[i][1] = gt_bbox[i][1] / height
  1472. gt_bbox[i][2] = gt_bbox[i][2] / width
  1473. gt_bbox[i][3] = gt_bbox[i][3] / height
  1474. sample['gt_bbox'] = gt_bbox
  1475. if 'gt_keypoint' in sample.keys():
  1476. gt_keypoint = sample['gt_keypoint']
  1477. for i in range(gt_keypoint.shape[1]):
  1478. if i % 2:
  1479. gt_keypoint[:, i] = gt_keypoint[:, i] / height
  1480. else:
  1481. gt_keypoint[:, i] = gt_keypoint[:, i] / width
  1482. sample['gt_keypoint'] = gt_keypoint
  1483. return sample
  1484. @register_op
  1485. class BboxXYXY2XYWH(BaseOperator):
  1486. """
  1487. Convert bbox XYXY format to XYWH format.
  1488. """
  1489. def __init__(self):
  1490. super(BboxXYXY2XYWH, self).__init__()
  1491. def apply(self, sample, context=None):
  1492. assert 'gt_bbox' in sample
  1493. bbox = sample['gt_bbox']
  1494. bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
  1495. bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
  1496. sample['gt_bbox'] = bbox
  1497. return sample
  1498. @register_op
  1499. class PadBox(BaseOperator):
  1500. def __init__(self, num_max_boxes=50):
  1501. """
  1502. Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
  1503. Args:
  1504. num_max_boxes (int): the max number of bboxes
  1505. """
  1506. self.num_max_boxes = num_max_boxes
  1507. super(PadBox, self).__init__()
  1508. def apply(self, sample, context=None):
  1509. assert 'gt_bbox' in sample
  1510. bbox = sample['gt_bbox']
  1511. gt_num = min(self.num_max_boxes, len(bbox))
  1512. num_max = self.num_max_boxes
  1513. # fields = context['fields'] if context else []
  1514. pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
  1515. if gt_num > 0:
  1516. pad_bbox[:gt_num, :] = bbox[:gt_num, :]
  1517. sample['gt_bbox'] = pad_bbox
  1518. if 'gt_class' in sample:
  1519. pad_class = np.zeros((num_max, ), dtype=np.int32)
  1520. if gt_num > 0:
  1521. pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
  1522. sample['gt_class'] = pad_class
  1523. if 'gt_score' in sample:
  1524. pad_score = np.zeros((num_max, ), dtype=np.float32)
  1525. if gt_num > 0:
  1526. pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
  1527. sample['gt_score'] = pad_score
  1528. # in training, for example in op ExpandImage,
  1529. # the bbox and gt_class is expandded, but the difficult is not,
  1530. # so, judging by it's length
  1531. if 'difficult' in sample:
  1532. pad_diff = np.zeros((num_max, ), dtype=np.int32)
  1533. if gt_num > 0:
  1534. pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
  1535. sample['difficult'] = pad_diff
  1536. if 'is_crowd' in sample:
  1537. pad_crowd = np.zeros((num_max, ), dtype=np.int32)
  1538. if gt_num > 0:
  1539. pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
  1540. sample['is_crowd'] = pad_crowd
  1541. if 'gt_ide' in sample:
  1542. pad_ide = np.zeros((num_max, ), dtype=np.int32)
  1543. if gt_num > 0:
  1544. pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
  1545. sample['gt_ide'] = pad_ide
  1546. return sample
  1547. @register_op
  1548. class DebugVisibleImage(BaseOperator):
  1549. """
  1550. In debug mode, visualize images according to `gt_box`.
  1551. (Currently only supported when not cropping and flipping image.)
  1552. """
  1553. def __init__(self, output_dir='output/debug', is_normalized=False):
  1554. super(DebugVisibleImage, self).__init__()
  1555. self.is_normalized = is_normalized
  1556. self.output_dir = output_dir
  1557. if not os.path.isdir(output_dir):
  1558. os.makedirs(output_dir)
  1559. if not isinstance(self.is_normalized, bool):
  1560. raise TypeError("{}: input type is invalid.".format(self))
  1561. def apply(self, sample, context=None):
  1562. image = Image.open(sample['im_file']).convert('RGB')
  1563. out_file_name = sample['im_file'].split('/')[-1]
  1564. width = sample['w']
  1565. height = sample['h']
  1566. gt_bbox = sample['gt_bbox']
  1567. gt_class = sample['gt_class']
  1568. draw = ImageDraw.Draw(image)
  1569. for i in range(gt_bbox.shape[0]):
  1570. if self.is_normalized:
  1571. gt_bbox[i][0] = gt_bbox[i][0] * width
  1572. gt_bbox[i][1] = gt_bbox[i][1] * height
  1573. gt_bbox[i][2] = gt_bbox[i][2] * width
  1574. gt_bbox[i][3] = gt_bbox[i][3] * height
  1575. xmin, ymin, xmax, ymax = gt_bbox[i]
  1576. draw.line(
  1577. [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  1578. (xmin, ymin)],
  1579. width=2,
  1580. fill='green')
  1581. # draw label
  1582. text = str(gt_class[i][0])
  1583. tw, th = draw.textsize(text)
  1584. draw.rectangle(
  1585. [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green')
  1586. draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
  1587. if 'gt_keypoint' in sample.keys():
  1588. gt_keypoint = sample['gt_keypoint']
  1589. if self.is_normalized:
  1590. for i in range(gt_keypoint.shape[1]):
  1591. if i % 2:
  1592. gt_keypoint[:, i] = gt_keypoint[:, i] * height
  1593. else:
  1594. gt_keypoint[:, i] = gt_keypoint[:, i] * width
  1595. for i in range(gt_keypoint.shape[0]):
  1596. keypoint = gt_keypoint[i]
  1597. for j in range(int(keypoint.shape[0] / 2)):
  1598. x1 = round(keypoint[2 * j]).astype(np.int32)
  1599. y1 = round(keypoint[2 * j + 1]).astype(np.int32)
  1600. draw.ellipse(
  1601. (x1, y1, x1 + 5, y1 + 5),
  1602. fill='green',
  1603. outline='green')
  1604. save_path = os.path.join(self.output_dir, out_file_name)
  1605. image.save(save_path, quality=95)
  1606. return sample
  1607. @register_op
  1608. class Pad(BaseOperator):
  1609. def __init__(self,
  1610. size=None,
  1611. size_divisor=32,
  1612. pad_mode=0,
  1613. offsets=None,
  1614. fill_value=(127.5, 127.5, 127.5)):
  1615. """
  1616. Pad image to a specified size or multiple of size_divisor.
  1617. Args:
  1618. size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
  1619. size_divisor (int): size divisor, default 32
  1620. pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
  1621. if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
  1622. offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
  1623. fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
  1624. """
  1625. super(Pad, self).__init__()
  1626. if not isinstance(size, (int, Sequence)):
  1627. raise TypeError(
  1628. "Type of target_size is invalid when random_size is True. \
  1629. Must be List, now is {}".format(type(size)))
  1630. if isinstance(size, int):
  1631. size = [size, size]
  1632. assert pad_mode in [
  1633. -1, 0, 1, 2
  1634. ], 'currently only supports four modes [-1, 0, 1, 2]'
  1635. assert pad_mode == -1 and offsets, 'if pad_mode is -1, offsets should not be None'
  1636. self.size = size
  1637. self.size_divisor = size_divisor
  1638. self.pad_mode = pad_mode
  1639. self.fill_value = fill_value
  1640. self.offsets = offsets
  1641. def apply_segm(self, segms, offsets, im_size, size):
  1642. def _expand_poly(poly, x, y):
  1643. expanded_poly = np.array(poly)
  1644. expanded_poly[0::2] += x
  1645. expanded_poly[1::2] += y
  1646. return expanded_poly.tolist()
  1647. def _expand_rle(rle, x, y, height, width, h, w):
  1648. if 'counts' in rle and type(rle['counts']) == list:
  1649. rle = mask_util.frPyObjects(rle, height, width)
  1650. mask = mask_util.decode(rle)
  1651. expanded_mask = np.full((h, w), 0).astype(mask.dtype)
  1652. expanded_mask[y:y + height, x:x + width] = mask
  1653. rle = mask_util.encode(
  1654. np.array(
  1655. expanded_mask, order='F', dtype=np.uint8))
  1656. return rle
  1657. x, y = offsets
  1658. height, width = im_size
  1659. h, w = size
  1660. expanded_segms = []
  1661. for segm in segms:
  1662. if is_poly(segm):
  1663. # Polygon format
  1664. expanded_segms.append(
  1665. [_expand_poly(poly, x, y) for poly in segm])
  1666. else:
  1667. # RLE format
  1668. import pycocotools.mask as mask_util
  1669. expanded_segms.append(
  1670. _expand_rle(segm, x, y, height, width, h, w))
  1671. return expanded_segms
  1672. def apply_bbox(self, bbox, offsets):
  1673. return bbox + np.array(offsets * 2, dtype=np.float32)
  1674. def apply_keypoint(self, keypoints, offsets):
  1675. n = len(keypoints[0]) // 2
  1676. return keypoints + np.array(offsets * n, dtype=np.float32)
  1677. def apply_image(self, image, offsets, im_size, size):
  1678. x, y = offsets
  1679. im_h, im_w = im_size
  1680. h, w = size
  1681. canvas = np.ones((h, w, 3), dtype=np.float32)
  1682. canvas *= np.array(self.fill_value, dtype=np.float32)
  1683. canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
  1684. return canvas
  1685. def apply(self, sample, context=None):
  1686. im = sample['image']
  1687. im_h, im_w = im.shape[:2]
  1688. if self.size:
  1689. h, w = self.size
  1690. assert (
  1691. im_h < h and im_w < w
  1692. ), '(h, w) of target size should be greater than (im_h, im_w)'
  1693. else:
  1694. h = np.ceil(im_h // self.size_divisor) * self.size_divisor
  1695. w = np.ceil(im_w / self.size_divisor) * self.size_divisor
  1696. if h == im_h and w == im_w:
  1697. return sample
  1698. if self.pad_mode == -1:
  1699. offset_x, offset_y = self.offsets
  1700. elif self.pad_mode == 0:
  1701. offset_y, offset_x = 0, 0
  1702. elif self.pad_mode == 1:
  1703. offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
  1704. else:
  1705. offset_y, offset_x = h - im_h, w - im_w
  1706. offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
  1707. sample['image'] = self.apply_image(im, offsets, im_size, size)
  1708. if self.pad_mode == 0:
  1709. return sample
  1710. if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
  1711. sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
  1712. if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
  1713. sample['gt_poly'] = self.apply_segm(sample['gt_poly'], offsets,
  1714. im_size, size)
  1715. if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0:
  1716. sample['gt_keypoint'] = self.apply_keypoint(sample['gt_keypoint'],
  1717. offsets)
  1718. return sample
  1719. @register_op
  1720. class Poly2Mask(BaseOperator):
  1721. """
  1722. gt poly to mask annotations
  1723. """
  1724. def __init__(self):
  1725. super(Poly2Mask, self).__init__()
  1726. import pycocotools.mask as maskUtils
  1727. self.maskutils = maskUtils
  1728. def _poly2mask(self, mask_ann, img_h, img_w):
  1729. if isinstance(mask_ann, list):
  1730. # polygon -- a single object might consist of multiple parts
  1731. # we merge all parts into one mask rle code
  1732. rles = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
  1733. rle = self.maskutils.merge(rles)
  1734. elif isinstance(mask_ann['counts'], list):
  1735. # uncompressed RLE
  1736. rle = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
  1737. else:
  1738. # rle
  1739. rle = mask_ann
  1740. mask = self.maskutils.decode(rle)
  1741. return mask
  1742. def apply(self, sample, context=None):
  1743. assert 'gt_poly' in sample
  1744. im_h = sample['h']
  1745. im_w = sample['w']
  1746. masks = [
  1747. self._poly2mask(gt_poly, im_h, im_w)
  1748. for gt_poly in sample['gt_poly']
  1749. ]
  1750. sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
  1751. return sample
  1752. @register_op
  1753. class Rbox2Poly(BaseOperator):
  1754. """
  1755. Convert rbbox format to poly format.
  1756. """
  1757. def __init__(self):
  1758. super(Rbox2Poly, self).__init__()
  1759. def apply(self, sample, context=None):
  1760. assert 'gt_rbox' in sample
  1761. assert sample['gt_rbox'].shape[1] == 5
  1762. rrects = sample['gt_rbox']
  1763. x_ctr = rrects[:, 0]
  1764. y_ctr = rrects[:, 1]
  1765. width = rrects[:, 2]
  1766. height = rrects[:, 3]
  1767. x1 = x_ctr - width / 2.0
  1768. y1 = y_ctr - height / 2.0
  1769. x2 = x_ctr + width / 2.0
  1770. y2 = y_ctr + height / 2.0
  1771. sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1)
  1772. polys = bbox_utils.rbox2poly_np(rrects)
  1773. sample['gt_rbox2poly'] = polys
  1774. return sample
  1775. @register_op
  1776. class AugmentHSV(BaseOperator):
  1777. def __init__(self, fraction=0.50, is_bgr=False):
  1778. """
  1779. Augment the SV channel of image data.
  1780. Args:
  1781. fraction (float): the fraction for augment
  1782. is_bgr (bool): whether the image is BGR mode
  1783. """
  1784. super(AugmentHSV, self).__init__()
  1785. self.fraction = fraction
  1786. self.is_bgr = is_bgr
  1787. def apply(self, sample, context=None):
  1788. img = sample['image']
  1789. if self.is_bgr:
  1790. img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
  1791. else:
  1792. img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
  1793. S = img_hsv[:, :, 1].astype(np.float32)
  1794. V = img_hsv[:, :, 2].astype(np.float32)
  1795. a = (random.random() * 2 - 1) * self.fraction + 1
  1796. S *= a
  1797. if a > 1:
  1798. np.clip(S, a_min=0, a_max=255, out=S)
  1799. a = (random.random() * 2 - 1) * self.fraction + 1
  1800. V *= a
  1801. if a > 1:
  1802. np.clip(V, a_min=0, a_max=255, out=V)
  1803. img_hsv[:, :, 1] = S.astype(np.uint8)
  1804. img_hsv[:, :, 2] = V.astype(np.uint8)
  1805. if self.is_bgr:
  1806. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
  1807. else:
  1808. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB, dst=img)
  1809. sample['image'] = img
  1810. return sample
  1811. @register_op
  1812. class Norm2PixelBbox(BaseOperator):
  1813. """
  1814. Transform the bounding box's coornidates which is in [0,1] to pixels.
  1815. """
  1816. def __init__(self):
  1817. super(Norm2PixelBbox, self).__init__()
  1818. def apply(self, sample, context=None):
  1819. assert 'gt_bbox' in sample
  1820. bbox = sample['gt_bbox']
  1821. height, width = sample['image'].shape[:2]
  1822. bbox[:, 0::2] = bbox[:, 0::2] * width
  1823. bbox[:, 1::2] = bbox[:, 1::2] * height
  1824. sample['gt_bbox'] = bbox
  1825. return sample
  1826. @register_op
  1827. class BboxCXCYWH2XYXY(BaseOperator):
  1828. """
  1829. Convert bbox CXCYWH format to XYXY format.
  1830. [center_x, center_y, width, height] -> [x0, y0, x1, y1]
  1831. """
  1832. def __init__(self):
  1833. super(BboxCXCYWH2XYXY, self).__init__()
  1834. def apply(self, sample, context=None):
  1835. assert 'gt_bbox' in sample
  1836. bbox0 = sample['gt_bbox']
  1837. bbox = bbox0.copy()
  1838. bbox[:, :2] = bbox0[:, :2] - bbox0[:, 2:4] / 2.
  1839. bbox[:, 2:4] = bbox0[:, :2] + bbox0[:, 2:4] / 2.
  1840. sample['gt_bbox'] = bbox
  1841. return sample