analyse_dataset.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import math
  17. import platform
  18. from pathlib import Path
  19. from collections import defaultdict
  20. from PIL import Image
  21. import cv2
  22. import numpy as np
  23. import matplotlib.pyplot as plt
  24. from matplotlib.backends.backend_agg import FigureCanvasAgg
  25. from matplotlib import font_manager
  26. from .....utils.file_interface import custom_open
  27. from .....utils.logging import warning
  28. from .....utils.fonts import PINGFANG_FONT_FILE_PATH
  29. def simple_analyse(dataset_path, images_dict):
  30. """
  31. Analyse the dataset samples by return image path and label path
  32. Args:
  33. dataset_path (str): dataset path
  34. ds_meta (dict): dataset meta
  35. images_dict (dict): train, val and test image path
  36. Returns:
  37. tuple: tuple of sample number, image path and label path for train, val and text subdataset.
  38. """
  39. tags = ['train', 'val', 'test']
  40. sample_cnts = defaultdict(int)
  41. img_paths = defaultdict(list)
  42. res = [None] * 6
  43. for tag in tags:
  44. file_list = os.path.join(dataset_path, f'{tag}.txt')
  45. if not os.path.exists(file_list):
  46. if tag in ('train', 'val'):
  47. res.insert(0, "数据集不符合规范,请先通过数据校准")
  48. return res
  49. else:
  50. continue
  51. else:
  52. with custom_open(file_list, 'r') as f:
  53. all_lines = f.readlines()
  54. # Each line corresponds to a sample
  55. sample_cnts[tag] = len(all_lines)
  56. img_paths[tag] = images_dict[tag]
  57. return ("完成数据分析", sample_cnts[tags[0]], sample_cnts[tags[1]],
  58. sample_cnts[tags[2]], img_paths[tags[0]], img_paths[tags[1]],
  59. img_paths[tags[2]])
  60. def deep_analyse(dataset_path, output_dir):
  61. """class analysis for dataset"""
  62. tags = ['train', 'val']
  63. all_instances = 0
  64. labels_cnt = {}
  65. x_max = []
  66. classes_max = []
  67. for tag in tags:
  68. image_path = os.path.join(dataset_path, f'{tag}.txt')
  69. str_nums = []
  70. with custom_open(image_path, 'r') as f:
  71. lines = f.readlines()
  72. for line in lines:
  73. line = line.strip().split("\t")
  74. if len(line) != 2:
  75. warning(f"Error in {line}.")
  76. continue
  77. str_nums.append(len(line[1]))
  78. max_length = min(100, max(str_nums))
  79. start = 0
  80. for i in range(1, math.ceil((max_length / 5))):
  81. stop = i * 5
  82. num_str = sum(start < i <= stop for i in str_nums)
  83. labels_cnt[f'{start}-{stop}'] = num_str
  84. start = stop
  85. if sum(max_length < i for i in str_nums) != 0:
  86. labels_cnt[f'> {max_length}'] = sum(max_length < i
  87. for i in str_nums)
  88. if tag == 'train':
  89. cnts_train = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
  90. x_train = np.arange(len(cnts_train))
  91. if len(x_train) > len(x_max):
  92. x_max = x_train
  93. classes_max = [
  94. cat_name for cat_name, cat_ids in labels_cnt.items()
  95. ]
  96. elif tag == 'val':
  97. cnts_val = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
  98. x_val = np.arange(len(cnts_val))
  99. if len(x_val) > len(x_max):
  100. x_max = x_val
  101. classes_max = [
  102. cat_name for cat_name, cat_ids in labels_cnt.items()
  103. ]
  104. width = 0.3
  105. # bar
  106. os_system = platform.system().lower()
  107. if os_system == "windows":
  108. plt.rcParams['font.sans-serif'] = 'FangSong'
  109. else:
  110. font = font_manager.FontProperties(
  111. fname=PINGFANG_FONT_FILE_PATH, size=15)
  112. fig, ax = plt.subplots(figsize=(10, 5), dpi=120)
  113. ax.bar(x_train, cnts_train, width=0.3, label='train')
  114. ax.bar(x_val + width, cnts_val, width=0.3, label='val')
  115. plt.xticks(x_max + width / 2, classes_max, rotation=90)
  116. ax.set_xlabel(
  117. '文本字长度区间',
  118. fontproperties=None if os_system == "windows" else font,
  119. fontsize=12)
  120. ax.set_ylabel(
  121. '图片数量',
  122. fontproperties=None if os_system == "windows" else font,
  123. fontsize=12)
  124. canvas = FigureCanvasAgg(fig)
  125. canvas.draw()
  126. width, height = fig.get_size_inches() * fig.get_dpi()
  127. pie_array = np.frombuffer(
  128. canvas.tostring_rgb(), dtype='uint8').reshape(
  129. int(height), int(width), 3)
  130. fig1_path = os.path.join(output_dir, "histogram.png")
  131. cv2.imwrite(fig1_path, pie_array)
  132. return {"histogram": os.path.join("check_dataset", "histogram.png")}