check_dataset.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. from PIL import Image, ImageOps
  18. from .....utils.errors import DatasetFileNotFoundError
  19. from .....utils.file_interface import custom_open
  20. from .....utils.logging import info
  21. from .utils.visualizer import visualize
  22. def check_dataset(dataset_dir, output, sample_num=10):
  23. """check dataset"""
  24. dataset_dir = osp.abspath(dataset_dir)
  25. if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
  26. raise DatasetFileNotFoundError(file_path=dataset_dir)
  27. vis_save_dir = osp.join(output, "demo_img")
  28. if not osp.exists(vis_save_dir):
  29. os.makedirs(vis_save_dir)
  30. split_tags = ["train", "val"]
  31. attrs = dict()
  32. class_ids = set()
  33. for tag in split_tags:
  34. mapping_file = osp.join(dataset_dir, f"{tag}.txt")
  35. if not osp.exists(mapping_file):
  36. info(f"The mapping file ({mapping_file}) doesn't exist, ignored.")
  37. continue
  38. with custom_open(mapping_file, "r") as fp:
  39. lines = filter(None, (line.strip() for line in fp.readlines()))
  40. for i, line in enumerate(lines):
  41. img_file, ann_file = line.split(" ")
  42. img_file = osp.join(dataset_dir, img_file)
  43. ann_file = osp.join(dataset_dir, ann_file)
  44. assert osp.exists(img_file), FileNotFoundError(
  45. f"{img_file} not exist, please check!"
  46. )
  47. assert osp.exists(ann_file), FileNotFoundError(
  48. f"{ann_file} not exist, please check!"
  49. )
  50. img = np.array(ImageOps.exif_transpose(Image.open(img_file)), "uint8")
  51. ann = np.array(ImageOps.exif_transpose(Image.open(ann_file)), "uint8")
  52. assert img.shape[:2] == ann.shape, ValueError(
  53. f"The shape of {img_file}:{img.shape[:2]} and "
  54. f"{ann_file}:{ann.shape} must be the same!"
  55. )
  56. class_ids = class_ids | set(ann.reshape([-1]).tolist())
  57. if i < sample_num:
  58. vis_img = visualize(img, ann)
  59. vis_img = Image.fromarray(vis_img)
  60. vis_save_path = osp.join(vis_save_dir, osp.basename(img_file))
  61. vis_img.save(vis_save_path)
  62. vis_save_path = osp.join(
  63. "check_dataset", os.path.relpath(vis_save_path, output)
  64. )
  65. if f"{tag}_sample_paths" not in attrs:
  66. attrs[f"{tag}_sample_paths"] = [vis_save_path]
  67. else:
  68. attrs[f"{tag}_sample_paths"].append(vis_save_path)
  69. if f"{tag}_samples" not in attrs:
  70. attrs[f"{tag}_samples"] = i + 1
  71. if 255 in class_ids:
  72. class_ids.remove(255)
  73. attrs["num_classes"] = len(class_ids)
  74. return attrs