فهرست منبع

add transforms vdl

sunyanfang01 5 سال پیش
والد
کامیت
4cbf8369c4

+ 33 - 2
paddlex/cv/transforms/cls_transforms.py

@@ -15,6 +15,7 @@
 from .ops import *
 from .imgaug_support import execute_imgaug
 import random
+import os
 import os.path as osp
 import numpy as np
 from PIL import Image, ImageEnhance
@@ -57,8 +58,24 @@ class Compose(ClsTransform):
                     raise Exception(
                         "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
-
-    def __call__(self, im, label=None):
+        self.images_writer = None
+                    
+    def set_vdl(self, vdl_save_dir=None):
+        # 对数据预处理结果在VisualDL中可视化 
+        self.images_writer = None
+        if vdl_save_dir is not None:
+            if not osp.isdir(vdl_save_dir):
+                if osp.exists(vdl_save_dir):
+                    os.remove(vdl_save_dir)
+                os.makedirs(vdl_save_dir)
+            from visualdl import LogWriter
+            vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
+            self.images_writer = LogWriter(vdl_images_dir)
+            
+    def release_vdl(self):
+        self.images_writer = None
+
+    def __call__(self, im, label=None, step=0):
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
@@ -67,6 +84,7 @@ class Compose(ClsTransform):
             tuple: 根据网络所需字段所组成的tuple;
                 字段由transforms中的最后一个数据预处理操作决定。
         """
+        im_file = str(step)
         if isinstance(im, np.ndarray):
             if len(im.shape) != 3:
                 raise Exception(
@@ -74,10 +92,16 @@ class Compose(ClsTransform):
                     format(len(im.shape)))
         else:
             try:
+                im_file = im
                 im = cv2.imread(im).astype('float32')
             except:
                 raise TypeError('Can\'t read The image file {}!'.format(im))
         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        if self.images_writer is not None:
+            self.images_writer.add_image(tag='0. origin image',
+                                        img=im,
+                                        step=step)
+        op_id = 1
         for op in self.transforms:
             if isinstance(op, ClsTransform):
                 outputs = op(im, label)
@@ -91,6 +115,12 @@ class Compose(ClsTransform):
                 outputs = (im, )
                 if label is not None:
                     outputs = (im, label)
+            if self.images_writer is not None:
+                tag = str(op_id) + '. ' + op.__class__.__name__
+                self.images_writer.add_image(tag=tag,
+                                        img=im,
+                                        step=step)
+            op_id += 1
         return outputs
 
     def add_augmenters(self, augmenters):
@@ -434,6 +464,7 @@ class RandomDistort(ClsTransform):
             params['im'] = im
             if np.random.uniform(0, 1) < prob:
                 im = ops[id](**params)
+        im = im.astype('float32')
         if label is None:
             return (im, )
         else:

+ 37 - 4
paddlex/cv/transforms/det_transforms.py

@@ -18,6 +18,7 @@ except Exception:
     from collections import Sequence
 
 import random
+import os
 import os.path as osp
 import numpy as np
 
@@ -50,7 +51,7 @@ class Compose(DetTransform):
         ValueError: 数据长度不匹配。
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms, vdl_save_dir=None):
         if not isinstance(transforms, list):
             raise TypeError('The transforms must be a list!')
         if len(transforms) < 1:
@@ -69,8 +70,24 @@ class Compose(DetTransform):
                     raise Exception(
                         "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
-
-    def __call__(self, im, im_info=None, label_info=None):
+        self.images_writer = None
+                    
+    def set_vdl(self, vdl_save_dir=None):
+        # 对数据预处理结果在VisualDL中可视化 
+        self.images_writer = None
+        if vdl_save_dir is not None:
+            if not osp.isdir(vdl_save_dir):
+                if osp.exists(vdl_save_dir):
+                    os.remove(vdl_save_dir)
+                os.makedirs(vdl_save_dir)
+            from visualdl import LogWriter
+            vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
+            self.images_writer = LogWriter(vdl_images_dir)
+            
+    def release_vdl(self):
+        self.images_writer = None
+
+    def __call__(self, im, im_info=None, label_info=None, step=0):
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
@@ -133,12 +150,21 @@ class Compose(DetTransform):
                 return (im, im_info)
             else:
                 return (im, im_info, label_info)
-
+            
+        if isinstance(im, str):
+            im_file = im
+        else:
+            im_file = str(step)
         outputs = decode_image(im, im_info, label_info)
         im = outputs[0]
         im_info = outputs[1]
         if len(outputs) == 3:
             label_info = outputs[2]
+        if self.images_writer is not None:
+            self.images_writer.add_image(tag='0. origin image',
+                                        img=im,
+                                        step=step)
+        op_id = 1
         for op in self.transforms:
             if im is None:
                 return None
@@ -151,6 +177,12 @@ class Compose(DetTransform):
                     outputs = (im, im_info, label_info)
                 else:
                     outputs = (im, im_info)
+            if self.images_writer is not None:
+                tag = str(op_id) + '. ' + op.__class__.__name__
+                self.images_writer.add_image(tag=tag,
+                                        img=im,
+                                        step=step)
+            op_id += 1
         return outputs
 
     def add_augmenters(self, augmenters):
@@ -621,6 +653,7 @@ class RandomDistort(DetTransform):
 
             if np.random.uniform(0, 1) < prob:
                 im = ops[id](**params)
+        im = im.astype('float32')
         if label_info is None:
             return (im, im_info)
         else:

+ 34 - 4
paddlex/cv/transforms/seg_transforms.py

@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
 from .ops import *
 from .imgaug_support import execute_imgaug
 import random
@@ -45,7 +46,7 @@ class Compose(SegTransform):
 
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms, vdl_save_dir=None):
         if not isinstance(transforms, list):
             raise TypeError('The transforms must be a list!')
         if len(transforms) < 1:
@@ -61,8 +62,24 @@ class Compose(SegTransform):
                     raise Exception(
                         "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
-
-    def __call__(self, im, im_info=None, label=None):
+        self.images_writer = None
+                    
+    def set_vdl(self, vdl_save_dir=None):
+        # 对数据预处理结果在VisualDL中可视化 
+        self.images_writer = None
+        if vdl_save_dir is not None:
+            if not osp.isdir(vdl_save_dir):
+                if osp.exists(vdl_save_dir):
+                    os.remove(vdl_save_dir)
+                os.makedirs(vdl_save_dir)
+            from visualdl import LogWriter
+            vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
+            self.images_writer = LogWriter(vdl_images_dir)
+            
+    def release_vdl(self):
+        self.images_writer = None
+
+    def __call__(self, im, im_info=None, label=None, step=0):
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
@@ -75,7 +92,7 @@ class Compose(SegTransform):
         Returns:
             tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
         """
-
+        im_file = str(step)
         if im_info is None:
             im_info = list()
         if isinstance(im, np.ndarray):
@@ -85,6 +102,7 @@ class Compose(SegTransform):
                     format(len(im.shape)))
         else:
             try:
