# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import json import os import os.path as osp import shutil import cv2 import numpy as np from PIL import Image, ImageDraw from .....utils.file_interface import custom_open from .....utils import logging from .....utils.logging import info def convert_dataset(dataset_type, input_dir): """convert to paddlex official format""" if dataset_type == "LabelMe": return convert_labelme_dataset(input_dir) elif dataset_type == "MVTec_AD": return convert_mvtec_dataset(input_dir) else: raise NotImplementedError(dataset_type) def convert_labelme_dataset(input_dir): """convert labelme format to paddlex official format""" bg_name = "_background_" ignore_name = "__ignore__" # prepare dir output_img_dir = osp.join(input_dir, "images") output_annot_dir = osp.join(input_dir, "annotations") if not osp.exists(output_img_dir): os.makedirs(output_img_dir) if not osp.exists(output_annot_dir): os.makedirs(output_annot_dir) # collect class_names and set class_name_to_id class_names = [] class_name_to_id = {} split_tags = ["train", "val"] for tag in split_tags: mapping_file = osp.join(input_dir, f"{tag}_anno_list.txt") with open(mapping_file, "r") as f: label_files = [ osp.join(input_dir, line.strip("\n")) for line in f.readlines() ] for label_file in label_files: with custom_open(label_file, "r") as fp: data = json.load(fp) for shape in data["shapes"]: cls_name = shape["label"] if cls_name not in class_names: class_names.append(cls_name) if ignore_name in class_names: class_name_to_id[ignore_name] = 255 class_names.remove(ignore_name) if bg_name in class_names: class_names.remove(bg_name) class_name_to_id[bg_name] = 0 for i, name in enumerate(class_names): class_name_to_id[name] = i + 1 if len(class_names) > 256: raise ValueError( f"There are {len(class_names)} categories in the annotation file, " f"exceeding 256, Not compliant with paddlex official format!" ) # create annotated images and copy origin images color_map = get_color_map_list(256) img_file_list = [] label_file_list = [] for i, label_file in enumerate(label_files): filename = osp.splitext(osp.basename(label_file))[0] annotated_img_path = osp.join(output_annot_dir, filename + ".png") with custom_open(label_file, "r") as f: data = json.load(f) img_path = osp.join(osp.dirname(label_file), data["imagePath"]) if not os.path.exists(img_path): logging.info("%s is not existed, skip this image" % img_path) continue img_name = img_path.split("/")[-1] img_file_list.append(f"images/{img_name}") label_img_name = annotated_img_path.split("/")[-1] label_file_list.append(f"annotations/{label_img_name}") img = np.asarray(cv2.imread(img_path)) lbl = shape2label( img_size=img.shape, shapes=data["shapes"], class_name_mapping=class_name_to_id, ) lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode="P") lbl_pil.putpalette(color_map) lbl_pil.save(annotated_img_path) shutil.copy(img_path, output_img_dir) with custom_open(osp.join(input_dir, f"{tag}.txt"), "w") as fp: for img_path, lbl_path in zip(img_file_list, label_file_list): fp.write(f"{img_path} {lbl_path}\n") with custom_open(osp.join(input_dir, "class_name.txt"), "w") as fp: for name in class_names: fp.write(f"{name}{os.linesep}") with custom_open(osp.join(input_dir, "class_name_to_id.txt"), "w") as fp: for key, val in class_name_to_id.items(): fp.write(f"{val}: {key}{os.linesep}") return input_dir def get_color_map_list(num_classes): """get color map list""" num_classes += 1 color_map = num_classes * [0, 0, 0] for i in range(0, num_classes): j = 0 lab = i while lab: color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j) color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j) color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j) j += 1 lab >>= 3 color_map = color_map[3:] return color_map def shape2label(img_size, shapes, class_name_mapping): """根据输入的形状列表,将图像的标签矩阵填充为对应形状的类别编号""" label = np.zeros(img_size[:2], dtype=np.int32) for shape in shapes: points = shape["points"] class_name = shape["label"] label_mask = polygon2mask(img_size[:2], points) label[label_mask] = class_name_mapping[class_name] return label def polygon2mask(img_size, points): """将给定形状的点转换成对应的掩膜""" label_mask = Image.fromarray(np.zeros(img_size[:2], dtype=np.uint8)) image_draw = ImageDraw.Draw(label_mask) points_list = [tuple(point) for point in points] assert len(points_list) > 2, ValueError("Polygon must have points more than 2") image_draw.polygon(xy=points_list, outline=1, fill=1) return np.array(label_mask, dtype=bool) def save_item_to_txt(items, file_path): try: with open(file_path, "a") as file: file.write(items) file.close() except Exception as e: print(f"Saving_error: {e}") def save_training_txt(cls_root, mode, cat): imgs = os.listdir(os.path.join(cls_root, mode, cat)) imgs.sort() for img in imgs: if mode == "train": item = os.path.join(cls_root, mode, cat, img) items = item + " " + item + "\n" save_item_to_txt(items, os.path.join(cls_root, "train.txt")) elif mode == "test" and cat != "good": item1 = os.path.join(cls_root, mode, cat, img) item2 = os.path.join( cls_root, "ground_truth", cat, img.split(".")[0] + "_mask.png" ) items = item1 + " " + item2 + "\n" save_item_to_txt(items, os.path.join(cls_root, "val.txt")) def check_old_txt(cls_pth, mode): set_name = "train.txt" if mode == "train" else "val.txt" pth = os.path.join(cls_pth, set_name) if os.path.exists(pth): os.remove(pth) def convert_mvtec_dataset(input_dir): classes = [ "bottle", "cable", "capsule", "hazelnut", "metal_nut", "pill", "screw", "toothbrush", "transistor", "zipper", "carpet", "grid", "leather", "tile", "wood", ] clas = os.path.split(input_dir)[-1] assert clas in classes, info( f"Make sure your class: '{clas}' in your dataset root in\n {classes}" ) modes = ["train", "test"] cls_root = input_dir for mode in modes: check_old_txt(cls_root, mode) cats = os.listdir(os.path.join(cls_root, mode)) for cat in cats: save_training_txt(cls_root, mode, cat) info(f"Add train.txt/val.txt successfully for {input_dir}")