mamingjie-China %!s(int64=5) %!d(string=hai) anos
pai
achega
7420b0f4ac

+ 0 - 2
paddlex/command.py

@@ -200,10 +200,8 @@ def main():
             logging.error("The value of split is not correct.")
         if not osp.exists(save_dir):
             logging.error("The path of saved split information doesn't exist.")
-        print(11111111111111)
         pdx.tools.split.dataset_split(dataset_dir, dataset_form, val_value,
                                       test_value, save_dir)
-        print(222222222)
 
 
 if __name__ == "__main__":

+ 0 - 0
paddlex/tools/dataset_split/__init__.py


+ 69 - 0
paddlex/tools/dataset_split/coco_split.py

@@ -0,0 +1,69 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+import json
+from .utils import MyEncoder
+
+
+def split_coco_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    if not osp.exists(osp.join(dataset_dir, "annotations.json")):
+        raise ValueError("\'annotations.json\' is not found in {}!".format(
+            dataset_dir))
+    try:
+        from pycocotools.coco import COCO
+    except:
+        print(
+            "pycococotools is not installed, follow this doc install pycocotools: https://paddlex.readthedocs.io/zh_CN/develop/install.html#pycocotools"
+        )
+        return
+
+    annotation_file = osp.join(dataset_dir, "annotations.json")
+    coco = COCO(annotation_file)
+    img_ids = coco.getImgIds()
+    cat_ids = coco.getCatIds()
+    anno_ids = coco.getAnnIds()
+
+    val_num = int(len(img_ids) * val_percent)
+    test_num = int(len(img_ids) * test_percent)
+    train_num = len(img_ids) - val_num - test_num
+
+    random.shuffle(img_ids)
+    train_files_ids = img_ids[:train_num]
+    val_files_ids = img_ids[train_num:train_num + val_num]
+    test_files_ids = img_ids[train_num + val_num:]
+
+    for img_id_list in [train_files_ids, val_files_ids, test_files_ids]:
+        img_anno_ids = coco.getAnnIds(imgIds=img_id_list, iscrowd=0)
+        imgs = coco.loadImgs(img_id_list)
+        instances = coco.loadAnns(img_anno_ids)
+        categories = coco.loadCats(cat_ids)
+        img_dict = {
+            "annotations": instances,
+            "images": imgs,
+            "categories": categories
+        }
+
+        if img_id_list == train_files_ids:
+            json_file = open(osp.join(save_dir, 'train.json'), 'w+')
+            json.dump(img_dict, json_file, cls=MyEncoder)
+        elif img_id_list == val_files_ids:
+            json_file = open(osp.join(save_dir, 'val.json'), 'w+')
+            json.dump(img_dict, json_file, cls=MyEncoder)
+        elif img_id_list == test_files_ids and len(test_files_ids):
+            json_file = open(osp.join(save_dir, 'test.json'), 'w+')
+            json.dump(img_dict, json_file, cls=MyEncoder)
+
+    return train_num, val_num, test_num

+ 74 - 0
paddlex/tools/dataset_split/imagenet_split.py

