sunyanfang01 5 år sedan
förälder
incheckning
de059b2063

+ 1 - 0
paddlex/__init__.py

@@ -48,6 +48,7 @@ if hub.version.hub_version < '1.6.2':
 env_info = get_environ_info()
 load_model = cv.models.load_model
 datasets = cv.datasets
+transforms = cv.transforms
 
 log_level = 2
 

+ 2 - 0
paddlex/cv/transforms/__init__.py

@@ -15,3 +15,5 @@
 from . import cls_transforms
 from . import det_transforms
 from . import seg_transforms
+from . import visualize
+visualize = visualize.visualize

+ 8 - 26
paddlex/cv/transforms/cls_transforms.py

@@ -58,24 +58,8 @@ class Compose(ClsTransform):
                     raise Exception(
                         "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
-        self.images_writer = None
-                    
-    def set_vdl(self, vdl_save_dir=None):
-        # 对数据预处理结果在VisualDL中可视化 
-        self.images_writer = None
-        if vdl_save_dir is not None:
-            if not osp.isdir(vdl_save_dir):
-                if osp.exists(vdl_save_dir):
-                    os.remove(vdl_save_dir)
-                os.makedirs(vdl_save_dir)
-            from visualdl import LogWriter
-            vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
-            self.images_writer = LogWriter(vdl_images_dir)
-            
-    def release_vdl(self):
-        self.images_writer = None
-
-    def __call__(self, im, label=None, step=0):
+
+    def __call__(self, im, label=None, images_writer=None, step=0):
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
@@ -84,7 +68,6 @@ class Compose(ClsTransform):
             tuple: 根据网络所需字段所组成的tuple;
                 字段由transforms中的最后一个数据预处理操作决定。
         """
-        im_file = str(step)
         if isinstance(im, np.ndarray):
             if len(im.shape) != 3:
                 raise Exception(
@@ -92,15 +75,14 @@ class Compose(ClsTransform):
                     format(len(im.shape)))
         else:
             try:
-                im_file = im
                 im = cv2.imread(im).astype('float32')
             except:
                 raise TypeError('Can\'t read The image file {}!'.format(im))
         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
-        if self.images_writer is not None:
-            self.images_writer.add_image(tag='0. origin image',
-                                        img=im,
-                                        step=step)
+        if images_writer is not None:
+            images_writer.add_image(tag='0. origin image',
+                                    img=im,
+                                    step=step)
         op_id = 1
         for op in self.transforms:
             if isinstance(op, ClsTransform):
@@ -115,9 +97,9 @@ class Compose(ClsTransform):
                 outputs = (im, )
                 if label is not None:
                     outputs = (im, label)
-            if self.images_writer is not None:
+            if images_writer is not None:
                 tag = str(op_id) + '. ' + op.__class__.__name__
-                self.images_writer.add_image(tag=tag,
+                images_writer.add_image(tag=tag,
                                         img=im,
                                         step=step)
             op_id += 1

+ 8 - 28
paddlex/cv/transforms/det_transforms.py

@@ -51,7 +51,7 @@ class Compose(DetTransform):
         ValueError: 数据长度不匹配。
     """
 
-    def __init__(self, transforms, vdl_save_dir=None):
+    def __init__(self, transforms):
         if not isinstance(transforms, list):
             raise TypeError('The transforms must be a list!')
         if len(transforms) < 1:
@@ -70,24 +70,8 @@ class Compose(DetTransform):
                     raise Exception(
                         "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
-        self.images_writer = None
-                    
-    def set_vdl(self, vdl_save_dir=None):
-        # 对数据预处理结果在VisualDL中可视化 
-        self.images_writer = None
-        if vdl_save_dir is not None:
-            if not osp.isdir(vdl_save_dir):
-                if osp.exists(vdl_save_dir):
-                    os.remove(vdl_save_dir)
-                os.makedirs(vdl_save_dir)
-            from visualdl import LogWriter
-            vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
-            self.images_writer = LogWriter(vdl_images_dir)
-            
-    def release_vdl(self):
-        self.images_writer = None
 
-    def __call__(self, im, im_info=None, label_info=None, step=0):
+    def __call__(self, im, im_info=None, label_info=None, images_writer=None, step=0):
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
@@ -151,19 +135,15 @@ class Compose(DetTransform):
             else:
                 return (im, im_info, label_info)
             
-        if isinstance(im, str):
-            im_file = im
-        else:
-            im_file = str(step)
         outputs = decode_image(im, im_info, label_info)
         im = outputs[0]
         im_info = outputs[1]
         if len(outputs) == 3:
             label_info = outputs[2]
-        if self.images_writer is not None:
-            self.images_writer.add_image(tag='0. origin image',
-                                        img=im,
-                                        step=step)
+        if images_writer is not None:
+            images_writer.add_image(tag='0. origin image',
+                                    img=im,
+                                    step=step)
         op_id = 1
         for op in self.transforms:
             if im is None:
@@ -177,9 +157,9 @@ class Compose(DetTransform):
                     outputs = (im, im_info, label_info)
                 else:
                     outputs = (im, im_info)
-            if self.images_writer is not None:
+            if images_writer is not None:
                 tag = str(op_id) + '. ' + op.__class__.__name__
-                self.images_writer.add_image(tag=tag,
+                images_writer.add_image(tag=tag,
                                         img=im,
                                         step=step)
             op_id += 1

+ 9 - 27
paddlex/cv/transforms/seg_transforms.py

@@ -46,7 +46,7 @@ class Compose(SegTransform):
 
     """
 
-    def __init__(self, transforms, vdl_save_dir=None):
+    def __init__(self, transforms):
         if not isinstance(transforms, list):
             raise TypeError('The transforms must be a list!')
         if len(transforms) < 1:
@@ -62,24 +62,8 @@ class Compose(SegTransform):
                     raise Exception(
                         "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
-        self.images_writer = None
-                    
-    def set_vdl(self, vdl_save_dir=None):
-        # 对数据预处理结果在VisualDL中可视化 
-        self.images_writer = None
-        if vdl_save_dir is not None:
-            if not osp.isdir(vdl_save_dir):
-                if osp.exists(vdl_save_dir):
-                    os.remove(vdl_save_dir)
-                os.makedirs(vdl_save_dir)
-            from visualdl import LogWriter
-            vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
-            self.images_writer = LogWriter(vdl_images_dir)
-            
-    def release_vdl(self):
-        self.images_writer = None
-
-    def __call__(self, im, im_info=None, label=None, step=0):
+       
+    def __call__(self, im, im_info=None, label=None, images_writer=None, step=0):
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
@@ -92,7 +76,6 @@ class Compose(SegTransform):
         Returns:
             tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
         """
-        im_file = str(step)
         if im_info is None:
             im_info = list()
         if isinstance(im, np.ndarray):
@@ -102,7 +85,6 @@ class Compose(SegTransform):
                     format(len(im.shape)))
         else:
             try:
-                im_file = im
                 im = cv2.imread(im).astype('float32')
             except:
                 raise ValueError('Can\'t read The image file {}!'.format(im))
@@ -111,10 +93,10 @@ class Compose(SegTransform):
         if label is not None:
             if not isinstance(label, np.ndarray):
                 label = np.asarray(Image.open(label))
-        if self.images_writer is not None:
-            self.images_writer.add_image(tag='0. origin image',
-                                        img=im,
-                                        step=step)
+        if images_writer is not None:
+            images_writer.add_image(tag='0. origin image',
+                                    img=im,
+                                    step=step)
         op_id = 1
         for op in self.transforms:
             if isinstance(op, SegTransform):
@@ -130,9 +112,9 @@ class Compose(SegTransform):
                     outputs = (im, im_info, label)
                 else:
                     outputs = (im, im_info)
-            if self.images_writer is not None:
+            if images_writer is not None:
                 tag = str(op_id) + '. ' + op.__class__.__name__
-                self.images_writer.add_image(tag=tag,
+                images_writer.add_image(tag=tag,
                                         img=im,
                                         step=step)
             op_id += 1

+ 39 - 0
paddlex/cv/transforms/visualize.py

@@ -0,0 +1,39 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from .cls_transforms import ClsTransform
+from .det_transforms import DetTransform
+from .seg_transforms import SegTransform
+
+def visualize(dataset, index=0, steps=3, save_dir='vdl_output'):
+    transforms = dataset.transforms
+    if not osp.isdir(save_dir):
+        if osp.exists(save_dir):
+            os.remove(save_dir)
+        os.makedirs(save_dir)
+    for i, data in enumerate(dataset.iterator()):
+        if i == index:
+            break
+    from visualdl import LogWriter
+    vdl_save_dir = osp.join(save_dir, 'image_transforms')
+    images_writer = LogWriter(vdl_save_dir)
+    data.append(images_writer)
+    for s in range(steps):
+        if s != 0:
+            data.pop()
+        data.append(s)
+        transforms(*data)
+        

+ 0 - 13
tutorials/train/classification/mobilenetv2.py

@@ -34,19 +34,6 @@ eval_dataset = pdx.datasets.ImageNet(
     label_list='vegetables_cls/labels.txt',
     transforms=eval_transforms)
 
-# 可使用VisualDL查看数据预处理的中间结果
-# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
-# 浏览器打开 https://0.0.0.0:8001即可
-# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
-train_transforms.set_vdl(vdl_save_dir='vdl_output')
-for step, data in enumerate(train_dataset.iterator()):
-    data.append(step)
-    train_transforms(*data)
-    if step == 5:
-        break
-train_transforms.release_vdl()
-
-
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标
 # VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001

+ 0 - 12
tutorials/train/detection/yolov3_darknet53.py

@@ -38,18 +38,6 @@ eval_dataset = pdx.datasets.VOCDetection(
     label_list='insect_det/labels.txt',
     transforms=eval_transforms)
 
-# 可使用VisualDL查看数据预处理的中间结果
-# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
-# 浏览器打开 https://0.0.0.0:8001即可
-# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
-train_transforms.set_vdl(vdl_save_dir='vdl_output')
-for step, data in enumerate(train_dataset.iterator()):
-    data.append(step)
-    train_transforms(*data)
-    if step == 5:
-        break
-train_transforms.release_vdl()
-
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标
 # VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001

+ 0 - 12
tutorials/train/segmentation/deeplabv3p.py

@@ -33,18 +33,6 @@ eval_dataset = pdx.datasets.SegDataset(
     label_list='optic_disc_seg/labels.txt',
     transforms=eval_transforms)
 
-# 可使用VisualDL查看数据预处理的中间结果
-# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
-# 浏览器打开 https://0.0.0.0:8001即可
-# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
-train_transforms.vdl_save_dir = 'vdl_output'
-for step, data in enumerate(train_dataset.iterator()):
-    data.append(step)
-    train_transforms(*data)
-    if step == 5:
-        break
-train_transforms.vdl_save_dir = None
-
 # 初始化模型,并进行训练
 # 可使用VisualDL查看训练指标
 # VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001