| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # function:
- # operators to process sample,
- # eg: decode/resize/crop image
- from __future__ import absolute_import
- from __future__ import print_function
- from __future__ import division
- try:
- from collections.abc import Sequence
- except Exception:
- from collections import Sequence
- from numbers import Number, Integral
- import uuid
- import random
- import math
- import numpy as np
- import os
- import copy
- import cv2
- from PIL import Image, ImageEnhance, ImageDraw
- from paddlex.ppdet.core.workspace import serializable
- from paddlex.ppdet.modeling.layers import AnchorGrid
- from paddlex.ppdet.modeling import bbox_utils
- from .op_helper import (
- satisfy_sample_constraint, filter_and_process, generate_sample_bbox,
- clip_bbox, data_anchor_sampling, satisfy_sample_constraint_coverage,
- crop_image_sampling, generate_sample_bbox_square, bbox_area_sampling,
- is_poly, gaussian_radius, draw_gaussian)
- from paddlex.ppdet.utils.logger import setup_logger
- logger = setup_logger(__name__)
- registered_ops = []
- def register_op(cls):
- registered_ops.append(cls.__name__)
- if not hasattr(BaseOperator, cls.__name__):
- setattr(BaseOperator, cls.__name__, cls)
- else:
- raise KeyError("The {} class has been registered.".format(
- cls.__name__))
- return serializable(cls)
- class BboxError(ValueError):
- pass
- class ImageError(ValueError):
- pass
- class BaseOperator(object):
- def __init__(self, name=None):
- if name is None:
- name = self.__class__.__name__
- self._id = name + '_' + str(uuid.uuid4())[-6:]
- def apply(self, sample, context=None):
- """ Process a sample.
- Args:
- sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
- context (dict): info about this sample processing
- Returns:
- result (dict): a processed sample
- """
- return sample
- def __call__(self, sample, context=None):
- """ Process a sample.
- Args:
- sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
- context (dict): info about this sample processing
- Returns:
- result (dict): a processed sample
- """
- if isinstance(sample, Sequence):
- for i in range(len(sample)):
- sample[i] = self.apply(sample[i], context)
- else:
- sample = self.apply(sample, context)
- return sample
- def __str__(self):
- return str(self._id)
- @register_op
- class Decode(BaseOperator):
- def __init__(self):
- """ Transform the image data to numpy format following the rgb format
- """
- super(Decode, self).__init__()
- def apply(self, sample, context=None):
- """ load image if 'im_file' field is not empty but 'image' is"""
- if 'image' not in sample:
- with open(sample['im_file'], 'rb') as f:
- sample['image'] = f.read()
- sample.pop('im_file')
- im = sample['image']
- data = np.frombuffer(im, dtype='uint8')
- im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
- if 'keep_ori_im' in sample and sample['keep_ori_im']:
- sample['ori_image'] = im
- im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
- sample['image'] = im
- if 'h' not in sample:
- sample['h'] = im.shape[0]
- elif sample['h'] != im.shape[0]:
- logger.warn(
- "The actual image height: {} is not equal to the "
- "height: {} in annotation, and update sample['h'] by actual "
- "image height.".format(im.shape[0], sample['h']))
- sample['h'] = im.shape[0]
- if 'w' not in sample:
- sample['w'] = im.shape[1]
- elif sample['w'] != im.shape[1]:
- logger.warn(
- "The actual image width: {} is not equal to the "
- "width: {} in annotation, and update sample['w'] by actual "
- "image width.".format(im.shape[1], sample['w']))
- sample['w'] = im.shape[1]
- sample['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
- sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
- return sample
- @register_op
- class Permute(BaseOperator):
- def __init__(self):
- """
- Change the channel to be (C, H, W)
- """
- super(Permute, self).__init__()
- def apply(self, sample, context=None):
- im = sample['image']
- im = im.transpose((2, 0, 1))
- sample['image'] = im
- return sample
- @register_op
- class Lighting(BaseOperator):
- """
- Lighting the image by eigenvalues and eigenvectors
- Args:
- eigval (list): eigenvalues
- eigvec (list): eigenvectors
- alphastd (float): random weight of lighting, 0.1 by default
- """
- def __init__(self, eigval, eigvec, alphastd=0.1):
- super(Lighting, self).__init__()
- self.alphastd = alphastd
- self.eigval = np.array(eigval).astype('float32')
- self.eigvec = np.array(eigvec).astype('float32')
- def apply(self, sample, context=None):
- alpha = np.random.normal(scale=self.alphastd, size=(3, ))
- sample['image'] += np.dot(self.eigvec, self.eigval * alpha)
- return sample
- @register_op
- class RandomErasingImage(BaseOperator):
- def __init__(self, prob=0.5, lower=0.02, higher=0.4, aspect_ratio=0.3):
- """
- Random Erasing Data Augmentation, see https://arxiv.org/abs/1708.04896
- Args:
- prob (float): probability to carry out random erasing
- lower (float): lower limit of the erasing area ratio
- heigher (float): upper limit of the erasing area ratio
- aspect_ratio (float): aspect ratio of the erasing region
- """
- super(RandomErasingImage, self).__init__()
- self.prob = prob
- self.lower = lower
- self.heigher = heigher
- self.aspect_ratio = aspect_ratio
- def apply(self, sample):
- gt_bbox = sample['gt_bbox']
- im = sample['image']
- if not isinstance(im, np.ndarray):
- raise TypeError("{}: image is not a numpy array.".format(self))
- if len(im.shape) != 3:
- raise ImageError("{}: image is not 3-dimensional.".format(self))
- for idx in range(gt_bbox.shape[0]):
- if self.prob <= np.random.rand():
- continue
- x1, y1, x2, y2 = gt_bbox[idx, :]
- w_bbox = x2 - x1
- h_bbox = y2 - y1
- area = w_bbox * h_bbox
- target_area = random.uniform(self.lower, self.higher) * area
- aspect_ratio = random.uniform(self.aspect_ratio,
- 1 / self.aspect_ratio)
- h = int(round(math.sqrt(target_area * aspect_ratio)))
- w = int(round(math.sqrt(target_area / aspect_ratio)))
- if w < w_bbox and h < h_bbox:
- off_y1 = random.randint(0, int(h_bbox - h))
- off_x1 = random.randint(0, int(w_bbox - w))
- im[int(y1 + off_y1):int(y1 + off_y1 + h), int(x1 + off_x1):int(
- x1 + off_x1 + w), :] = 0
- sample['image'] = im
- return sample
- @register_op
- class NormalizeImage(BaseOperator):
- def __init__(self,
- mean=[0.485, 0.456, 0.406],
- std=[1, 1, 1],
- is_scale=True):
- """
- Args:
- mean (list): the pixel mean
- std (list): the pixel variance
- """
- super(NormalizeImage, self).__init__()
- self.mean = mean
- self.std = std
- self.is_scale = is_scale
- if not (isinstance(self.mean, list) and isinstance(self.std, list) and
- isinstance(self.is_scale, bool)):
- raise TypeError("{}: input type is invalid.".format(self))
- from functools import reduce
- if reduce(lambda x, y: x * y, self.std) == 0:
- raise ValueError('{}: std is invalid!'.format(self))
- def apply(self, sample, context=None):
- """Normalize the image.
- Operators:
- 1.(optional) Scale the image to [0,1]
- 2. Each pixel minus mean and is divided by std
- """
- im = sample['image']
- im = im.astype(np.float32, copy=False)
- mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
- std = np.array(self.std)[np.newaxis, np.newaxis, :]
- if self.is_scale:
- im = im / 255.0
- im -= mean
- im /= std
- sample['image'] = im
- return sample
- @register_op
- class GridMask(BaseOperator):
- def __init__(self,
- use_h=True,
- use_w=True,
- rotate=1,
- offset=False,
- ratio=0.5,
- mode=1,
- prob=0.7,
- upper_iter=360000):
- """
- GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086
- Args:
- use_h (bool): whether to mask vertically
- use_w (boo;): whether to mask horizontally
- rotate (float): angle for the mask to rotate
- offset (float): mask offset
- ratio (float): mask ratio
- mode (int): gridmask mode
- prob (float): max probability to carry out gridmask
- upper_iter (int): suggested to be equal to global max_iter
- """
- super(GridMask, self).__init__()
- self.use_h = use_h
- self.use_w = use_w
- self.rotate = rotate
- self.offset = offset
- self.ratio = ratio
- self.mode = mode
- self.prob = prob
- self.upper_iter = upper_iter
- from .gridmask_utils import Gridmask
- self.gridmask_op = Gridmask(
- use_h,
- use_w,
- rotate=rotate,
- offset=offset,
- ratio=ratio,
- mode=mode,
- prob=prob,
- upper_iter=upper_iter)
- def apply(self, sample, context=None):
- sample['image'] = self.gridmask_op(sample['image'],
- sample['curr_iter'])
- return sample
- @register_op
- class RandomDistort(BaseOperator):
- """Random color distortion.
- Args:
- hue (list): hue settings. in [lower, upper, probability] format.
- saturation (list): saturation settings. in [lower, upper, probability] format.
- contrast (list): contrast settings. in [lower, upper, probability] format.
- brightness (list): brightness settings. in [lower, upper, probability] format.
- random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
- order.
- count (int): the number of doing distrot
- random_channel (bool): whether to swap channels randomly
- """
- def __init__(self,
- hue=[-18, 18, 0.5],
- saturation=[0.5, 1.5, 0.5],
- contrast=[0.5, 1.5, 0.5],
- brightness=[0.5, 1.5, 0.5],
- random_apply=True,
- count=4,
- random_channel=False):
- super(RandomDistort, self).__init__()
- self.hue = hue
- self.saturation = saturation
- self.contrast = contrast
- self.brightness = brightness
- self.random_apply = random_apply
- self.count = count
- self.random_channel = random_channel
- def apply_hue(self, img):
- low, high, prob = self.hue
- if np.random.uniform(0., 1.) < prob:
- return img
- img = img.astype(np.float32)
- # it works, but result differ from HSV version
- delta = np.random.uniform(low, high)
- u = np.cos(delta * np.pi)
- w = np.sin(delta * np.pi)
- bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
- tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
- [0.211, -0.523, 0.311]])
- ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
- [1.0, -1.107, 1.705]])
- t = np.dot(np.dot(ityiq, bt), tyiq).T
- img = np.dot(img, t)
- return img
- def apply_saturation(self, img):
- low, high, prob = self.saturation
- if np.random.uniform(0., 1.) < prob:
- return img
- delta = np.random.uniform(low, high)
- img = img.astype(np.float32)
- # it works, but result differ from HSV version
- gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
- gray = gray.sum(axis=2, keepdims=True)
- gray *= (1.0 - delta)
- img *= delta
- img += gray
- return img
- def apply_contrast(self, img):
- low, high, prob = self.contrast
- if np.random.uniform(0., 1.) < prob:
- return img
- delta = np.random.uniform(low, high)
- img = img.astype(np.float32)
- img *= delta
- return img
- def apply_brightness(self, img):
- low, high, prob = self.brightness
- if np.random.uniform(0., 1.) < prob:
- return img
- delta = np.random.uniform(low, high)
- img = img.astype(np.float32)
- img += delta
- return img
- def apply(self, sample, context=None):
- img = sample['image']
- if self.random_apply:
- functions = [
- self.apply_brightness, self.apply_contrast,
- self.apply_saturation, self.apply_hue
- ]
- distortions = np.random.permutation(functions)[:self.count]
- for func in distortions:
- img = func(img)
- sample['image'] = img
- return sample
- img = self.apply_brightness(img)
- mode = np.random.randint(0, 2)
- if mode:
- img = self.apply_contrast(img)
- img = self.apply_saturation(img)
- img = self.apply_hue(img)
- if not mode:
- img = self.apply_contrast(img)
- if self.random_channel:
- if np.random.randint(0, 2):
- img = img[..., np.random.permutation(3)]
- sample['image'] = img
- return sample
- @register_op
- class AutoAugment(BaseOperator):
- def __init__(self, autoaug_type="v1"):
- """
- Args:
- autoaug_type (str): autoaug type, support v0, v1, v2, v3, test
- """
- super(AutoAugment, self).__init__()
- self.autoaug_type = autoaug_type
- def apply(self, sample, context=None):
- """
- Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172
- """
- im = sample['image']
- gt_bbox = sample['gt_bbox']
- if not isinstance(im, np.ndarray):
- raise TypeError("{}: image is not a numpy array.".format(self))
- if len(im.shape) != 3:
- raise ImageError("{}: image is not 3-dimensional.".format(self))
- if len(gt_bbox) == 0:
- return sample
- height, width, _ = im.shape
- norm_gt_bbox = np.ones_like(gt_bbox, dtype=np.float32)
- norm_gt_bbox[:, 0] = gt_bbox[:, 1] / float(height)
- norm_gt_bbox[:, 1] = gt_bbox[:, 0] / float(width)
- norm_gt_bbox[:, 2] = gt_bbox[:, 3] / float(height)
- norm_gt_bbox[:, 3] = gt_bbox[:, 2] / float(width)
- from .autoaugment_utils import distort_image_with_autoaugment
- im, norm_gt_bbox = distort_image_with_autoaugment(im, norm_gt_bbox,
- self.autoaug_type)
- gt_bbox[:, 0] = norm_gt_bbox[:, 1] * float(width)
- gt_bbox[:, 1] = norm_gt_bbox[:, 0] * float(height)
- gt_bbox[:, 2] = norm_gt_bbox[:, 3] * float(width)
- gt_bbox[:, 3] = norm_gt_bbox[:, 2] * float(height)
- sample['image'] = im
- sample['gt_bbox'] = gt_bbox
- return sample
- @register_op
- class RandomFlip(BaseOperator):
- def __init__(self, prob=0.5):
- """
- Args:
- prob (float): the probability of flipping image
- """
- super(RandomFlip, self).__init__()
- self.prob = prob
- if not (isinstance(self.prob, float)):
- raise TypeError("{}: input type is invalid.".format(self))
- def apply_segm(self, segms, height, width):
- def _flip_poly(poly, width):
- flipped_poly = np.array(poly)
- flipped_poly[0::2] = width - np.array(poly[0::2])
- return flipped_poly.tolist()
- def _flip_rle(rle, height, width):
- if 'counts' in rle and type(rle['counts']) == list:
- rle = mask_util.frPyObjects(rle, height, width)
- mask = mask_util.decode(rle)
- mask = mask[:, ::-1]
- rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
- return rle
- flipped_segms = []
- for segm in segms:
- if is_poly(segm):
- # Polygon format
- flipped_segms.append(
- [_flip_poly(poly, width) for poly in segm])
- else:
- # RLE format
- import pycocotools.mask as mask_util
- flipped_segms.append(_flip_rle(segm, height, width))
- return flipped_segms
- def apply_keypoint(self, gt_keypoint, width):
- for i in range(gt_keypoint.shape[1]):
- if i % 2 == 0:
- old_x = gt_keypoint[:, i].copy()
- gt_keypoint[:, i] = width - old_x
- return gt_keypoint
- def apply_image(self, image):
- return image[:, ::-1, :]
- def apply_bbox(self, bbox, width):
- oldx1 = bbox[:, 0].copy()
- oldx2 = bbox[:, 2].copy()
- bbox[:, 0] = width - oldx2
- bbox[:, 2] = width - oldx1
- return bbox
- def apply_rbox(self, bbox, width):
- oldx1 = bbox[:, 0].copy()
- oldx2 = bbox[:, 2].copy()
- oldx3 = bbox[:, 4].copy()
- oldx4 = bbox[:, 6].copy()
- bbox[:, 0] = width - oldx1
- bbox[:, 2] = width - oldx2
- bbox[:, 4] = width - oldx3
- bbox[:, 6] = width - oldx4
- bbox = [bbox_utils.get_best_begin_point_single(e) for e in bbox]
- return bbox
- def apply(self, sample, context=None):
- """Filp the image and bounding box.
- Operators:
- 1. Flip the image numpy.
- 2. Transform the bboxes' x coordinates.
- (Must judge whether the coordinates are normalized!)
- 3. Transform the segmentations' x coordinates.
- (Must judge whether the coordinates are normalized!)
- Output:
- sample: the image, bounding box and segmentation part
- in sample are flipped.
- """
- if np.random.uniform(0, 1) < self.prob:
- im = sample['image']
- height, width = im.shape[:2]
- im = self.apply_image(im)
- if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
- sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], width)
- if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
- sample['gt_poly'] = self.apply_segm(sample['gt_poly'], height,
- width)
- if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0:
- sample['gt_keypoint'] = self.apply_keypoint(
- sample['gt_keypoint'], width)
- if 'semantic' in sample and sample['semantic']:
- sample['semantic'] = sample['semantic'][:, ::-1]
- if 'gt_segm' in sample and sample['gt_segm'].any():
- sample['gt_segm'] = sample['gt_segm'][:, :, ::-1]
- if 'gt_rbox2poly' in sample and sample['gt_rbox2poly'].any():
- sample['gt_rbox2poly'] = self.apply_rbox(
- sample['gt_rbox2poly'], width)
- sample['flipped'] = True
- sample['image'] = im
- return sample
- @register_op
- class Resize(BaseOperator):
- def __init__(self, target_size, keep_ratio, interp=cv2.INTER_LINEAR):
- """
- Resize image to target size. if keep_ratio is True,
- resize the image's long side to the maximum of target_size
- if keep_ratio is False, resize the image to target size(h, w)
- Args:
- target_size (int|list): image target size
- keep_ratio (bool): whether keep_ratio or not, default true
- interp (int): the interpolation method
- """
- super(Resize, self).__init__()
- self.keep_ratio = keep_ratio
- self.interp = interp
- if not isinstance(target_size, (Integral, Sequence)):
- raise TypeError(
- "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
- format(type(target_size)))
- if isinstance(target_size, Integral):
- target_size = [target_size, target_size]
- self.target_size = target_size
- def apply_image(self, image, scale):
- im_scale_x, im_scale_y = scale
- return cv2.resize(
- image,
- None,
- None,
- fx=im_scale_x,
- fy=im_scale_y,
- interpolation=self.interp)
- def apply_bbox(self, bbox, scale, size):
- im_scale_x, im_scale_y = scale
- resize_w, resize_h = size
- bbox[:, 0::2] *= im_scale_x
- bbox[:, 1::2] *= im_scale_y
- bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, resize_w)
- bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h)
- return bbox
- def apply_segm(self, segms, im_size, scale):
- def _resize_poly(poly, im_scale_x, im_scale_y):
- resized_poly = np.array(poly).astype('float32')
- resized_poly[0::2] *= im_scale_x
- resized_poly[1::2] *= im_scale_y
- return resized_poly.tolist()
- def _resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y):
- if 'counts' in rle and type(rle['counts']) == list:
- rle = mask_util.frPyObjects(rle, im_h, im_w)
- mask = mask_util.decode(rle)
- mask = cv2.resize(
- image,
- None,
- None,
- fx=im_scale_x,
- fy=im_scale_y,
- interpolation=self.interp)
- rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
- return rle
- im_h, im_w = im_size
- im_scale_x, im_scale_y = scale
- resized_segms = []
- for segm in segms:
- if is_poly(segm):
- # Polygon format
- resized_segms.append([
- _resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
- ])
- else:
- # RLE format
- import pycocotools.mask as mask_util
- resized_segms.append(
- _resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
- return resized_segms
- def apply(self, sample, context=None):
- """ Resize the image numpy.
- """
- im = sample['image']
- if not isinstance(im, np.ndarray):
- raise TypeError("{}: image type is not numpy.".format(self))
- if len(im.shape) != 3:
- raise ImageError('{}: image is not 3-dimensional.'.format(self))
- # apply image
- im_shape = im.shape
- if self.keep_ratio:
- im_size_min = np.min(im_shape[0:2])
- im_size_max = np.max(im_shape[0:2])
- target_size_min = np.min(self.target_size)
- target_size_max = np.max(self.target_size)
- im_scale = min(target_size_min / im_size_min,
- target_size_max / im_size_max)
- resize_h = im_scale * float(im_shape[0])
- resize_w = im_scale * float(im_shape[1])
- im_scale_x = im_scale
- im_scale_y = im_scale
- else:
- resize_h, resize_w = self.target_size
- im_scale_y = resize_h / im_shape[0]
- im_scale_x = resize_w / im_shape[1]
- im = self.apply_image(sample['image'], [im_scale_x, im_scale_y])
- sample['image'] = im
- sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32)
- if 'scale_factor' in sample:
- scale_factor = sample['scale_factor']
- sample['scale_factor'] = np.asarray(
- [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
- dtype=np.float32)
- else:
- sample['scale_factor'] = np.asarray(
- [im_scale_y, im_scale_x], dtype=np.float32)
- # apply bbox
- if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
- sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'],
- [im_scale_x, im_scale_y],
- [resize_w, resize_h])
- # apply rbox
- if 'gt_rbox2poly' in sample:
- if np.array(sample['gt_rbox2poly']).shape[1] != 8:
- logger.warn(
- "gt_rbox2poly's length shoule be 8, but actually is {}".
- format(len(sample['gt_rbox2poly'])))
- sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'],
- [im_scale_x, im_scale_y],
- [resize_w, resize_h])
- # apply polygon
- if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
- sample['gt_poly'] = self.apply_segm(
- sample['gt_poly'], im_shape[:2], [im_scale_x, im_scale_y])
- # apply semantic
- if 'semantic' in sample and sample['semantic']:
- semantic = sample['semantic']
- semantic = cv2.resize(
- semantic.astype('float32'),
- None,
- None,
- fx=im_scale_x,
- fy=im_scale_y,
- interpolation=self.interp)
- semantic = np.asarray(semantic).astype('int32')
- semantic = np.expand_dims(semantic, 0)
- sample['semantic'] = semantic
- # apply gt_segm
- if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
- masks = [
- cv2.resize(
- gt_segm,
- None,
- None,
- fx=im_scale_x,
- fy=im_scale_y,
- interpolation=cv2.INTER_NEAREST)
- for gt_segm in sample['gt_segm']
- ]
- sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
- return sample
- @register_op
- class MultiscaleTestResize(BaseOperator):
- def __init__(self,
- origin_target_size=[800, 1333],
- target_size=[],
- interp=cv2.INTER_LINEAR,
- use_flip=True):
- """
- Rescale image to the each size in target size, and capped at max_size.
- Args:
- origin_target_size (list): origin target size of image
- target_size (list): A list of target sizes of image.
- interp (int): the interpolation method.
- use_flip (bool): whether use flip augmentation.
- """
- super(MultiscaleTestResize, self).__init__()
- self.interp = interp
- self.use_flip = use_flip
- if not isinstance(target_size, Sequence):
- raise TypeError(
- "Type of target_size is invalid. Must be List or Tuple, now is {}".
- format(type(target_size)))
- self.target_size = target_size
- if not isinstance(origin_target_size, Sequence):
- raise TypeError(
- "Type of origin_target_size is invalid. Must be List or Tuple, now is {}".
- format(type(origin_target_size)))
- self.origin_target_size = origin_target_size
- def apply(self, sample, context=None):
- """ Resize the image numpy for multi-scale test.
- """
- samples = []
- resizer = Resize(
- self.origin_target_size, keep_ratio=True, interp=self.interp)
- samples.append(resizer(sample.copy(), context))
- if self.use_flip:
- flipper = RandomFlip(1.1)
- samples.append(flipper(sample.copy(), context=context))
- for size in self.target_size:
- resizer = Resize(size, keep_ratio=True, interp=self.interp)
- samples.append(resizer(sample.copy(), context))
- return samples
- @register_op
- class RandomResize(BaseOperator):
- def __init__(self,
- target_size,
- keep_ratio=True,
- interp=cv2.INTER_LINEAR,
- random_size=True,
- random_interp=False):
- """
- Resize image to target size randomly. random target_size and interpolation method
- Args:
- target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
- keep_ratio (bool): whether keep_raio or not, default true
- interp (int): the interpolation method
- random_size (bool): whether random select target size of image
- random_interp (bool): whether random select interpolation method
- """
- super(RandomResize, self).__init__()
- self.keep_ratio = keep_ratio
- self.interp = interp
- self.interps = [
- cv2.INTER_NEAREST,
- cv2.INTER_LINEAR,
- cv2.INTER_AREA,
- cv2.INTER_CUBIC,
- cv2.INTER_LANCZOS4,
- ]
- assert isinstance(target_size, (
- Integral, Sequence)), "target_size must be Integer, List or Tuple"
- if random_size and not isinstance(target_size, Sequence):
- raise TypeError(
- "Type of target_size is invalid when random_size is True. Must be List or Tuple, now is {}".
- format(type(target_size)))
- self.target_size = target_size
- self.random_size = random_size
- self.random_interp = random_interp
- def apply(self, sample, context=None):
- """ Resize the image numpy.
- """
- if self.random_size:
- target_size = random.choice(self.target_size)
- else:
- target_size = self.target_size
- if self.random_interp:
- interp = random.choice(self.interps)
- else:
- interp = self.interp
- resizer = Resize(target_size, self.keep_ratio, interp)
- return resizer(sample, context=context)
- @register_op
- class RandomExpand(BaseOperator):
- """Random expand the canvas.
- Args:
- ratio (float): maximum expansion ratio.
- prob (float): probability to expand.
- fill_value (list): color value used to fill the canvas. in RGB order.
- """
- def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)):
- super(RandomExpand, self).__init__()
- assert ratio > 1.01, "expand ratio must be larger than 1.01"
- self.ratio = ratio
- self.prob = prob
- assert isinstance(fill_value, (Number, Sequence)), \
- "fill value must be either float or sequence"
- if isinstance(fill_value, Number):
- fill_value = (fill_value, ) * 3
- if not isinstance(fill_value, tuple):
- fill_value = tuple(fill_value)
- self.fill_value = fill_value
- def apply(self, sample, context=None):
- if np.random.uniform(0., 1.) < self.prob:
- return sample
- im = sample['image']
- height, width = im.shape[:2]
- ratio = np.random.uniform(1., self.ratio)
- h = int(height * ratio)
- w = int(width * ratio)
- if not h > height or not w > width:
- return sample
- y = np.random.randint(0, h - height)
- x = np.random.randint(0, w - width)
- offsets, size = [x, y], [h, w]
- pad = Pad(size,
- pad_mode=-1,
- offsets=offsets,
- fill_value=self.fill_value)
- return pad(sample, context=context)
- @register_op
- class CropWithSampling(BaseOperator):
- def __init__(self, batch_sampler, satisfy_all=False, avoid_no_bbox=True):
- """
- Args:
- batch_sampler (list): Multiple sets of different
- parameters for cropping.
- satisfy_all (bool): whether all boxes must satisfy.
- e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
- [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]
- [max sample, max trial, min scale, max scale,
- min aspect ratio, max aspect ratio,
- min overlap, max overlap]
- avoid_no_bbox (bool): whether to to avoid the
- situation where the box does not appear.
- """
- super(CropWithSampling, self).__init__()
- self.batch_sampler = batch_sampler
- self.satisfy_all = satisfy_all
- self.avoid_no_bbox = avoid_no_bbox
- def apply(self, sample, context):
- """
- Crop the image and modify bounding box.
- Operators:
- 1. Scale the image width and height.
- 2. Crop the image according to a radom sample.
- 3. Rescale the bounding box.
- 4. Determine if the new bbox is satisfied in the new image.
- Returns:
- sample: the image, bounding box are replaced.
- """
- assert 'image' in sample, "image data not found"
- im = sample['image']
- gt_bbox = sample['gt_bbox']
- gt_class = sample['gt_class']
- im_height, im_width = im.shape[:2]
- gt_score = None
- if 'gt_score' in sample:
- gt_score = sample['gt_score']
- sampled_bbox = []
- gt_bbox = gt_bbox.tolist()
- for sampler in self.batch_sampler:
- found = 0
- for i in range(sampler[1]):
- if found >= sampler[0]:
- break
- sample_bbox = generate_sample_bbox(sampler)
- if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox,
- self.satisfy_all):
- sampled_bbox.append(sample_bbox)
- found = found + 1
- im = np.array(im)
- while sampled_bbox:
- idx = int(np.random.uniform(0, len(sampled_bbox)))
- sample_bbox = sampled_bbox.pop(idx)
- sample_bbox = clip_bbox(sample_bbox)
- crop_bbox, crop_class, crop_score = \
- filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score)
- if self.avoid_no_bbox:
- if len(crop_bbox) < 1:
- continue
- xmin = int(sample_bbox[0] * im_width)
- xmax = int(sample_bbox[2] * im_width)
- ymin = int(sample_bbox[1] * im_height)
- ymax = int(sample_bbox[3] * im_height)
- im = im[ymin:ymax, xmin:xmax]
- sample['image'] = im
- sample['gt_bbox'] = crop_bbox
- sample['gt_class'] = crop_class
- sample['gt_score'] = crop_score
- return sample
- return sample
- @register_op
- class CropWithDataAchorSampling(BaseOperator):
- def __init__(self,
- batch_sampler,
- anchor_sampler=None,
- target_size=None,
- das_anchor_scales=[16, 32, 64, 128],
- sampling_prob=0.5,
- min_size=8.,
- avoid_no_bbox=True):
- """
- Args:
- anchor_sampler (list): anchor_sampling sets of different
- parameters for cropping.
- batch_sampler (list): Multiple sets of different
- parameters for cropping.
- e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]]
- [[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
- [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]]
- [max sample, max trial, min scale, max scale,
- min aspect ratio, max aspect ratio,
- min overlap, max overlap, min coverage, max coverage]
- target_size (int): target image size.
- das_anchor_scales (list[float]): a list of anchor scales in data
- anchor smapling.
- min_size (float): minimum size of sampled bbox.
- avoid_no_bbox (bool): whether to to avoid the
- situation where the box does not appear.
- """
- super(CropWithDataAchorSampling, self).__init__()
- self.anchor_sampler = anchor_sampler
- self.batch_sampler = batch_sampler
- self.target_size = target_size
- self.sampling_prob = sampling_prob
- self.min_size = min_size
- self.avoid_no_bbox = avoid_no_bbox
- self.das_anchor_scales = np.array(das_anchor_scales)
- def apply(self, sample, context):
- """
- Crop the image and modify bounding box.
- Operators:
- 1. Scale the image width and height.
- 2. Crop the image according to a radom sample.
- 3. Rescale the bounding box.
- 4. Determine if the new bbox is satisfied in the new image.
- Returns:
- sample: the image, bounding box are replaced.
- """
- assert 'image' in sample, "image data not found"
- im = sample['image']
- gt_bbox = sample['gt_bbox']
- gt_class = sample['gt_class']
- image_height, image_width = im.shape[:2]
- gt_bbox[:, 0] /= image_width
- gt_bbox[:, 1] /= image_height
- gt_bbox[:, 2] /= image_width
- gt_bbox[:, 3] /= image_height
- gt_score = None
- if 'gt_score' in sample:
- gt_score = sample['gt_score']
- sampled_bbox = []
- gt_bbox = gt_bbox.tolist()
- prob = np.random.uniform(0., 1.)
- if prob > self.sampling_prob: # anchor sampling
- assert self.anchor_sampler
- for sampler in self.anchor_sampler:
- found = 0
- for i in range(sampler[1]):
- if found >= sampler[0]:
- break
- sample_bbox = data_anchor_sampling(
- gt_bbox, image_width, image_height,
- self.das_anchor_scales, self.target_size)
- if sample_bbox == 0:
- break
- if satisfy_sample_constraint_coverage(sampler, sample_bbox,
- gt_bbox):
- sampled_bbox.append(sample_bbox)
- found = found + 1
- im = np.array(im)
- while sampled_bbox:
- idx = int(np.random.uniform(0, len(sampled_bbox)))
- sample_bbox = sampled_bbox.pop(idx)
- if 'gt_keypoint' in sample.keys():
- keypoints = (sample['gt_keypoint'],
- sample['keypoint_ignore'])
- crop_bbox, crop_class, crop_score, gt_keypoints = \
- filter_and_process(sample_bbox, gt_bbox, gt_class,
- scores=gt_score,
- keypoints=keypoints)
- else:
- crop_bbox, crop_class, crop_score = filter_and_process(
- sample_bbox, gt_bbox, gt_class, scores=gt_score)
- crop_bbox, crop_class, crop_score = bbox_area_sampling(
- crop_bbox, crop_class, crop_score, self.target_size,
- self.min_size)
- if self.avoid_no_bbox:
- if len(crop_bbox) < 1:
- continue
- im = crop_image_sampling(im, sample_bbox, image_width,
- image_height, self.target_size)
- height, width = im.shape[:2]
- crop_bbox[:, 0] *= width
- crop_bbox[:, 1] *= height
- crop_bbox[:, 2] *= width
- crop_bbox[:, 3] *= height
- sample['image'] = im
- sample['gt_bbox'] = crop_bbox
- sample['gt_class'] = crop_class
- if 'gt_score' in sample:
- sample['gt_score'] = crop_score
- if 'gt_keypoint' in sample.keys():
- sample['gt_keypoint'] = gt_keypoints[0]
- sample['keypoint_ignore'] = gt_keypoints[1]
- return sample
- return sample
- else:
- for sampler in self.batch_sampler:
- found = 0
- for i in range(sampler[1]):
- if found >= sampler[0]:
- break
- sample_bbox = generate_sample_bbox_square(
- sampler, image_width, image_height)
- if satisfy_sample_constraint_coverage(sampler, sample_bbox,
- gt_bbox):
- sampled_bbox.append(sample_bbox)
- found = found + 1
- im = np.array(im)
- while sampled_bbox:
- idx = int(np.random.uniform(0, len(sampled_bbox)))
- sample_bbox = sampled_bbox.pop(idx)
- sample_bbox = clip_bbox(sample_bbox)
- if 'gt_keypoint' in sample.keys():
- keypoints = (sample['gt_keypoint'],
- sample['keypoint_ignore'])
- crop_bbox, crop_class, crop_score, gt_keypoints = \
- filter_and_process(sample_bbox, gt_bbox, gt_class,
- scores=gt_score,
- keypoints=keypoints)
- else:
- crop_bbox, crop_class, crop_score = filter_and_process(
- sample_bbox, gt_bbox, gt_class, scores=gt_score)
- # sampling bbox according the bbox area
- crop_bbox, crop_class, crop_score = bbox_area_sampling(
- crop_bbox, crop_class, crop_score, self.target_size,
- self.min_size)
- if self.avoid_no_bbox:
- if len(crop_bbox) < 1:
- continue
- xmin = int(sample_bbox[0] * image_width)
- xmax = int(sample_bbox[2] * image_width)
- ymin = int(sample_bbox[1] * image_height)
- ymax = int(sample_bbox[3] * image_height)
- im = im[ymin:ymax, xmin:xmax]
- height, width = im.shape[:2]
- crop_bbox[:, 0] *= width
- crop_bbox[:, 1] *= height
- crop_bbox[:, 2] *= width
- crop_bbox[:, 3] *= height
- sample['image'] = im
- sample['gt_bbox'] = crop_bbox
- sample['gt_class'] = crop_class
- if 'gt_score' in sample:
- sample['gt_score'] = crop_score
- if 'gt_keypoint' in sample.keys():
- sample['gt_keypoint'] = gt_keypoints[0]
- sample['keypoint_ignore'] = gt_keypoints[1]
- return sample
- return sample
- @register_op
- class RandomCrop(BaseOperator):
- """Random crop image and bboxes.
- Args:
- aspect_ratio (list): aspect ratio of cropped region.
- in [min, max] format.
- thresholds (list): iou thresholds for decide a valid bbox crop.
- scaling (list): ratio between a cropped region and the original image.
- in [min, max] format.
- num_attempts (int): number of tries before giving up.
- allow_no_crop (bool): allow return without actually cropping them.
- cover_all_box (bool): ensure all bboxes are covered in the final crop.
- is_mask_crop(bool): whether crop the segmentation.
- """
- def __init__(self,
- aspect_ratio=[.5, 2.],
- thresholds=[.0, .1, .3, .5, .7, .9],
- scaling=[.3, 1.],
- num_attempts=50,
- allow_no_crop=True,
- cover_all_box=False,
- is_mask_crop=False):
- super(RandomCrop, self).__init__()
- self.aspect_ratio = aspect_ratio
- self.thresholds = thresholds
- self.scaling = scaling
- self.num_attempts = num_attempts
- self.allow_no_crop = allow_no_crop
- self.cover_all_box = cover_all_box
- self.is_mask_crop = is_mask_crop
- def crop_segms(self, segms, valid_ids, crop, height, width):
- def _crop_poly(segm, crop):
- xmin, ymin, xmax, ymax = crop
- crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]
- crop_p = np.array(crop_coord).reshape(4, 2)
- crop_p = Polygon(crop_p)
- crop_segm = list()
- for poly in segm:
- poly = np.array(poly).reshape(len(poly) // 2, 2)
- polygon = Polygon(poly)
- if not polygon.is_valid:
- exterior = polygon.exterior
- multi_lines = exterior.intersection(exterior)
- polygons = shapely.ops.polygonize(multi_lines)
- polygon = MultiPolygon(polygons)
- multi_polygon = list()
- if isinstance(polygon, MultiPolygon):
- multi_polygon = copy.deepcopy(polygon)
- else:
- multi_polygon.append(copy.deepcopy(polygon))
- for per_polygon in multi_polygon:
- inter = per_polygon.intersection(crop_p)
- if not inter:
- continue
- if isinstance(inter, (MultiPolygon, GeometryCollection)):
- for part in inter:
- if not isinstance(part, Polygon):
- continue
- part = np.squeeze(
- np.array(part.exterior.coords[:-1]).reshape(
- 1, -1))
- part[0::2] -= xmin
- part[1::2] -= ymin
- crop_segm.append(part.tolist())
- elif isinstance(inter, Polygon):
- crop_poly = np.squeeze(
- np.array(inter.exterior.coords[:-1]).reshape(1,
- -1))
- crop_poly[0::2] -= xmin
- crop_poly[1::2] -= ymin
- crop_segm.append(crop_poly.tolist())
- else:
- continue
- return crop_segm
- def _crop_rle(rle, crop, height, width):
- if 'counts' in rle and type(rle['counts']) == list:
- rle = mask_util.frPyObjects(rle, height, width)
- mask = mask_util.decode(rle)
- mask = mask[crop[1]:crop[3], crop[0]:crop[2]]
- rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
- return rle
- crop_segms = []
- for id in valid_ids:
- segm = segms[id]
- if is_poly(segm):
- import copy
- import shapely.ops
- from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
- logging.getLogger("shapely").setLevel(logging.WARNING)
- # Polygon format
- crop_segms.append(_crop_poly(segm, crop))
- else:
- # RLE format
- import pycocotools.mask as mask_util
- crop_segms.append(_crop_rle(segm, crop, height, width))
- return crop_segms
- def apply(self, sample, context=None):
- if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
- return sample
- h, w = sample['image'].shape[:2]
- gt_bbox = sample['gt_bbox']
- # NOTE Original method attempts to generate one candidate for each
- # threshold then randomly sample one from the resulting list.
- # Here a short circuit approach is taken, i.e., randomly choose a
- # threshold and attempt to find a valid crop, and simply return the
- # first one found.
- # The probability is not exactly the same, kinda resembling the
- # "Monty Hall" problem. Actually carrying out the attempts will affect
- # observability (just like opening doors in the "Monty Hall" game).
- thresholds = list(self.thresholds)
- if self.allow_no_crop:
- thresholds.append('no_crop')
- np.random.shuffle(thresholds)
- for thresh in thresholds:
- if thresh == 'no_crop':
- return sample
- found = False
- for i in range(self.num_attempts):
- scale = np.random.uniform(*self.scaling)
- if self.aspect_ratio is not None:
- min_ar, max_ar = self.aspect_ratio
- aspect_ratio = np.random.uniform(
- max(min_ar, scale**2), min(max_ar, scale**-2))
- h_scale = scale / np.sqrt(aspect_ratio)
- w_scale = scale * np.sqrt(aspect_ratio)
- else:
- h_scale = np.random.uniform(*self.scaling)
- w_scale = np.random.uniform(*self.scaling)
- crop_h = h * h_scale
- crop_w = w * w_scale
- if self.aspect_ratio is None:
- if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
- continue
- crop_h = int(crop_h)
- crop_w = int(crop_w)
- crop_y = np.random.randint(0, h - crop_h)
- crop_x = np.random.randint(0, w - crop_w)
- crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
- iou = self._iou_matrix(
- gt_bbox, np.array(
- [crop_box], dtype=np.float32))
- if iou.max() < thresh:
- continue
- if self.cover_all_box and iou.min() < thresh:
- continue
- cropped_box, valid_ids = self._crop_box_with_center_constraint(
- gt_bbox, np.array(
- crop_box, dtype=np.float32))
- if valid_ids.size > 0:
- found = True
- break
- if found:
- if self.is_mask_crop and 'gt_poly' in sample and len(sample[
- 'gt_poly']) > 0:
- crop_polys = self.crop_segms(
- sample['gt_poly'],
- valid_ids,
- np.array(
- crop_box, dtype=np.int64),
- h,
- w)
- if [] in crop_polys:
- delete_id = list()
- valid_polys = list()
- for id, crop_poly in enumerate(crop_polys):
- if crop_poly == []:
- delete_id.append(id)
- else:
- valid_polys.append(crop_poly)
- valid_ids = np.delete(valid_ids, delete_id)
- if len(valid_polys) == 0:
- return sample
- sample['gt_poly'] = valid_polys
- else:
- sample['gt_poly'] = crop_polys
- if 'gt_segm' in sample:
- sample['gt_segm'] = self._crop_segm(sample['gt_segm'],
- crop_box)
- sample['gt_segm'] = np.take(
- sample['gt_segm'], valid_ids, axis=0)
- sample['image'] = self._crop_image(sample['image'], crop_box)
- sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
- sample['gt_class'] = np.take(
- sample['gt_class'], valid_ids, axis=0)
- if 'gt_score' in sample:
- sample['gt_score'] = np.take(
- sample['gt_score'], valid_ids, axis=0)
- if 'is_crowd' in sample:
- sample['is_crowd'] = np.take(
- sample['is_crowd'], valid_ids, axis=0)
- return sample
- return sample
- def _iou_matrix(self, a, b):
- tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
- br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
- area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
- area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
- area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
- area_o = (area_a[:, np.newaxis] + area_b - area_i)
- return area_i / (area_o + 1e-10)
- def _crop_box_with_center_constraint(self, box, crop):
- cropped_box = box.copy()
- cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
- cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
- cropped_box[:, :2] -= crop[:2]
- cropped_box[:, 2:] -= crop[:2]
- centers = (box[:, :2] + box[:, 2:]) / 2
- valid = np.logical_and(crop[:2] <= centers,
- centers < crop[2:]).all(axis=1)
- valid = np.logical_and(
- valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
- return cropped_box, np.where(valid)[0]
- def _crop_image(self, img, crop):
- x1, y1, x2, y2 = crop
- return img[y1:y2, x1:x2, :]
- def _crop_segm(self, segm, crop):
- x1, y1, x2, y2 = crop
- return segm[:, y1:y2, x1:x2]
- @register_op
- class RandomScaledCrop(BaseOperator):
- """Resize image and bbox based on long side (with optional random scaling),
- then crop or pad image to target size.
- Args:
- target_dim (int): target size.
- scale_range (list): random scale range.
- interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
- """
- def __init__(self,
- target_dim=512,
- scale_range=[.1, 2.],
- interp=cv2.INTER_LINEAR):
- super(RandomScaledCrop, self).__init__()
- self.target_dim = target_dim
- self.scale_range = scale_range
- self.interp = interp
- def apply(self, sample, context=None):
- img = sample['image']
- h, w = img.shape[:2]
- random_scale = np.random.uniform(*self.scale_range)
- dim = self.target_dim
- random_dim = int(dim * random_scale)
- dim_max = max(h, w)
- scale = random_dim / dim_max
- resize_w = w * scale
- resize_h = h * scale
- offset_x = int(max(0, np.random.uniform(0., resize_w - dim)))
- offset_y = int(max(0, np.random.uniform(0., resize_h - dim)))
- img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interp)
- img = np.array(img)
- canvas = np.zeros((dim, dim, 3), dtype=img.dtype)
- canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[
- offset_y:offset_y + dim, offset_x:offset_x + dim, :]
- sample['image'] = canvas
- sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32)
- scale_factor = sample['sacle_factor']
- sample['scale_factor'] = np.asarray(
- [scale_factor[0] * scale, scale_factor[1] * scale],
- dtype=np.float32)
- if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
- scale_array = np.array([scale, scale] * 2, dtype=np.float32)
- shift_array = np.array([offset_x, offset_y] * 2, dtype=np.float32)
- boxes = sample['gt_bbox'] * scale_array - shift_array
- boxes = np.clip(boxes, 0, dim - 1)
- # filter boxes with no area
- area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1)
- valid = (area > 1.).nonzero()[0]
- sample['gt_bbox'] = boxes[valid]
- sample['gt_class'] = sample['gt_class'][valid]
- return sample
- @register_op
- class Cutmix(BaseOperator):
- def __init__(self, alpha=1.5, beta=1.5):
- """
- CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://arxiv.org/abs/1905.04899
- Cutmix image and gt_bbbox/gt_score
- Args:
- alpha (float): alpha parameter of beta distribute
- beta (float): beta parameter of beta distribute
- """
- super(Cutmix, self).__init__()
- self.alpha = alpha
- self.beta = beta
- if self.alpha <= 0.0:
- raise ValueError("alpha shold be positive in {}".format(self))
- if self.beta <= 0.0:
- raise ValueError("beta shold be positive in {}".format(self))
- def apply_image(self, img1, img2, factor):
- """ _rand_bbox """
- h = max(img1.shape[0], img2.shape[0])
- w = max(img1.shape[1], img2.shape[1])
- cut_rat = np.sqrt(1. - factor)
- cut_w = np.int(w * cut_rat)
- cut_h = np.int(h * cut_rat)
- # uniform
- cx = np.random.randint(w)
- cy = np.random.randint(h)
- bbx1 = np.clip(cx - cut_w // 2, 0, w - 1)
- bby1 = np.clip(cy - cut_h // 2, 0, h - 1)
- bbx2 = np.clip(cx + cut_w // 2, 0, w - 1)
- bby2 = np.clip(cy + cut_h // 2, 0, h - 1)
- img_1_pad = np.zeros((h, w, img1.shape[2]), 'float32')
- img_1_pad[:img1.shape[0], :img1.shape[1], :] = \
- img1.astype('float32')
- img_2_pad = np.zeros((h, w, img2.shape[2]), 'float32')
- img_2_pad[:img2.shape[0], :img2.shape[1], :] = \
- img2.astype('float32')
- img_1_pad[bby1:bby2, bbx1:bbx2, :] = img_2_pad[bby1:bby2, bbx1:bbx2, :]
- return img_1_pad
- def __call__(self, sample, context=None):
- if not isinstance(sample, Sequence):
- return sample
- assert len(sample) == 2, 'cutmix need two samples'
- factor = np.random.beta(self.alpha, self.beta)
- factor = max(0.0, min(1.0, factor))
- if factor >= 1.0:
- return sample[0]
- if factor <= 0.0:
- return sample[1]
- img1 = sample[0]['image']
- img2 = sample[1]['image']
- img = self.apply_image(img1, img2, factor)
- gt_bbox1 = sample[0]['gt_bbox']
- gt_bbox2 = sample[1]['gt_bbox']
- gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
- gt_class1 = sample[0]['gt_class']
- gt_class2 = sample[1]['gt_class']
- gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
- gt_score1 = np.ones_like(sample[0]['gt_class'])
- gt_score2 = np.ones_like(sample[1]['gt_class'])
- gt_score = np.concatenate(
- (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
- result = copy.deepcopy(sample[0])
- result['image'] = img
- result['gt_bbox'] = gt_bbox
- result['gt_score'] = gt_score
- result['gt_class'] = gt_class
- if 'is_crowd' in sample[0]:
- is_crowd1 = sample[0]['is_crowd']
- is_crowd2 = sample[1]['is_crowd']
- is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
- result['is_crowd'] = is_crowd
- if 'difficult' in sample[0]:
- is_difficult1 = sample[0]['difficult']
- is_difficult2 = sample[1]['difficult']
- is_difficult = np.concatenate(
- (is_difficult1, is_difficult2), axis=0)
- result['difficult'] = is_difficult
- return result
- @register_op
- class Mixup(BaseOperator):
- def __init__(self, alpha=1.5, beta=1.5):
- """ Mixup image and gt_bbbox/gt_score
- Args:
- alpha (float): alpha parameter of beta distribute
- beta (float): beta parameter of beta distribute
- """
- super(Mixup, self).__init__()
- self.alpha = alpha
- self.beta = beta
- if self.alpha <= 0.0:
- raise ValueError("alpha shold be positive in {}".format(self))
- if self.beta <= 0.0:
- raise ValueError("beta shold be positive in {}".format(self))
- def apply_image(self, img1, img2, factor):
- h = max(img1.shape[0], img2.shape[0])
- w = max(img1.shape[1], img2.shape[1])
- img = np.zeros((h, w, img1.shape[2]), 'float32')
- img[:img1.shape[0], :img1.shape[1], :] = \
- img1.astype('float32') * factor
- img[:img2.shape[0], :img2.shape[1], :] += \
- img2.astype('float32') * (1.0 - factor)
- return img.astype('uint8')
- def __call__(self, sample, context=None):
- if not isinstance(sample, Sequence):
- return sample
- assert len(sample) == 2, 'mixup need two samples'
- factor = np.random.beta(self.alpha, self.beta)
- factor = max(0.0, min(1.0, factor))
- if factor >= 1.0:
- return sample[0]
- if factor <= 0.0:
- return sample[1]
- im = self.apply_image(sample[0]['image'], sample[1]['image'], factor)
- result = copy.deepcopy(sample[0])
- result['image'] = im
- # apply bbox and score
- if 'gt_bbox' in sample[0]:
- gt_bbox1 = sample[0]['gt_bbox']
- gt_bbox2 = sample[1]['gt_bbox']
- gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
- result['gt_bbox'] = gt_bbox
- if 'gt_class' in sample[0]:
- gt_class1 = sample[0]['gt_class']
- gt_class2 = sample[1]['gt_class']
- gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
- result['gt_class'] = gt_class
- gt_score1 = np.ones_like(sample[0]['gt_class'])
- gt_score2 = np.ones_like(sample[1]['gt_class'])
- gt_score = np.concatenate(
- (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
- result['gt_score'] = gt_score
- if 'is_crowd' in sample[0]:
- is_crowd1 = sample[0]['is_crowd']
- is_crowd2 = sample[1]['is_crowd']
- is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
- result['is_crowd'] = is_crowd
- if 'difficult' in sample[0]:
- is_difficult1 = sample[0]['difficult']
- is_difficult2 = sample[1]['difficult']
- is_difficult = np.concatenate(
- (is_difficult1, is_difficult2), axis=0)
- result['difficult'] = is_difficult
- if 'gt_ide' in sample[0]:
- gt_ide1 = sample[0]['gt_ide']
- gt_ide2 = sample[1]['gt_ide']
- gt_ide = np.concatenate((gt_ide1, gt_ide2), axis=0)
- result['gt_ide'] = gt_ide
- return result
- @register_op
- class NormalizeBox(BaseOperator):
- """Transform the bounding box's coornidates to [0,1]."""
- def __init__(self):
- super(NormalizeBox, self).__init__()
- def apply(self, sample, context):
- im = sample['image']
- gt_bbox = sample['gt_bbox']
- height, width, _ = im.shape
- for i in range(gt_bbox.shape[0]):
- gt_bbox[i][0] = gt_bbox[i][0] / width
- gt_bbox[i][1] = gt_bbox[i][1] / height
- gt_bbox[i][2] = gt_bbox[i][2] / width
- gt_bbox[i][3] = gt_bbox[i][3] / height
- sample['gt_bbox'] = gt_bbox
- if 'gt_keypoint' in sample.keys():
- gt_keypoint = sample['gt_keypoint']
- for i in range(gt_keypoint.shape[1]):
- if i % 2:
- gt_keypoint[:, i] = gt_keypoint[:, i] / height
- else:
- gt_keypoint[:, i] = gt_keypoint[:, i] / width
- sample['gt_keypoint'] = gt_keypoint
- return sample
- @register_op
- class BboxXYXY2XYWH(BaseOperator):
- """
- Convert bbox XYXY format to XYWH format.
- """
- def __init__(self):
- super(BboxXYXY2XYWH, self).__init__()
- def apply(self, sample, context=None):
- assert 'gt_bbox' in sample
- bbox = sample['gt_bbox']
- bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
- bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
- sample['gt_bbox'] = bbox
- return sample
- @register_op
- class PadBox(BaseOperator):
- def __init__(self, num_max_boxes=50):
- """
- Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
- Args:
- num_max_boxes (int): the max number of bboxes
- """
- self.num_max_boxes = num_max_boxes
- super(PadBox, self).__init__()
- def apply(self, sample, context=None):
- assert 'gt_bbox' in sample
- bbox = sample['gt_bbox']
- gt_num = min(self.num_max_boxes, len(bbox))
- num_max = self.num_max_boxes
- # fields = context['fields'] if context else []
- pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
- if gt_num > 0:
- pad_bbox[:gt_num, :] = bbox[:gt_num, :]
- sample['gt_bbox'] = pad_bbox
- if 'gt_class' in sample:
- pad_class = np.zeros((num_max, ), dtype=np.int32)
- if gt_num > 0:
- pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
- sample['gt_class'] = pad_class
- if 'gt_score' in sample:
- pad_score = np.zeros((num_max, ), dtype=np.float32)
- if gt_num > 0:
- pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
- sample['gt_score'] = pad_score
- # in training, for example in op ExpandImage,
- # the bbox and gt_class is expandded, but the difficult is not,
- # so, judging by it's length
- if 'difficult' in sample:
- pad_diff = np.zeros((num_max, ), dtype=np.int32)
- if gt_num > 0:
- pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
- sample['difficult'] = pad_diff
- if 'is_crowd' in sample:
- pad_crowd = np.zeros((num_max, ), dtype=np.int32)
- if gt_num > 0:
- pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
- sample['is_crowd'] = pad_crowd
- if 'gt_ide' in sample:
- pad_ide = np.zeros((num_max, ), dtype=np.int32)
- if gt_num > 0:
- pad_ide[:gt_num] = sample['gt_ide'][:gt_num, 0]
- sample['gt_ide'] = pad_ide
- return sample
- @register_op
- class DebugVisibleImage(BaseOperator):
- """
- In debug mode, visualize images according to `gt_box`.
- (Currently only supported when not cropping and flipping image.)
- """
- def __init__(self, output_dir='output/debug', is_normalized=False):
- super(DebugVisibleImage, self).__init__()
- self.is_normalized = is_normalized
- self.output_dir = output_dir
- if not os.path.isdir(output_dir):
- os.makedirs(output_dir)
- if not isinstance(self.is_normalized, bool):
- raise TypeError("{}: input type is invalid.".format(self))
- def apply(self, sample, context=None):
- image = Image.open(sample['im_file']).convert('RGB')
- out_file_name = sample['im_file'].split('/')[-1]
- width = sample['w']
- height = sample['h']
- gt_bbox = sample['gt_bbox']
- gt_class = sample['gt_class']
- draw = ImageDraw.Draw(image)
- for i in range(gt_bbox.shape[0]):
- if self.is_normalized:
- gt_bbox[i][0] = gt_bbox[i][0] * width
- gt_bbox[i][1] = gt_bbox[i][1] * height
- gt_bbox[i][2] = gt_bbox[i][2] * width
- gt_bbox[i][3] = gt_bbox[i][3] * height
- xmin, ymin, xmax, ymax = gt_bbox[i]
- draw.line(
- [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
- (xmin, ymin)],
- width=2,
- fill='green')
- # draw label
- text = str(gt_class[i][0])
- tw, th = draw.textsize(text)
- draw.rectangle(
- [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green')
- draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
- if 'gt_keypoint' in sample.keys():
- gt_keypoint = sample['gt_keypoint']
- if self.is_normalized:
- for i in range(gt_keypoint.shape[1]):
- if i % 2:
- gt_keypoint[:, i] = gt_keypoint[:, i] * height
- else:
- gt_keypoint[:, i] = gt_keypoint[:, i] * width
- for i in range(gt_keypoint.shape[0]):
- keypoint = gt_keypoint[i]
- for j in range(int(keypoint.shape[0] / 2)):
- x1 = round(keypoint[2 * j]).astype(np.int32)
- y1 = round(keypoint[2 * j + 1]).astype(np.int32)
- draw.ellipse(
- (x1, y1, x1 + 5, y1 + 5),
- fill='green',
- outline='green')
- save_path = os.path.join(self.output_dir, out_file_name)
- image.save(save_path, quality=95)
- return sample
- @register_op
- class Pad(BaseOperator):
- def __init__(self,
- size=None,
- size_divisor=32,
- pad_mode=0,
- offsets=None,
- fill_value=(127.5, 127.5, 127.5)):
- """
- Pad image to a specified size or multiple of size_divisor.
- Args:
- size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
- size_divisor (int): size divisor, default 32
- pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
- if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
- offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
- fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
- """
- super(Pad, self).__init__()
- if not isinstance(size, (int, Sequence)):
- raise TypeError(
- "Type of target_size is invalid when random_size is True. \
- Must be List, now is {}".format(type(size)))
- if isinstance(size, int):
- size = [size, size]
- assert pad_mode in [
- -1, 0, 1, 2
- ], 'currently only supports four modes [-1, 0, 1, 2]'
- assert pad_mode == -1 and offsets, 'if pad_mode is -1, offsets should not be None'
- self.size = size
- self.size_divisor = size_divisor
- self.pad_mode = pad_mode
- self.fill_value = fill_value
- self.offsets = offsets
- def apply_segm(self, segms, offsets, im_size, size):
- def _expand_poly(poly, x, y):
- expanded_poly = np.array(poly)
- expanded_poly[0::2] += x
- expanded_poly[1::2] += y
- return expanded_poly.tolist()
- def _expand_rle(rle, x, y, height, width, h, w):
- if 'counts' in rle and type(rle['counts']) == list:
- rle = mask_util.frPyObjects(rle, height, width)
- mask = mask_util.decode(rle)
- expanded_mask = np.full((h, w), 0).astype(mask.dtype)
- expanded_mask[y:y + height, x:x + width] = mask
- rle = mask_util.encode(
- np.array(
- expanded_mask, order='F', dtype=np.uint8))
- return rle
- x, y = offsets
- height, width = im_size
- h, w = size
- expanded_segms = []
- for segm in segms:
- if is_poly(segm):
- # Polygon format
- expanded_segms.append(
- [_expand_poly(poly, x, y) for poly in segm])
- else:
- # RLE format
- import pycocotools.mask as mask_util
- expanded_segms.append(
- _expand_rle(segm, x, y, height, width, h, w))
- return expanded_segms
- def apply_bbox(self, bbox, offsets):
- return bbox + np.array(offsets * 2, dtype=np.float32)
- def apply_keypoint(self, keypoints, offsets):
- n = len(keypoints[0]) // 2
- return keypoints + np.array(offsets * n, dtype=np.float32)
- def apply_image(self, image, offsets, im_size, size):
- x, y = offsets
- im_h, im_w = im_size
- h, w = size
- canvas = np.ones((h, w, 3), dtype=np.float32)
- canvas *= np.array(self.fill_value, dtype=np.float32)
- canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
- return canvas
- def apply(self, sample, context=None):
- im = sample['image']
- im_h, im_w = im.shape[:2]
- if self.size:
- h, w = self.size
- assert (
- im_h < h and im_w < w
- ), '(h, w) of target size should be greater than (im_h, im_w)'
- else:
- h = np.ceil(im_h // self.size_divisor) * self.size_divisor
- w = np.ceil(im_w / self.size_divisor) * self.size_divisor
- if h == im_h and w == im_w:
- return sample
- if self.pad_mode == -1:
- offset_x, offset_y = self.offsets
- elif self.pad_mode == 0:
- offset_y, offset_x = 0, 0
- elif self.pad_mode == 1:
- offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
- else:
- offset_y, offset_x = h - im_h, w - im_w
- offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
- sample['image'] = self.apply_image(im, offsets, im_size, size)
- if self.pad_mode == 0:
- return sample
- if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
- sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
- if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
- sample['gt_poly'] = self.apply_segm(sample['gt_poly'], offsets,
- im_size, size)
- if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0:
- sample['gt_keypoint'] = self.apply_keypoint(sample['gt_keypoint'],
- offsets)
- return sample
- @register_op
- class Poly2Mask(BaseOperator):
- """
- gt poly to mask annotations
- """
- def __init__(self):
- super(Poly2Mask, self).__init__()
- import pycocotools.mask as maskUtils
- self.maskutils = maskUtils
- def _poly2mask(self, mask_ann, img_h, img_w):
- if isinstance(mask_ann, list):
- # polygon -- a single object might consist of multiple parts
- # we merge all parts into one mask rle code
- rles = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
- rle = self.maskutils.merge(rles)
- elif isinstance(mask_ann['counts'], list):
- # uncompressed RLE
- rle = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
- else:
- # rle
- rle = mask_ann
- mask = self.maskutils.decode(rle)
- return mask
- def apply(self, sample, context=None):
- assert 'gt_poly' in sample
- im_h = sample['h']
- im_w = sample['w']
- masks = [
- self._poly2mask(gt_poly, im_h, im_w)
- for gt_poly in sample['gt_poly']
- ]
- sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
- return sample
- @register_op
- class Rbox2Poly(BaseOperator):
- """
- Convert rbbox format to poly format.
- """
- def __init__(self):
- super(Rbox2Poly, self).__init__()
- def apply(self, sample, context=None):
- assert 'gt_rbox' in sample
- assert sample['gt_rbox'].shape[1] == 5
- rrects = sample['gt_rbox']
- x_ctr = rrects[:, 0]
- y_ctr = rrects[:, 1]
- width = rrects[:, 2]
- height = rrects[:, 3]
- x1 = x_ctr - width / 2.0
- y1 = y_ctr - height / 2.0
- x2 = x_ctr + width / 2.0
- y2 = y_ctr + height / 2.0
- sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1)
- polys = bbox_utils.rbox2poly_np(rrects)
- sample['gt_rbox2poly'] = polys
- return sample
- @register_op
- class AugmentHSV(BaseOperator):
- def __init__(self, fraction=0.50, is_bgr=False):
- """
- Augment the SV channel of image data.
- Args:
- fraction (float): the fraction for augment
- is_bgr (bool): whether the image is BGR mode
- """
- super(AugmentHSV, self).__init__()
- self.fraction = fraction
- self.is_bgr = is_bgr
- def apply(self, sample, context=None):
- img = sample['image']
- if self.is_bgr:
- img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
- else:
- img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
- S = img_hsv[:, :, 1].astype(np.float32)
- V = img_hsv[:, :, 2].astype(np.float32)
- a = (random.random() * 2 - 1) * self.fraction + 1
- S *= a
- if a > 1:
- np.clip(S, a_min=0, a_max=255, out=S)
- a = (random.random() * 2 - 1) * self.fraction + 1
- V *= a
- if a > 1:
- np.clip(V, a_min=0, a_max=255, out=V)
- img_hsv[:, :, 1] = S.astype(np.uint8)
- img_hsv[:, :, 2] = V.astype(np.uint8)
- if self.is_bgr:
- cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
- else:
- cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB, dst=img)
- sample['image'] = img
- return sample
- @register_op
- class Norm2PixelBbox(BaseOperator):
- """
- Transform the bounding box's coornidates which is in [0,1] to pixels.
- """
- def __init__(self):
- super(Norm2PixelBbox, self).__init__()
- def apply(self, sample, context=None):
- assert 'gt_bbox' in sample
- bbox = sample['gt_bbox']
- height, width = sample['image'].shape[:2]
- bbox[:, 0::2] = bbox[:, 0::2] * width
- bbox[:, 1::2] = bbox[:, 1::2] * height
- sample['gt_bbox'] = bbox
- return sample
- @register_op
- class BboxCXCYWH2XYXY(BaseOperator):
- """
- Convert bbox CXCYWH format to XYXY format.
- [center_x, center_y, width, height] -> [x0, y0, x1, y1]
- """
- def __init__(self):
- super(BboxCXCYWH2XYXY, self).__init__()
- def apply(self, sample, context=None):
- assert 'gt_bbox' in sample
- bbox0 = sample['gt_bbox']
- bbox = bbox0.copy()
- bbox[:, :2] = bbox0[:, :2] - bbox0[:, 2:4] / 2.
- bbox[:, 2:4] = bbox0[:, :2] + bbox0[:, 2:4] / 2.
- sample['gt_bbox'] = bbox
- return sample
|