| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import os.path as osp
- import numpy as np
- from PIL import Image, ImageOps
- import cv2
- from .utils.visualizer import visualize
- from .....utils.errors import DatasetFileNotFoundError
- from .....utils.file_interface import custom_open
- from .....utils.logging import info
- def check_dataset(dataset_dir, output, sample_num=10):
- """check dataset"""
- dataset_dir = osp.abspath(dataset_dir)
- if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
- raise DatasetFileNotFoundError(file_path=dataset_dir)
- vis_save_dir = osp.join(output, "demo_img")
- if not osp.exists(vis_save_dir):
- os.makedirs(vis_save_dir)
- split_tags = ["train", "val"]
- attrs = dict()
- class_ids = set()
- for tag in split_tags:
- mapping_file = osp.join(dataset_dir, f"{tag}.txt")
- if not osp.exists(mapping_file):
- info(f"The mapping file ({mapping_file}) doesn't exist, ignored.")
- info(
- "If you are using MVTec_AD dataset, add args below in your training commands:"
- )
- info("-o CheckDataset.convert.enable=True")
- info("-o CheckDataset.convert.src_dataset_type=MVTec_AD")
- continue
- with custom_open(mapping_file, "r") as fp:
- lines = filter(None, (line.strip() for line in fp.readlines()))
- for i, line in enumerate(lines):
- img_file, ann_file = line.split(" ")
- img_file = osp.join(dataset_dir, img_file)
- ann_file = osp.join(dataset_dir, ann_file)
- assert osp.exists(img_file), FileNotFoundError(
- f"{img_file} not exist, please check!"
- )
- assert osp.exists(ann_file), FileNotFoundError(
- f"{ann_file} not exist, please check!"
- )
- img = np.array(cv2.imread(img_file), "uint8")
- ann = np.array(cv2.imread(ann_file), "uint8")[:, :, 0]
- assert img.shape[:2] == ann.shape, ValueError(
- f"The shape of {img_file}:{img.shape[:2]} and "
- f"{ann_file}:{ann.shape} must be the same!"
- )
- if tag == "val":
- class_ids = class_ids | set(ann.reshape([-1]).tolist())
- if i < sample_num:
- vis_img = visualize(img, ann)
- vis_img = Image.fromarray(vis_img)
- vis_save_path = osp.join(vis_save_dir, osp.basename(img_file))
- vis_img.save(vis_save_path)
- vis_save_path = osp.join(
- "check_dataset", os.path.relpath(vis_save_path, output)
- )
- if f"{tag}_sample_paths" not in attrs:
- attrs[f"{tag}_sample_paths"] = [vis_save_path]
- else:
- attrs[f"{tag}_sample_paths"].append(vis_save_path)
- if f"{tag}_samples" not in attrs:
- attrs[f"{tag}_samples"] = i + 1
- if 255 in class_ids:
- class_ids.remove(255)
- attrs["num_classes"] = len(class_ids)
- return attrs
|