@@ -0,0 +1,74 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+from .utils import list_files, is_pic
+
+
+def split_imagenet_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    all_files = list_files(dataset_dir)
+    label_list = list()
+    train_image_anno_list = list()
+    val_image_anno_list = list()
+    test_image_anno_list = list()
+    for file in all_files:
+        if not is_pic(file):
+            continue
+        label, image_name = osp.split(file)
+        if label not in label_list:
+            label_list.append(label)
+    label_list = sorted(label_list)
+
+    for i in range(len(label_list)):
+        image_list = list_files(osp.join(dataset_dir, label_list[i]))
+        image_anno_list = list()
+        for img in image_list:
+            image_anno_list.append([osp.join(label_list[i], img), i])
+        random.shuffle(image_anno_list)
+        image_num = len(image_anno_list)
+        val_num = int(image_num * val_percent)
+        test_num = int(image_num * test_percent)
+        train_num = image_num - val_num - test_num
+
+        train_image_anno_list += image_anno_list[:train_num]
+        val_image_anno_list += image_anno_list[train_num:train_num + val_num]
+        test_image_anno_list += image_anno_list[train_num + val_num:]
+
+    with open(
+            osp.join(save_dir, 'train_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in train_image_anno_list:
+            file, label = x
+            f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'val_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in val_image_anno_list:
+            file, label = x
+            f.write('{} {}\n'.format(file, label))
+    if len(test_image_anno_list):
+        with open(
+                osp.join(save_dir, 'test_list.txt'), mode='w',
+                encoding='utf-8') as f:
+            for x in test_image_anno_list:
+                file, label = x
+                f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
+        for l in sorted(label_list):
+            f.write('{}\n'.format(l))
+
+    return len(train_image_anno_list), len(val_image_anno_list), len(
+        test_image_anno_list)

+ 93 - 0
paddlex/tools/dataset_split/seg_split.py

@@ -0,0 +1,93 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+from .utils import list_files, is_pic, replace_ext, read_seg_ann
+
+
+def split_seg_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    if not osp.exists(osp.join(dataset_dir, "JPEGImages")):
+        raise ValueError("\'JPEGImages\' is not found in {}!".format(
+            dataset_dir))
+    if not osp.exists(osp.join(dataset_dir, "Annotations")):
+        raise ValueError("\'Annotations\' is not found in {}!".format(
+            dataset_dir))
+
+    all_image_files = list_files(osp.join(dataset_dir, "JPEGImages"))
+
+    image_anno_list = list()
+    label_list = list()
+    for image_file in all_image_files:
+        if not is_pic(image_file):
+            continue
+        anno_name = replace_ext(image_file, "png")
+        if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
+            image_anno_list.append([image_file, anno_name])
+        else:
+            anno_name = replace_ext(image_file, "PNG")
+            if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
+                image_anno_list.append([image_file, anno_name])
+
+    if not osp.exists(osp.join(dataset_dir, "labels.txt")):
+        for image_anno in image_anno_list:
+            labels = read_seg_ann(
+                osp.join(dataset_dir, "Annotations", anno_name))
+            for i in labels:
+                if i not in label_list:
+                    label_list.append(i)
+        # 如果类标签的最大值大于类别数,添加对应缺失的标签
+        if len(label_list) != max(label_list) + 1:
+            label_list = [i for i in range(max(label_list) + 1)]
+
+    random.shuffle(image_anno_list)
+    image_num = len(image_anno_list)
+    val_num = int(image_num * val_percent)
+    test_num = int(image_num * test_percent)
+    train_num = image_num - val_num - test_num
+
+    train_image_anno_list = image_anno_list[:train_num]
+    val_image_anno_list = image_anno_list[train_num:train_num + val_num]
+    test_image_anno_list = image_anno_list[train_num + val_num:]
+
+    with open(
+            osp.join(save_dir, 'train_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in train_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'val_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in val_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    if len(test_image_anno_list):
+        with open(
+                osp.join(save_dir, 'test_list.txt'), mode='w',
+                encoding='utf-8') as f:
+            for x in test_image_anno_list:
+                file = osp.join("JPEGImages", x[0])
+                label = osp.join("Annotations", x[1])
+                f.write('{} {}\n'.format(file, label))
+    if len(label_list):
+        with open(
+                osp.join(save_dir, 'labels.txt'), mode='w',
+                encoding='utf-8') as f:
+            for l in sorted(label_list):
+                f.write('{}\n'.format(l))
+
+    return train_num, val_num, test_num

+ 102 - 0
paddlex/tools/dataset_split/utils.py

@@ -0,0 +1,102 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from PIL import Image
+import numpy as np
+import json
+
+
+class MyEncoder(json.JSONEncoder):
+    # 调整json文件存储形式
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        else:
+            return super(MyEncoder, self).default(obj)
+
+
+def list_files(dirname):
+    """ 列出目录下所有文件(包括所属的一级子目录下文件)
+
+    Args:
+        dirname: 目录路径
+    """
+
+    def filter_file(f):
+        if f.startswith('.'):
+            return True
+        return False
+
+    all_files = list()
+    dirs = list()
+    for f in os.listdir(dirname):
+        if filter_file(f):
+            continue
+        if osp.isdir(osp.join(dirname, f)):
+            dirs.append(f)
+        else:
+            all_files.append(f)
+    for d in dirs:
+        for f in os.listdir(osp.join(dirname, d)):
+            if filter_file(f):
+                continue
+            if osp.isdir(osp.join(dirname, d, f)):
+                continue
+            all_files.append(osp.join(d, f))
+    return all_files
+
+
+def is_pic(filename):
+    """ 判断文件是否为图片格式
+
+    Args:
+        filename: 文件路径
+    """
+    suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
+    suffix = filename.strip().split('.')[-1]
+    if suffix not in suffixes:
+        return False
+    return True
+
+
+def replace_ext(filename, new_ext):
+    """ 替换文件后缀
+
+    Args:
+        filename: 文件路径
+        new_ext: 需要替换的新的后缀
+    """
+    items = filename.split(".")
+    items[-1] = new_ext
+    new_filename = ".".join(items)
+    return new_filename
+
+
+def read_seg_ann(pngfile):
+    """ 解析语义分割的标注png图片
+
+    Args:
+        pngfile: 包含标注信息的png图片路径
+    """
+    grt = np.asarray(Image.open(pngfile))
+    labels = list(np.unique(grt))
+    if 255 in labels:
+        labels.remove(255)
+    return labels

+ 88 - 0
paddlex/tools/dataset_split/voc_split.py

@@ -0,0 +1,88 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import random
+import xml.etree.ElementTree as ET
+from .utils import list_files, is_pic, replace_ext
+
+
+def split_voc_dataset(dataset_dir, val_percent, test_percent, save_dir):
+    if not osp.exists(osp.join(dataset_dir, "JPEGImages")):
+        raise ValueError("\'JPEGImages\' is not found in {}!".format(
+            dataset_dir))
+    if not osp.exists(osp.join(dataset_dir, "Annotations")):
+        raise ValueError("\'Annotations\' is not found in {}!".format(
+            dataset_dir))
+
+    all_image_files = list_files(osp.join(dataset_dir, "JPEGImages"))
+
+    image_anno_list = list()
+    label_list = list()
+    for image_file in all_image_files:
+        if not is_pic(image_file):
+            continue
+        anno_name = replace_ext(image_file, "xml")
+        if osp.exists(osp.join(dataset_dir, "Annotations", anno_name)):
+            image_anno_list.append([image_file, anno_name])
+            try:
+                tree = ET.parse(
+                    osp.join(dataset_dir, "Annotations", anno_name))
+            except:
+                raise Exception("文件{}不是一个良构的xml文件,请检查标注文件".format(
+                    osp.join(dataset_dir, "Annotations", anno_name)))
+            objs = tree.findall("object")
+            for i, obj in enumerate(objs):
+                cname = obj.find('name').text
+                if not cname in label_list:
+                    label_list.append(cname)
+
+    random.shuffle(image_anno_list)
+    image_num = len(image_anno_list)
+    val_num = int(image_num * val_percent)
+    test_num = int(image_num * test_percent)
+    train_num = image_num - val_num - test_num
+
+    train_image_anno_list = image_anno_list[:train_num]
+    val_image_anno_list = image_anno_list[train_num:train_num + val_num]
+    test_image_anno_list = image_anno_list[train_num + val_num:]
+
+    with open(
+            osp.join(save_dir, 'train_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in train_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'val_list.txt'), mode='w',
+            encoding='utf-8') as f:
+        for x in val_image_anno_list:
+            file = osp.join("JPEGImages", x[0])
+            label = osp.join("Annotations", x[1])
+            f.write('{} {}\n'.format(file, label))
+    if len(test_image_anno_list):
+        with open(
+                osp.join(save_dir, 'test_list.txt'), mode='w',
+                encoding='utf-8') as f:
+            for x in test_image_anno_list:
+                file = osp.join("JPEGImages", x[0])
+                label = osp.join("Annotations", x[1])
+                f.write('{} {}\n'.format(file, label))
+    with open(
+            osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
+        for l in sorted(label_list):
+            f.write('{}\n'.format(l))
+
+    return train_num, val_num, test_num

+ 22 - 1
paddlex/tools/split.py

@@ -14,7 +14,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from .dataset_split.coco_split import split_coco_dataset
+from .dataset_split.voc_split import split_voc_dataset
+from .dataset_split.imagenet_split import split_imagenet_dataset
+from .dataset_split.seg_split import split_seg_dataset
+
 
 def dataset_split(dataset_dir, dataset_form, val_value, test_value, save_dir):
     print(dataset_dir, dataset_form, val_value, test_value, save_dir)
-    print(12345)
+    if dataset_form == "coco":
+        train_num, val_num, test_num = split_coco_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    elif dataset_form == "voc":
+        train_num, val_num, test_num = split_voc_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    elif dataset_form == "seg":
+        train_num, val_num, test_num = split_seg_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    elif dataset_form == "imagenet":
+        train_num, val_num, test_num = split_imagenet_dataset(
+            dataset_dir, val_value, test_value, save_dir)
+    print("Dataset Split Done.")
+    print("Train samples: {}".format(train_num))
+    print("Eval samples: {}".format(val_num))
+    print("Test samples: {}".format(test_num))
+    print("Split file saved in {}".format(save_dir))