check_dataset.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. from PIL import Image, ImageOps
  18. import cv2
  19. from .utils.visualizer import visualize
  20. from .....utils.errors import DatasetFileNotFoundError
  21. from .....utils.file_interface import custom_open
  22. from .....utils.logging import info
  23. def check_dataset(dataset_dir, output, sample_num=10):
  24. """check dataset"""
  25. dataset_dir = osp.abspath(dataset_dir)
  26. if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
  27. raise DatasetFileNotFoundError(file_path=dataset_dir)
  28. vis_save_dir = osp.join(output, "demo_img")
  29. if not osp.exists(vis_save_dir):
  30. os.makedirs(vis_save_dir)
  31. split_tags = ["train", "val"]
  32. attrs = dict()
  33. class_ids = set()
  34. for tag in split_tags:
  35. mapping_file = osp.join(dataset_dir, f"{tag}.txt")
  36. if not osp.exists(mapping_file):
  37. info(f"The mapping file ({mapping_file}) doesn't exist, ignored.")
  38. info(
  39. "If you are using MVTec_AD dataset, add args below in your training commands:"
  40. )
  41. info("-o CheckDataset.convert.enable=True")
  42. info("-o CheckDataset.convert.src_dataset_type=MVTec_AD")
  43. continue
  44. with custom_open(mapping_file, "r") as fp:
  45. lines = filter(None, (line.strip() for line in fp.readlines()))
  46. for i, line in enumerate(lines):
  47. img_file, ann_file = line.split(" ")
  48. img_file = osp.join(dataset_dir, img_file)
  49. ann_file = osp.join(dataset_dir, ann_file)
  50. assert osp.exists(img_file), FileNotFoundError(
  51. f"{img_file} not exist, please check!"
  52. )
  53. assert osp.exists(ann_file), FileNotFoundError(
  54. f"{ann_file} not exist, please check!"
  55. )
  56. img = np.array(cv2.imread(img_file), "uint8")
  57. ann = np.array(cv2.imread(ann_file), "uint8")[:, :, 0]
  58. assert img.shape[:2] == ann.shape, ValueError(
  59. f"The shape of {img_file}:{img.shape[:2]} and "
  60. f"{ann_file}:{ann.shape} must be the same!"
  61. )
  62. if tag == "val":
  63. class_ids = class_ids | set(ann.reshape([-1]).tolist())
  64. if i < sample_num:
  65. vis_img = visualize(img, ann)
  66. vis_img = Image.fromarray(vis_img)
  67. vis_save_path = osp.join(vis_save_dir, osp.basename(img_file))
  68. vis_img.save(vis_save_path)
  69. vis_save_path = osp.join(
  70. "check_dataset", os.path.relpath(vis_save_path, output)
  71. )
  72. if f"{tag}_sample_paths" not in attrs:
  73. attrs[f"{tag}_sample_paths"] = [vis_save_path]
  74. else:
  75. attrs[f"{tag}_sample_paths"].append(vis_save_path)
  76. if f"{tag}_samples" not in attrs:
  77. attrs[f"{tag}_samples"] = i + 1
  78. if 255 in class_ids:
  79. class_ids.remove(255)
  80. attrs["num_classes"] = len(class_ids)
  81. return attrs