|
|
@@ -46,7 +46,7 @@ class Compose(SegTransform):
|
|
|
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, transforms, vdl_save_dir=None):
|
|
|
+ def __init__(self, transforms):
|
|
|
if not isinstance(transforms, list):
|
|
|
raise TypeError('The transforms must be a list!')
|
|
|
if len(transforms) < 1:
|
|
|
@@ -62,24 +62,8 @@ class Compose(SegTransform):
|
|
|
raise Exception(
|
|
|
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
|
|
|
)
|
|
|
- self.images_writer = None
|
|
|
-
|
|
|
- def set_vdl(self, vdl_save_dir=None):
|
|
|
- # 对数据预处理结果在VisualDL中可视化
|
|
|
- self.images_writer = None
|
|
|
- if vdl_save_dir is not None:
|
|
|
- if not osp.isdir(vdl_save_dir):
|
|
|
- if osp.exists(vdl_save_dir):
|
|
|
- os.remove(vdl_save_dir)
|
|
|
- os.makedirs(vdl_save_dir)
|
|
|
- from visualdl import LogWriter
|
|
|
- vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
|
|
|
- self.images_writer = LogWriter(vdl_images_dir)
|
|
|
-
|
|
|
- def release_vdl(self):
|
|
|
- self.images_writer = None
|
|
|
-
|
|
|
- def __call__(self, im, im_info=None, label=None, step=0):
|
|
|
+
|
|
|
+ def __call__(self, im, im_info=None, label=None, images_writer=None, step=0):
|
|
|
"""
|
|
|
Args:
|
|
|
im (str/np.ndarray): 图像路径/图像np.ndarray数据。
|
|
|
@@ -92,7 +76,6 @@ class Compose(SegTransform):
|
|
|
Returns:
|
|
|
tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
|
|
|
"""
|
|
|
- im_file = str(step)
|
|
|
if im_info is None:
|
|
|
im_info = list()
|
|
|
if isinstance(im, np.ndarray):
|
|
|
@@ -102,7 +85,6 @@ class Compose(SegTransform):
|
|
|
format(len(im.shape)))
|
|
|
else:
|
|
|
try:
|
|
|
- im_file = im
|
|
|
im = cv2.imread(im).astype('float32')
|
|
|
except:
|
|
|
raise ValueError('Can\'t read The image file {}!'.format(im))
|
|
|
@@ -111,10 +93,10 @@ class Compose(SegTransform):
|
|
|
if label is not None:
|
|
|
if not isinstance(label, np.ndarray):
|
|
|
label = np.asarray(Image.open(label))
|
|
|
- if self.images_writer is not None:
|
|
|
- self.images_writer.add_image(tag='0. origin image',
|
|
|
- img=im,
|
|
|
- step=step)
|
|
|
+ if images_writer is not None:
|
|
|
+ images_writer.add_image(tag='0. origin image',
|
|
|
+ img=im,
|
|
|
+ step=step)
|
|
|
op_id = 1
|
|
|
for op in self.transforms:
|
|
|
if isinstance(op, SegTransform):
|
|
|
@@ -130,9 +112,9 @@ class Compose(SegTransform):
|
|
|
outputs = (im, im_info, label)
|
|
|
else:
|
|
|
outputs = (im, im_info)
|
|
|
- if self.images_writer is not None:
|
|
|
+ if images_writer is not None:
|
|
|
tag = str(op_id) + '. ' + op.__class__.__name__
|
|
|
- self.images_writer.add_image(tag=tag,
|
|
|
+ images_writer.add_image(tag=tag,
|
|
|
img=im,
|
|
|
step=step)
|
|
|
op_id += 1
|