analyse_dataset.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import platform
  16. import numpy as np
  17. import pandas as pd
  18. from .....utils.deps import function_requires_deps, is_dep_available
  19. from .....utils.fonts import PINGFANG_FONT
  20. if is_dep_available("matplotlib"):
  21. import matplotlib.pyplot as plt
  22. from matplotlib import font_manager
  23. @function_requires_deps("matplotlib")
  24. def deep_analyse(dataset_dir, output, label_col="label"):
  25. """class analysis for dataset"""
  26. tags = ["train", "val"]
  27. label_unique = None
  28. for tag in tags:
  29. csv_path = os.path.abspath(os.path.join(dataset_dir, tag + ".csv"))
  30. df = pd.read_csv(csv_path)
  31. if label_col not in df.columns:
  32. raise ValueError(f"default label_col: {label_col} not in {tag} dataset")
  33. if label_unique is None:
  34. label_unique = df[label_col].unique()
  35. cls_dict = {}
  36. for label in label_unique:
  37. vis_df = df[df[label_col].isin([label])]
  38. cls_dict[label] = len(vis_df)
  39. if tag == "train":
  40. cls_train = [label_num for label_col, label_num in cls_dict.items()]
  41. elif tag == "val":
  42. cls_val = [label_num for label_col, label_num in cls_dict.items()]
  43. sorted_id = sorted(range(len(cls_train)), key=lambda k: cls_train[k], reverse=True)
  44. cls_train_sorted = sorted(cls_train, reverse=True)
  45. cls_val_sorted = [cls_val[index] for index in sorted_id]
  46. classes_sorted = [label_unique[index] for index in sorted_id]
  47. x = np.arange(len(label_unique))
  48. width = 0.5
  49. # bar
  50. os_system = platform.system().lower()
  51. if os_system == "windows":
  52. plt.rcParams["font.sans-serif"] = "FangSong"
  53. else:
  54. font = font_manager.FontProperties(fname=PINGFANG_FONT.path)
  55. fig, ax = plt.subplots(figsize=(max(8, int(len(label_unique) / 5)), 5), dpi=120)
  56. ax.bar(x, cls_train_sorted, width=0.5, label="train")
  57. ax.bar(x + width, cls_val_sorted, width=0.5, label="val")
  58. plt.xticks(
  59. x + width / 2,
  60. classes_sorted,
  61. rotation=90,
  62. fontproperties=None if os_system == "windows" else font,
  63. )
  64. ax.set_ylabel("Counts")
  65. plt.legend()
  66. fig.tight_layout()
  67. fig_path = os.path.join(output, "histogram.png")
  68. fig.savefig(fig_path)
  69. return {"histogram": os.path.join("check_dataset", "histogram.png")}