|
|
@@ -13,6 +13,7 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import os
|
|
|
from .ops import *
|
|
|
from .imgaug_support import execute_imgaug
|
|
|
import random
|
|
|
@@ -45,7 +46,7 @@ class Compose(SegTransform):
|
|
|
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, transforms):
|
|
|
+ def __init__(self, transforms, vdl_save_dir=None):
|
|
|
if not isinstance(transforms, list):
|
|
|
raise TypeError('The transforms must be a list!')
|
|
|
if len(transforms) < 1:
|
|
|
@@ -61,8 +62,24 @@ class Compose(SegTransform):
|
|
|
raise Exception(
|
|
|
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
|
|
|
)
|
|
|
-
|
|
|
- def __call__(self, im, im_info=None, label=None):
|
|
|
+ self.images_writer = None
|
|
|
+
|
|
|
+ def set_vdl(self, vdl_save_dir=None):
|
|
|
+ # 对数据预处理结果在VisualDL中可视化
|
|
|
+ self.images_writer = None
|
|
|
+ if vdl_save_dir is not None:
|
|
|
+ if not osp.isdir(vdl_save_dir):
|
|
|
+ if osp.exists(vdl_save_dir):
|
|
|
+ os.remove(vdl_save_dir)
|
|
|
+ os.makedirs(vdl_save_dir)
|
|
|
+ from visualdl import LogWriter
|
|
|
+ vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
|
|
|
+ self.images_writer = LogWriter(vdl_images_dir)
|
|
|
+
|
|
|
+ def release_vdl(self):
|
|
|
+ self.images_writer = None
|
|
|
+
|
|
|
+ def __call__(self, im, im_info=None, label=None, step=0):
|
|
|
"""
|
|
|
Args:
|
|
|
im (str/np.ndarray): 图像路径/图像np.ndarray数据。
|
|
|
@@ -75,7 +92,7 @@ class Compose(SegTransform):
|
|
|
Returns:
|
|
|
tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
|
|
|
"""
|
|
|
-
|
|
|
+ im_file = str(step)
|
|
|
if im_info is None:
|
|
|
im_info = list()
|
|
|
if isinstance(im, np.ndarray):
|
|
|
@@ -85,6 +102,7 @@ class Compose(SegTransform):
|
|
|
format(len(im.shape)))
|
|
|
else:
|
|
|
try:
|
|
|
+ im_file = im
|
|
|
im = cv2.imread(im).astype('float32')
|
|
|
except:
|
|
|
raise ValueError('Can\'t read The image file {}!'.format(im))
|
|
|
@@ -93,6 +111,11 @@ class Compose(SegTransform):
|
|
|
if label is not None:
|
|
|
if not isinstance(label, np.ndarray):
|
|
|
label = np.asarray(Image.open(label))
|
|
|
+ if self.images_writer is not None:
|
|
|
+ self.images_writer.add_image(tag='0. origin image',
|
|
|
+ img=im,
|
|
|
+ step=step)
|
|
|
+ op_id = 1
|
|
|
for op in self.transforms:
|
|
|
if isinstance(op, SegTransform):
|
|
|
outputs = op(im, im_info, label)
|
|
|
@@ -107,6 +130,12 @@ class Compose(SegTransform):
|
|
|
outputs = (im, im_info, label)
|
|
|
else:
|
|
|
outputs = (im, im_info)
|
|
|
+ if self.images_writer is not None:
|
|
|
+ tag = str(op_id) + '. ' + op.__class__.__name__
|
|
|
+ self.images_writer.add_image(tag=tag,
|
|
|
+ img=im,
|
|
|
+ step=step)
|
|
|
+ op_id += 1
|
|
|
return outputs
|
|
|
|
|
|
def add_augmenters(self, augmenters):
|
|
|
@@ -1053,6 +1082,7 @@ class RandomDistort(SegTransform):
|
|
|
params['im'] = im
|
|
|
if np.random.uniform(0, 1) < prob:
|
|
|
im = ops[id](**params)
|
|
|
+ im = im.astype('float32')
|
|
|
if label is None:
|
|
|
return (im, im_info)
|
|
|
else:
|