check_dataset.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. from PIL import Image
  18. from .....utils.deps import function_requires_deps, is_dep_available
  19. from .....utils.errors import DatasetFileNotFoundError
  20. from .....utils.file_interface import custom_open
  21. from .....utils.logging import info
  22. from .utils.visualizer import visualize
  23. if is_dep_available("opencv-contrib-python"):
  24. import cv2
  25. @function_requires_deps("opencv-contrib-python")
  26. def check_dataset(dataset_dir, output, sample_num=10):
  27. """check dataset"""
  28. dataset_dir = osp.abspath(dataset_dir)
  29. if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
  30. raise DatasetFileNotFoundError(file_path=dataset_dir)
  31. vis_save_dir = osp.join(output, "demo_img")
  32. if not osp.exists(vis_save_dir):
  33. os.makedirs(vis_save_dir)
  34. split_tags = ["train", "val"]
  35. attrs = dict()
  36. class_ids = set()
  37. for tag in split_tags:
  38. mapping_file = osp.join(dataset_dir, f"{tag}.txt")
  39. if not osp.exists(mapping_file):
  40. info(f"The mapping file ({mapping_file}) doesn't exist, ignored.")
  41. info(
  42. "If you are using MVTec_AD dataset, add args below in your training commands:"
  43. )
  44. info("-o CheckDataset.convert.enable=True")
  45. info("-o CheckDataset.convert.src_dataset_type=MVTec_AD")
  46. continue
  47. with custom_open(mapping_file, "r") as fp:
  48. lines = filter(None, (line.strip() for line in fp.readlines()))
  49. for i, line in enumerate(lines):
  50. img_file, ann_file = line.split(" ")
  51. img_file = osp.join(dataset_dir, img_file)
  52. ann_file = osp.join(dataset_dir, ann_file)
  53. assert osp.exists(img_file), FileNotFoundError(
  54. f"{img_file} not exist, please check!"
  55. )
  56. assert osp.exists(ann_file), FileNotFoundError(
  57. f"{ann_file} not exist, please check!"
  58. )
  59. img = np.array(cv2.imread(img_file), "uint8")
  60. ann = np.array(cv2.imread(ann_file), "uint8")[:, :, 0]
  61. assert img.shape[:2] == ann.shape, ValueError(
  62. f"The shape of {img_file}:{img.shape[:2]} and "
  63. f"{ann_file}:{ann.shape} must be the same!"
  64. )
  65. if tag == "val":
  66. class_ids = class_ids | set(ann.reshape([-1]).tolist())
  67. if i < sample_num:
  68. vis_img = visualize(img, ann)
  69. vis_img = Image.fromarray(vis_img)
  70. vis_save_path = osp.join(vis_save_dir, osp.basename(img_file))
  71. vis_img.save(vis_save_path)
  72. vis_save_path = osp.join(
  73. "check_dataset", os.path.relpath(vis_save_path, output)
  74. )
  75. if f"{tag}_sample_paths" not in attrs:
  76. attrs[f"{tag}_sample_paths"] = [vis_save_path]
  77. else:
  78. attrs[f"{tag}_sample_paths"].append(vis_save_path)
  79. if f"{tag}_samples" not in attrs:
  80. attrs[f"{tag}_samples"] = i + 1
  81. if 255 in class_ids:
  82. class_ids.remove(255)
  83. attrs["num_classes"] = len(class_ids)
  84. return attrs