+                im_file = im
                 im = cv2.imread(im).astype('float32')
             except:
                 raise ValueError('Can\'t read The image file {}!'.format(im))
@@ -93,6 +111,11 @@ class Compose(SegTransform):
         if label is not None:
             if not isinstance(label, np.ndarray):
                 label = np.asarray(Image.open(label))
+        if self.images_writer is not None:
+            self.images_writer.add_image(tag='0. origin image',
+                                        img=im,
+                                        step=step)
+        op_id = 1
         for op in self.transforms:
             if isinstance(op, SegTransform):
                 outputs = op(im, im_info, label)
@@ -107,6 +130,12 @@ class Compose(SegTransform):
                     outputs = (im, im_info, label)
                 else:
                     outputs = (im, im_info)
+            if self.images_writer is not None:
+                tag = str(op_id) + '. ' + op.__class__.__name__
+                self.images_writer.add_image(tag=tag,
+                                        img=im,
+                                        step=step)
+            op_id += 1
         return outputs
 
     def add_augmenters(self, augmenters):
@@ -1053,6 +1082,7 @@ class RandomDistort(SegTransform):
             params['im'] = im
             if np.random.uniform(0, 1) < prob:
                 im = ops[id](**params)
+        im = im.astype('float32')
         if label is None:
             return (im, im_info)
         else:

+ 13 - 0
tutorials/train/classification/mobilenetv2.py

@@ -34,6 +34,19 @@ eval_dataset = pdx.datasets.ImageNet(
     label_list='vegetables_cls/labels.txt',
     transforms=eval_transforms)
 
+# 可使用VisualDL查看数据预处理的中间结果
+# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
+# 浏览器打开 https://0.0.0.0:8001即可
+# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
+train_transforms.set_vdl(vdl_save_dir='vdl_output')
+for step, data in enumerate(train_dataset.iterator()):
+    data.append(step)
+    train_transforms(*data)
+    if step == 5:
+        break
+train_transforms.release_vdl()
+
+
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标
 # VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001

+ 24 - 0
tutorials/train/detection/yolov3_darknet53.py

@@ -38,6 +38,30 @@ eval_dataset = pdx.datasets.VOCDetection(
     label_list='insect_det/labels.txt',
     transforms=eval_transforms)
 
+# 可使用VisualDL查看数据预处理的中间结果
+# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
+# 浏览器打开 https://0.0.0.0:8001即可
+# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
+train_transforms.set_vdl(vdl_save_dir='vdl_output')
+for step, data in enumerate(train_dataset.iterator()):
+    data.append(step)
+    train_transforms(*data)
+    if step == 5:
+        break
+train_transforms.release_vdl()
+
+# 可使用VisualDL查看数据预处理的中间结果
+# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
+# 浏览器打开 https://0.0.0.0:8001即可
+# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
+train_transforms.vdl_save_dir = 'vdl_output'
+for step, data in enumerate(train_dataset.iterator()):
+    data.append(step)
+    train_transforms(*data)
+    if step == 5:
+        break
+train_transforms.vdl_save_dir = None
+
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标
 # VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001

+ 12 - 0
tutorials/train/segmentation/deeplabv3p.py

@@ -33,6 +33,18 @@ eval_dataset = pdx.datasets.SegDataset(
     label_list='optic_disc_seg/labels.txt',
     transforms=eval_transforms)
 
+# 可使用VisualDL查看数据预处理的中间结果
+# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
+# 浏览器打开 https://0.0.0.0:8001即可
+# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
+train_transforms.vdl_save_dir = 'vdl_output'
+for step, data in enumerate(train_dataset.iterator()):
+    data.append(step)
+    train_transforms(*data)
+    if step == 5:
+        break
+train_transforms.vdl_save_dir = None
+
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标
 # VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001