|
|
@@ -24,11 +24,20 @@ import numpy as np
|
|
|
import cv2
|
|
|
from PIL import Image, ImageEnhance
|
|
|
|
|
|
+from .imgaug_support import execute_imgaug
|
|
|
from .ops import *
|
|
|
from .box_utils import *
|
|
|
|
|
|
|
|
|
-class Compose:
|
|
|
+class DetTransform:
|
|
|
+ """检测数据处理基类
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class Compose(DetTransform):
|
|
|
"""根据数据预处理/增强列表对输入数据进行操作。
|
|
|
所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
|
|
|
|
|
|
@@ -49,8 +58,16 @@ class Compose:
|
|
|
self.transforms = transforms
|
|
|
self.use_mixup = False
|
|
|
for t in self.transforms:
|
|
|
- if t.__class__.__name__ == 'MixupImage':
|
|
|
+ if type(t).__name__ == 'MixupImage':
|
|
|
self.use_mixup = True
|
|
|
+ # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
|
|
|
+ for op in self.transforms:
|
|
|
+ if not isinstance(op, DetTransform):
|
|
|
+ import imgaug.augmenters as iaa
|
|
|
+ if not isinstance(op, iaa.Augmenter):
|
|
|
+ raise Exception(
|
|
|
+ "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
|
|
|
+ )
|
|
|
|
|
|
def __call__(self, im, im_info=None, label_info=None):
|
|
|
"""
|
|
|
@@ -58,8 +75,8 @@ class Compose:
|
|
|
im (str/np.ndarray): 图像路径/图像np.ndarray数据。
|
|
|
im_info (dict): 存储与图像相关的信息,dict中的字段如下:
|
|
|
- im_id (np.ndarray): 图像序列号,形状为(1,)。
|
|
|
- - origin_shape (np.ndarray): 图像原始大小,形状为(2,),
|
|
|
- origin_shape[0]为高,origin_shape[1]为宽。
|
|
|
+ - image_shape (np.ndarray): 图像原始大小,形状为(2,),
|
|
|
+ image_shape[0]为高,image_shape[1]为宽。
|
|
|
- mixup (list): list为[im, im_info, label_info],分别对应
|
|
|
与当前图像进行mixup的图像np.ndarray数据、图像相关信息、标注框相关信息;
|
|
|
注意,当前epoch若无需进行mixup,则无该字段。
|
|
|
@@ -84,18 +101,24 @@ class Compose:
|
|
|
def decode_image(im_file, im_info, label_info):
|
|
|
if im_info is None:
|
|
|
im_info = dict()
|
|
|
- try:
|
|
|
- im = cv2.imread(im_file).astype('float32')
|
|
|
- except:
|
|
|
- raise TypeError(
|
|
|
- 'Can\'t read The image file {}!'.format(im_file))
|
|
|
+ if isinstance(im_file, np.ndarray):
|
|
|
+ if len(im_file.shape) != 3:
|
|
|
+ raise Exception(
|
|
|
+ "im should be 3-dimensions, but now is {}-dimensions".
|
|
|
+ format(len(im_file.shape)))
|
|
|
+ im = im_file
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ im = cv2.imread(im_file).astype('float32')
|
|
|
+ except:
|
|
|
+ raise TypeError(
|
|
|
+ 'Can\'t read The image file {}!'.format(im_file))
|
|
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
|
|
# make default im_info with [h, w, 1]
|
|
|
im_info['im_resize_info'] = np.array(
|
|
|
[im.shape[0], im.shape[1], 1.], dtype=np.float32)
|
|
|
- # copy augment_shape from origin_shape
|
|
|
- im_info['augment_shape'] = np.array([im.shape[0],
|
|
|
- im.shape[1]]).astype('int32')
|
|
|
+ im_info['image_shape'] = np.array([im.shape[0],
|
|
|
+ im.shape[1]]).astype('int32')
|
|
|
if not self.use_mixup:
|
|
|
if 'mixup' in im_info:
|
|
|
del im_info['mixup']
|
|
|
@@ -118,12 +141,28 @@ class Compose:
|
|
|
for op in self.transforms:
|
|
|
if im is None:
|
|
|
return None
|
|
|
- outputs = op(im, im_info, label_info)
|
|
|
- im = outputs[0]
|
|
|
+ if isinstance(op, DetTransform):
|
|
|
+ outputs = op(im, im_info, label_info)
|
|
|
+ im = outputs[0]
|
|
|
+ else:
|
|
|
+ if label_info is not None:
|
|
|
+ gt_poly = label_info.get('gt_poly', None)
|
|
|
+ gt_bbox = label_info['gt_bbox']
|
|
|
+ if gt_poly is None:
|
|
|
+ im, aug_bbox = execute_imgaug(op, im, bboxes=gt_bbox)
|
|
|
+ else:
|
|
|
+ im, aug_bbox, aug_poly = execute_imgaug(
|
|
|
+ op, im, bboxes=gt_bbox, polygons=gt_poly)
|
|
|
+ label_info['gt_poly'] = aug_poly
|
|
|
+ label_info['gt_bbox'] = aug_bbox
|
|
|
+ outputs = (im, im_info, label_info)
|
|
|
+ else:
|
|
|
+ im, = execute_imgaug(op, im)
|
|
|
+ outputs = (im, im_info)
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
-class ResizeByShort:
|
|
|
+class ResizeByShort(DetTransform):
|
|
|
"""根据图像的短边调整图像大小(resize)。
|
|
|
|
|
|
1. 获取图像的长边和短边长度。
|
|
|
@@ -195,12 +234,17 @@ class ResizeByShort:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class Padding:
|
|
|
- """将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
|
|
|
+class Padding(DetTransform):
|
|
|
+ """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
|
|
|
`coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
|
|
|
进行padding,最终输出图像为[320, 640]。
|
|
|
+ 2.或者,将图像的长和宽padding到target_size指定的shape,如输入的图像为[300,640],
|
|
|
+ a. `target_size` = 960,在图像最右和最下使用0值进行padding,最终输出
|
|
|
+ 图像为[960, 960]。
|
|
|
+ b. `target_size` = [640, 960],在图像最右和最下使用0值进行padding,最终
|
|
|
+ 输出图像为[640, 960]。
|
|
|
|
|
|
- 1. 如果coarsest_stride为1则直接返回。
|
|
|
+ 1. 如果coarsest_stride为1,target_size为None则直接返回。
|
|
|
2. 获取图像的高H、宽W。
|
|
|
3. 计算填充后图像的高H_new、宽W_new。
|
|
|
4. 构建大小为(H_new, W_new, 3)像素值为0的np.ndarray,
|
|
|
@@ -208,10 +252,26 @@ class Padding:
|
|
|
|
|
|
Args:
|
|
|
coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
|
|
|
+ target_size (int|list|tuple): 填充后的图像长、宽,默认为None,coarset_stride优先级更高。
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ TypeError: 形参`target_size`数据类型不满足需求。
|
|
|
+ ValueError: 形参`target_size`为(list|tuple)时,长度不满足需求。
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, coarsest_stride=1):
|
|
|
+ def __init__(self, coarsest_stride=1, target_size=None):
|
|
|
self.coarsest_stride = coarsest_stride
|
|
|
+ if target_size is not None:
|
|
|
+ if not isinstance(target_size, int):
|
|
|
+ if not isinstance(target_size, tuple) and not isinstance(
|
|
|
+ target_size, list):
|
|
|
+ raise TypeError(
|
|
|
+ "Padding: Type of target_size must in (int|list|tuple)."
|
|
|
+ )
|
|
|
+ elif len(target_size) != 2:
|
|
|
+ raise ValueError(
|
|
|
+ "Padding: Length of target_size must equal 2.")
|
|
|
+ self.target_size = target_size
|
|
|
|
|
|
def __call__(self, im, im_info=None, label_info=None):
|
|
|
"""
|
|
|
@@ -228,13 +288,9 @@ class Padding:
|
|
|
Raises:
|
|
|
TypeError: 形参数据类型不满足需求。
|
|
|
ValueError: 数据长度不匹配。
|
|
|
+ ValueError: coarsest_stride,target_size需有且只有一个被指定。
|
|
|
+ ValueError: target_size小于原图的大小。
|
|
|
"""
|
|
|
-
|
|
|
- if self.coarsest_stride == 1:
|
|
|
- if label_info is None:
|
|
|
- return (im, im_info)
|
|
|
- else:
|
|
|
- return (im, im_info, label_info)
|
|
|
if im_info is None:
|
|
|
im_info = dict()
|
|
|
if not isinstance(im, np.ndarray):
|
|
|
@@ -242,11 +298,29 @@ class Padding:
|
|
|
if len(im.shape) != 3:
|
|
|
raise ValueError('Padding: image is not 3-dimensional.')
|
|
|
im_h, im_w, im_c = im.shape[:]
|
|
|
- if self.coarsest_stride > 1:
|
|
|
+
|
|
|
+ if isinstance(self.target_size, int):
|
|
|
+ padding_im_h = self.target_size
|
|
|
+ padding_im_w = self.target_size
|
|
|
+ elif isinstance(self.target_size, list) or isinstance(
|
|
|
+ self.target_size, tuple):
|
|
|
+ padding_im_w = self.target_size[0]
|
|
|
+ padding_im_h = self.target_size[1]
|
|
|
+ elif self.coarsest_stride > 0:
|
|
|
padding_im_h = int(
|
|
|
np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
|
|
|
padding_im_w = int(
|
|
|
np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ "coarsest_stridei(>1) or target_size(list|int) need setting in Padding transform"
|
|
|
+ )
|
|
|
+ pad_height = padding_im_h - im_h
|
|
|
+ pad_width = padding_im_w - im_w
|
|
|
+ if pad_height < 0 or pad_width < 0:
|
|
|
+ raise ValueError(
|
|
|
+ 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
|
|
|
+ .format(im_w, im_h, padding_im_w, padding_im_h))
|
|
|
padding_im = np.zeros((padding_im_h, padding_im_w, im_c),
|
|
|
dtype=np.float32)
|
|
|
padding_im[:im_h, :im_w, :] = im
|
|
|
@@ -256,7 +330,7 @@ class Padding:
|
|
|
return (padding_im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class Resize:
|
|
|
+class Resize(DetTransform):
|
|
|
"""调整图像大小(resize)。
|
|
|
|
|
|
- 当目标大小(target_size)类型为int时,根据插值方式,
|
|
|
@@ -335,7 +409,7 @@ class Resize:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class RandomHorizontalFlip:
|
|
|
+class RandomHorizontalFlip(DetTransform):
|
|
|
"""随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。
|
|
|
|
|
|
1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时,
|
|
|
@@ -387,16 +461,13 @@ class RandomHorizontalFlip:
|
|
|
raise TypeError(
|
|
|
'Cannot do RandomHorizontalFlip! ' +
|
|
|
'Becasuse the im_info and label_info can not be None!')
|
|
|
- if 'augment_shape' not in im_info:
|
|
|
- raise TypeError('Cannot do RandomHorizontalFlip! ' + \
|
|
|
- 'Becasuse augment_shape is not in im_info!')
|
|
|
if 'gt_bbox' not in label_info:
|
|
|
raise TypeError('Cannot do RandomHorizontalFlip! ' + \
|
|
|
'Becasuse gt_bbox is not in label_info!')
|
|
|
- augment_shape = im_info['augment_shape']
|
|
|
+ image_shape = im_info['image_shape']
|
|
|
gt_bbox = label_info['gt_bbox']
|
|
|
- height = augment_shape[0]
|
|
|
- width = augment_shape[1]
|
|
|
+ height = image_shape[0]
|
|
|
+ width = image_shape[1]
|
|
|
|
|
|
if np.random.uniform(0, 1) < self.prob:
|
|
|
im = horizontal_flip(im)
|
|
|
@@ -416,7 +487,7 @@ class RandomHorizontalFlip:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class Normalize:
|
|
|
+class Normalize(DetTransform):
|
|
|
"""对图像进行标准化。
|
|
|
|
|
|
1. 归一化图像到到区间[0.0, 1.0]。
|
|
|
@@ -460,7 +531,7 @@ class Normalize:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class RandomDistort:
|
|
|
+class RandomDistort(DetTransform):
|
|
|
"""以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
|
|
|
|
|
|
1. 对变换的操作顺序进行随机化操作。
|
|
|
@@ -545,7 +616,7 @@ class RandomDistort:
|
|
|
params = params_dict[ops[id].__name__]
|
|
|
prob = prob_dict[ops[id].__name__]
|
|
|
params['im'] = im
|
|
|
-
|
|
|
+
|
|
|
if np.random.uniform(0, 1) < prob:
|
|
|
im = ops[id](**params)
|
|
|
if label_info is None:
|
|
|
@@ -554,7 +625,7 @@ class RandomDistort:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class MixupImage:
|
|
|
+class MixupImage(DetTransform):
|
|
|
"""对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。
|
|
|
|
|
|
当label_info中不存在mixup字段时,直接返回,否则进行下述操作:
|
|
|
@@ -567,7 +638,7 @@ class MixupImage:
|
|
|
(2)拼接原图像标注框和mixup图像标注框。
|
|
|
(3)拼接原图像标注框类别和mixup图像标注框类别。
|
|
|
(4)原图像标注框混合得分乘以factor,mixup图像标注框混合得分乘以(1-factor),叠加2个结果。
|
|
|
- 3. 更新im_info中的augment_shape信息。
|
|
|
+ 3. 更新im_info中的image_shape信息。
|
|
|
|
|
|
Args:
|
|
|
alpha (float): 随机beta分布的下限。默认为1.5。
|
|
|
@@ -610,7 +681,7 @@ class MixupImage:
|
|
|
当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
|
|
|
存储与标注框相关信息的字典。
|
|
|
其中,im_info更新字段为:
|
|
|
- - augment_shape (np.ndarray): mixup后的图像高、宽二者组成的np.ndarray,形状为(2,)。
|
|
|
+ - image_shape (np.ndarray): mixup后的图像高、宽二者组成的np.ndarray,形状为(2,)。
|
|
|
im_info删除的字段:
|
|
|
- mixup (list): 与当前字段进行mixup的图像相关信息。
|
|
|
label_info更新字段为:
|
|
|
@@ -674,8 +745,8 @@ class MixupImage:
|
|
|
label_info['gt_score'] = gt_score
|
|
|
label_info['gt_class'] = gt_class
|
|
|
label_info['is_crowd'] = is_crowd
|
|
|
- im_info['augment_shape'] = np.array([im.shape[0],
|
|
|
- im.shape[1]]).astype('int32')
|
|
|
+ im_info['image_shape'] = np.array([im.shape[0],
|
|
|
+ im.shape[1]]).astype('int32')
|
|
|
im_info.pop('mixup')
|
|
|
if label_info is None:
|
|
|
return (im, im_info)
|
|
|
@@ -683,7 +754,7 @@ class MixupImage:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class RandomExpand:
|
|
|
+class RandomExpand(DetTransform):
|
|
|
"""随机扩张图像,模型训练时的数据增强操作。
|
|
|
1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
|
|
|
2. 计算扩张后图像大小。
|
|
|
@@ -721,7 +792,7 @@ class RandomExpand:
|
|
|
当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
|
|
|
存储与标注框相关信息的字典。
|
|
|
其中,im_info更新字段为:
|
|
|
- - augment_shape (np.ndarray): 扩张后的图像高、宽二者组成的np.ndarray,形状为(2,)。
|
|
|
+ - image_shape (np.ndarray): 扩张后的图像高、宽二者组成的np.ndarray,形状为(2,)。
|
|
|
label_info更新字段为:
|
|
|
- gt_bbox (np.ndarray): 随机扩张后真实标注框坐标,形状为(n, 4),
|
|
|
其中n代表真实标注框的个数。
|
|
|
@@ -734,9 +805,6 @@ class RandomExpand:
|
|
|
raise TypeError(
|
|
|
'Cannot do RandomExpand! ' +
|
|
|
'Becasuse the im_info and label_info can not be None!')
|
|
|
- if 'augment_shape' not in im_info:
|
|
|
- raise TypeError('Cannot do RandomExpand! ' + \
|
|
|
- 'Becasuse augment_shape is not in im_info!')
|
|
|
if 'gt_bbox' not in label_info or \
|
|
|
'gt_class' not in label_info:
|
|
|
raise TypeError('Cannot do RandomExpand! ' + \
|
|
|
@@ -744,9 +812,9 @@ class RandomExpand:
|
|
|
if np.random.uniform(0., 1.) < self.prob:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
- augment_shape = im_info['augment_shape']
|
|
|
- height = int(augment_shape[0])
|
|
|
- width = int(augment_shape[1])
|
|
|
+ image_shape = im_info['image_shape']
|
|
|
+ height = int(image_shape[0])
|
|
|
+ width = int(image_shape[1])
|
|
|
|
|
|
expand_ratio = np.random.uniform(1., self.ratio)
|
|
|
h = int(height * expand_ratio)
|
|
|
@@ -759,7 +827,7 @@ class RandomExpand:
|
|
|
canvas *= np.array(self.fill_value, dtype=np.float32)
|
|
|
canvas[y:y + height, x:x + width, :] = im
|
|
|
|
|
|
- im_info['augment_shape'] = np.array([h, w]).astype('int32')
|
|
|
+ im_info['image_shape'] = np.array([h, w]).astype('int32')
|
|
|
if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0:
|
|
|
label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
|
|
|
if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
|
|
|
@@ -768,7 +836,7 @@ class RandomExpand:
|
|
|
return (canvas, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class RandomCrop:
|
|
|
+class RandomCrop(DetTransform):
|
|
|
"""随机裁剪图像。
|
|
|
1. 若allow_no_crop为True,则在thresholds加入’no_crop’。
|
|
|
2. 随机打乱thresholds。
|
|
|
@@ -815,12 +883,14 @@ class RandomCrop:
|
|
|
tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
|
|
|
当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
|
|
|
存储与标注框相关信息的字典。
|
|
|
- 其中,label_info更新字段为:
|
|
|
- - gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
|
|
|
+ 其中,im_info更新字段为:
|
|
|
+ - image_shape (np.ndarray): 扩裁剪的图像高、宽二者组成的np.ndarray,形状为(2,)。
|
|
|
+ label_info更新字段为:
|
|
|
+ - gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
|
|
|
其中n代表真实标注框的个数。
|
|
|
- - gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
|
|
|
+ - gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
|
|
|
其中n代表真实标注框的个数。
|
|
|
- - gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
|
|
|
+ - gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
|
|
|
其中n代表真实标注框的个数。
|
|
|
|
|
|
Raises:
|
|
|
@@ -830,9 +900,6 @@ class RandomCrop:
|
|
|
raise TypeError(
|
|
|
'Cannot do RandomCrop! ' +
|
|
|
'Becasuse the im_info and label_info can not be None!')
|
|
|
- if 'augment_shape' not in im_info:
|
|
|
- raise TypeError('Cannot do RandomCrop! ' + \
|
|
|
- 'Becasuse augment_shape is not in im_info!')
|
|
|
if 'gt_bbox' not in label_info or \
|
|
|
'gt_class' not in label_info:
|
|
|
raise TypeError('Cannot do RandomCrop! ' + \
|
|
|
@@ -841,9 +908,9 @@ class RandomCrop:
|
|
|
if len(label_info['gt_bbox']) == 0:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
- augment_shape = im_info['augment_shape']
|
|
|
- w = augment_shape[1]
|
|
|
- h = augment_shape[0]
|
|
|
+ image_shape = im_info['image_shape']
|
|
|
+ w = image_shape[1]
|
|
|
+ h = image_shape[0]
|
|
|
gt_bbox = label_info['gt_bbox']
|
|
|
thresholds = list(self.thresholds)
|
|
|
if self.allow_no_crop:
|
|
|
@@ -902,7 +969,7 @@ class RandomCrop:
|
|
|
label_info['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
|
|
|
label_info['gt_class'] = np.take(
|
|
|
label_info['gt_class'], valid_ids, axis=0)
|
|
|
- im_info['augment_shape'] = np.array(
|
|
|
+ im_info['image_shape'] = np.array(
|
|
|
[crop_box[3] - crop_box[1],
|
|
|
crop_box[2] - crop_box[0]]).astype('int32')
|
|
|
if 'gt_score' in label_info:
|
|
|
@@ -917,7 +984,7 @@ class RandomCrop:
|
|
|
return (im, im_info, label_info)
|
|
|
|
|
|
|
|
|
-class ArrangeFasterRCNN:
|
|
|
+class ArrangeFasterRCNN(DetTransform):
|
|
|
"""获取FasterRCNN模型训练/验证/预测所需信息。
|
|
|
|
|
|
Args:
|
|
|
@@ -973,7 +1040,7 @@ class ArrangeFasterRCNN:
|
|
|
im_resize_info = im_info['im_resize_info']
|
|
|
im_id = im_info['im_id']
|
|
|
im_shape = np.array(
|
|
|
- (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
|
|
|
+ (im_info['image_shape'][0], im_info['image_shape'][1], 1),
|
|
|
dtype=np.float32)
|
|
|
gt_bbox = label_info['gt_bbox']
|
|
|
gt_class = label_info['gt_class']
|
|
|
@@ -986,13 +1053,13 @@ class ArrangeFasterRCNN:
|
|
|
'Becasuse the im_info can not be None!')
|
|
|
im_resize_info = im_info['im_resize_info']
|
|
|
im_shape = np.array(
|
|
|
- (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
|
|
|
+ (im_info['image_shape'][0], im_info['image_shape'][1], 1),
|
|
|
dtype=np.float32)
|
|
|
outputs = (im, im_resize_info, im_shape)
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
-class ArrangeMaskRCNN:
|
|
|
+class ArrangeMaskRCNN(DetTransform):
|
|
|
"""获取MaskRCNN模型训练/验证/预测所需信息。
|
|
|
|
|
|
Args:
|
|
|
@@ -1066,7 +1133,7 @@ class ArrangeMaskRCNN:
|
|
|
'Becasuse the im_info can not be None!')
|
|
|
im_resize_info = im_info['im_resize_info']
|
|
|
im_shape = np.array(
|
|
|
- (im_info['augment_shape'][0], im_info['augment_shape'][1], 1),
|
|
|
+ (im_info['image_shape'][0], im_info['image_shape'][1], 1),
|
|
|
dtype=np.float32)
|
|
|
if self.mode == 'eval':
|
|
|
im_id = im_info['im_id']
|
|
|
@@ -1076,7 +1143,7 @@ class ArrangeMaskRCNN:
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
-class ArrangeYOLOv3:
|
|
|
+class ArrangeYOLOv3(DetTransform):
|
|
|
"""获取YOLOv3模型训练/验证/预测所需信息。
|
|
|
|
|
|
Args:
|
|
|
@@ -1117,7 +1184,7 @@ class ArrangeYOLOv3:
|
|
|
raise TypeError(
|
|
|
'Cannot do ArrangeYolov3! ' +
|
|
|
'Becasuse the im_info and label_info can not be None!')
|
|
|
- im_shape = im_info['augment_shape']
|
|
|
+ im_shape = im_info['image_shape']
|
|
|
if len(label_info['gt_bbox']) != len(label_info['gt_class']):
|
|
|
raise ValueError("gt num mismatch: bbox and class.")
|
|
|
if len(label_info['gt_bbox']) != len(label_info['gt_score']):
|
|
|
@@ -1141,7 +1208,7 @@ class ArrangeYOLOv3:
|
|
|
raise TypeError(
|
|
|
'Cannot do ArrangeYolov3! ' +
|
|
|
'Becasuse the im_info and label_info can not be None!')
|
|
|
- im_shape = im_info['augment_shape']
|
|
|
+ im_shape = im_info['image_shape']
|
|
|
if len(label_info['gt_bbox']) != len(label_info['gt_class']):
|
|
|
raise ValueError("gt num mismatch: bbox and class.")
|
|
|
im_id = im_info['im_id']
|
|
|
@@ -1160,6 +1227,6 @@ class ArrangeYOLOv3:
|
|
|
if im_info is None:
|
|
|
raise TypeError('Cannot do ArrangeYolov3! ' +
|
|
|
'Becasuse the im_info can not be None!')
|
|
|
- im_shape = im_info['augment_shape']
|
|
|
+ im_shape = im_info['image_shape']
|
|
|
outputs = (im, im_shape)
|
|
|
return outputs
|