analyse_dataset.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from collections import defaultdict
  16. from .....utils.file_interface import custom_open
  17. def simple_analyse(dataset_path):
  18. """
  19. Analyse the dataset samples by return image path and label path
  20. Args:
  21. dataset_path (str): dataset path
  22. Returns:
  23. tuple: tuple of sample number, image path and label path for train, val and text subdataset.
  24. """
  25. tags = ["train", "val", "test"]
  26. sample_cnts = defaultdict(int)
  27. defaultdict(list)
  28. res = [None] * 6
  29. for tag in tags:
  30. file_list = os.path.join(dataset_path, f"{tag}.txt")
  31. if not os.path.exists(file_list):
  32. if tag in ("train", "val"):
  33. res.insert(0, "数据集不符合规范,请先通过数据校准")
  34. return res
  35. else:
  36. continue
  37. else:
  38. with custom_open(file_list, "r") as f:
  39. all_lines = f.readlines()
  40. # Each line corresponds to a sample
  41. sample_cnts[tag] = len(all_lines)
  42. # img_paths[tag] = images_dict[tag]
  43. return f"训练数据样本数: {sample_cnts[tags[0]]}\t评估数据样本数: {sample_cnts[tags[1]]}"
  44. def deep_analyse(dataset_path, output=None):
  45. """class analysis for dataset"""
  46. return simple_analyse(dataset_path)