analyse_dataset.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import numpy as np
  17. from PIL import Image, ImageOps
  18. from .....utils.deps import function_requires_deps, is_dep_available
  19. from .....utils.file_interface import custom_open
  20. if is_dep_available("matplotlib"):
  21. import matplotlib.pyplot as plt
  22. @function_requires_deps("matplotlib")
  23. def anaylse_dataset(dataset_dir, output):
  24. """class analysis for dataset"""
  25. split_tags = ["train", "val"]
  26. label2count = {tag: dict() for tag in split_tags}
  27. for tag in split_tags:
  28. mapping_file = osp.join(dataset_dir, f"{tag}.txt")
  29. with custom_open(mapping_file, "r") as fp:
  30. lines = filter(None, (line.strip() for line in fp.readlines()))
  31. for i, line in enumerate(lines):
  32. _, ann_file = line.split(" ")
  33. ann_file = osp.join(dataset_dir, ann_file)
  34. ann = np.array(ImageOps.exif_transpose(Image.open(ann_file)), "uint8")
  35. for idx in set(ann.reshape([-1]).tolist()):
  36. if idx == 255:
  37. continue
  38. if idx not in label2count[tag]:
  39. label2count[tag][idx] = 1
  40. else:
  41. label2count[tag][idx] += 1
  42. if label2count[tag].get(0, None) is None:
  43. label2count[tag][0] = 0
  44. train_label_idx = np.array(list(label2count["train"].keys()))
  45. val_label_idx = np.array(list(label2count["val"].keys()))
  46. label_idx = np.array(list(set(train_label_idx) | set(val_label_idx)))
  47. x = np.arange(len(label_idx))
  48. train_list = []
  49. val_list = []
  50. for i in range(len(label_idx)):
  51. train_list.append(label2count["train"].get(i, 0))
  52. val_list.append(label2count["val"].get(i, 0))
  53. fig, ax = plt.subplots(figsize=(max(8, int(len(label_idx) / 5)), 5), dpi=120)
  54. width = (0.5,)
  55. ax.bar(x, train_list, width=width, label="train")
  56. ax.bar(x + width, val_list, width=width, label="val")
  57. plt.xticks(x + 0.25, label_idx)
  58. ax.set_xlabel("Label Index")
  59. ax.set_ylabel("Sample Counts")
  60. plt.legend()
  61. fig.tight_layout()
  62. fig_path = os.path.join(output, "histogram.png")
  63. fig.savefig(fig_path)
  64. return {"histogram": os.path.join("check_dataset", "histogram.png")}