| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import os
- import os.path as osp
- import shutil
- import numpy as np
- from PIL import Image, ImageDraw
- from .....utils import logging
- from .....utils.deps import function_requires_deps, is_dep_available
- from .....utils.file_interface import custom_open
- from .....utils.logging import info
- if is_dep_available("opencv-contrib-python"):
- import cv2
- def convert_dataset(dataset_type, input_dir):
- """convert to paddlex official format"""
- if dataset_type == "LabelMe":
- return convert_labelme_dataset(input_dir)
- elif dataset_type == "MVTec_AD":
- return convert_mvtec_dataset(input_dir)
- else:
- raise NotImplementedError(dataset_type)
- @function_requires_deps("opencv-contrib-python")
- def convert_labelme_dataset(input_dir):
- """convert labelme format to paddlex official format"""
- bg_name = "_background_"
- ignore_name = "__ignore__"
- # prepare dir
- output_img_dir = osp.join(input_dir, "images")
- output_annot_dir = osp.join(input_dir, "annotations")
- if not osp.exists(output_img_dir):
- os.makedirs(output_img_dir)
- if not osp.exists(output_annot_dir):
- os.makedirs(output_annot_dir)
- # collect class_names and set class_name_to_id
- class_names = []
- class_name_to_id = {}
- split_tags = ["train", "val"]
- for tag in split_tags:
- mapping_file = osp.join(input_dir, f"{tag}_anno_list.txt")
- with open(mapping_file, "r") as f:
- label_files = [
- osp.join(input_dir, line.strip("\n")) for line in f.readlines()
- ]
- for label_file in label_files:
- with custom_open(label_file, "r") as fp:
- data = json.load(fp)
- for shape in data["shapes"]:
- cls_name = shape["label"]
- if cls_name not in class_names:
- class_names.append(cls_name)
- if ignore_name in class_names:
- class_name_to_id[ignore_name] = 255
- class_names.remove(ignore_name)
- if bg_name in class_names:
- class_names.remove(bg_name)
- class_name_to_id[bg_name] = 0
- for i, name in enumerate(class_names):
- class_name_to_id[name] = i + 1
- if len(class_names) > 256:
- raise ValueError(
- f"There are {len(class_names)} categories in the annotation file, "
- f"exceeding 256, Not compliant with paddlex official format!"
- )
- # create annotated images and copy origin images
- color_map = get_color_map_list(256)
- img_file_list = []
- label_file_list = []
- for i, label_file in enumerate(label_files):
- filename = osp.splitext(osp.basename(label_file))[0]
- annotated_img_path = osp.join(output_annot_dir, filename + ".png")
- with custom_open(label_file, "r") as f:
- data = json.load(f)
- img_path = osp.join(osp.dirname(label_file), data["imagePath"])
- if not os.path.exists(img_path):
- logging.info("%s is not existed, skip this image" % img_path)
- continue
- img_name = img_path.split("/")[-1]
- img_file_list.append(f"images/{img_name}")
- label_img_name = annotated_img_path.split("/")[-1]
- label_file_list.append(f"annotations/{label_img_name}")
- img = np.asarray(cv2.imread(img_path))
- lbl = shape2label(
- img_size=img.shape,
- shapes=data["shapes"],
- class_name_mapping=class_name_to_id,
- )
- lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode="P")
- lbl_pil.putpalette(color_map)
- lbl_pil.save(annotated_img_path)
- shutil.copy(img_path, output_img_dir)
- with custom_open(osp.join(input_dir, f"{tag}.txt"), "w") as fp:
- for img_path, lbl_path in zip(img_file_list, label_file_list):
- fp.write(f"{img_path} {lbl_path}\n")
- with custom_open(osp.join(input_dir, "class_name.txt"), "w") as fp:
- for name in class_names:
- fp.write(f"{name}{os.linesep}")
- with custom_open(osp.join(input_dir, "class_name_to_id.txt"), "w") as fp:
- for key, val in class_name_to_id.items():
- fp.write(f"{val}: {key}{os.linesep}")
- return input_dir
- def get_color_map_list(num_classes):
- """get color map list"""
- num_classes += 1
- color_map = num_classes * [0, 0, 0]
- for i in range(0, num_classes):
- j = 0
- lab = i
- while lab:
- color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
- color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
- color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
- j += 1
- lab >>= 3
- color_map = color_map[3:]
- return color_map
- def shape2label(img_size, shapes, class_name_mapping):
- """根据输入的形状列表,将图像的标签矩阵填充为对应形状的类别编号"""
- label = np.zeros(img_size[:2], dtype=np.int32)
- for shape in shapes:
- points = shape["points"]
- class_name = shape["label"]
- label_mask = polygon2mask(img_size[:2], points)
- label[label_mask] = class_name_mapping[class_name]
- return label
- def polygon2mask(img_size, points):
- """将给定形状的点转换成对应的掩膜"""
- label_mask = Image.fromarray(np.zeros(img_size[:2], dtype=np.uint8))
- image_draw = ImageDraw.Draw(label_mask)
- points_list = [tuple(point) for point in points]
- assert len(points_list) > 2, ValueError("Polygon must have points more than 2")
- image_draw.polygon(xy=points_list, outline=1, fill=1)
- return np.array(label_mask, dtype=bool)
- def save_item_to_txt(items, file_path):
- try:
- with open(file_path, "a") as file:
- file.write(items)
- file.close()
- except Exception as e:
- print(f"Saving_error: {e}")
- def save_training_txt(cls_root, mode, cat):
- imgs = os.listdir(os.path.join(cls_root, mode, cat))
- imgs.sort()
- for img in imgs:
- if mode == "train":
- item = os.path.join(cls_root, mode, cat, img)
- items = item + " " + item + "\n"
- save_item_to_txt(items, os.path.join(cls_root, "train.txt"))
- elif mode == "test" and cat != "good":
- item1 = os.path.join(cls_root, mode, cat, img)
- item2 = os.path.join(
- cls_root, "ground_truth", cat, img.split(".")[0] + "_mask.png"
- )
- items = item1 + " " + item2 + "\n"
- save_item_to_txt(items, os.path.join(cls_root, "val.txt"))
- def check_old_txt(cls_pth, mode):
- set_name = "train.txt" if mode == "train" else "val.txt"
- pth = os.path.join(cls_pth, set_name)
- if os.path.exists(pth):
- os.remove(pth)
- def convert_mvtec_dataset(input_dir):
- classes = [
- "bottle",
- "cable",
- "capsule",
- "hazelnut",
- "metal_nut",
- "pill",
- "screw",
- "toothbrush",
- "transistor",
- "zipper",
- "carpet",
- "grid",
- "leather",
- "tile",
- "wood",
- ]
- clas = os.path.split(input_dir)[-1]
- assert clas in classes, info(
- f"Make sure your class: '{clas}' in your dataset root in\n {classes}"
- )
- modes = ["train", "test"]
- cls_root = input_dir
- for mode in modes:
- check_old_txt(cls_root, mode)
- cats = os.listdir(os.path.join(cls_root, mode))
- for cat in cats:
- save_training_txt(cls_root, mode, cat)
- info(f"Add train.txt/val.txt successfully for {input_dir}")
|