# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import platform import numpy as np from .....utils.deps import function_requires_deps, is_dep_available from .....utils.file_interface import custom_open from .....utils.fonts import PINGFANG_FONT if is_dep_available("matplotlib"): import matplotlib.pyplot as plt from matplotlib import font_manager @function_requires_deps("matplotlib") def deep_analyse(dataset_path, output, dataset_type="ShiTuRec"): """class analysis for dataset""" tags = ["train", "gallery", "query"] tags_info = dict() for tag in tags: anno_path = os.path.join(dataset_path, f"{tag}.txt") with custom_open(anno_path, "r") as f: lines = f.readlines() lines = [line.strip("\n").split(" ") for line in lines] num_images = len(lines) num_labels = len(set([int(line[1]) for line in lines])) tags_info[tag] = { "num_images": num_images, "num_labels": num_labels, } categories = list(tags_info.keys()) num_images = [tags_info[category]["num_images"] for category in categories] num_labels = [tags_info[category]["num_labels"] for category in categories] # bar os_system = platform.system().lower() if os_system == "windows": plt.rcParams["font.sans-serif"] = "FangSong" else: font = font_manager.FontProperties(fname=PINGFANG_FONT.path, size=10) x = np.arange(len(categories)) # 标签位置 width = 0.35 # 每个条形的宽度 fig, ax = plt.subplots() rects1 = ax.bar(x - width / 2, num_images, width, label="Num Images") rects2 = ax.bar(x + width / 2, num_labels, width, label="Num Classes") # 添加一些文本标签 ax.set_xlabel("集合", fontproperties=None if os_system == "windows" else font) ax.set_ylabel("数量", fontproperties=None if os_system == "windows" else font) ax.set_title( "不同集合的图片和类别数量", fontproperties=None if os_system == "windows" else font, ) ax.set_xticks(x, fontproperties=None if os_system == "windows" else font) ax.set_xticklabels(categories) ax.legend() # 在条形图上添加数值标签 def autolabel(rects): """Attach a text label above each bar in *rects*, displaying its height.""" for rect in rects: height = rect.get_height() ax.annotate( "{}".format(height), xy=(rect.get_x() + rect.get_width() / 2, height), xytext=(0, 3), # 3 points vertical offset textcoords="offset points", ha="center", va="bottom", ) autolabel(rects1) autolabel(rects2) fig.tight_layout() file_path = os.path.join(output, "histogram.png") fig.savefig(file_path, dpi=300) return {"histogram": os.path.join("check_dataset", "histogram.png")}