| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import os
- import platform
- from collections import defaultdict
- import numpy as np
- from .....utils.deps import function_requires_deps, is_dep_available
- from .....utils.file_interface import custom_open
- from .....utils.fonts import PINGFANG_FONT
- from .....utils.logging import warning
- if is_dep_available("opencv-contrib-python"):
- import cv2
- if is_dep_available("matplotlib"):
- import matplotlib.pyplot as plt
- from matplotlib import font_manager
- from matplotlib.backends.backend_agg import FigureCanvasAgg
- def simple_analyse(dataset_path, images_dict):
- """
- Analyse the dataset samples by return image path and label path
- Args:
- dataset_path (str): dataset path
- ds_meta (dict): dataset meta
- images_dict (dict): train, val and test image path
- Returns:
- tuple: tuple of sample number, image path and label path for train, val and text subdataset.
- """
- tags = ["train", "val", "test"]
- sample_cnts = defaultdict(int)
- img_paths = defaultdict(list)
- res = [None] * 6
- for tag in tags:
- file_list = os.path.join(dataset_path, f"{tag}.txt")
- if not os.path.exists(file_list):
- if tag in ("train", "val"):
- res.insert(0, "数据集不符合规范,请先通过数据校准")
- return res
- else:
- continue
- else:
- with custom_open(file_list, "r") as f:
- all_lines = f.readlines()
- # Each line corresponds to a sample
- sample_cnts[tag] = len(all_lines)
- img_paths[tag] = images_dict[tag]
- return (
- "完成数据分析",
- sample_cnts[tags[0]],
- sample_cnts[tags[1]],
- sample_cnts[tags[2]],
- img_paths[tags[0]],
- img_paths[tags[1]],
- img_paths[tags[2]],
- )
- @function_requires_deps("matplotlib", "opencv-contrib-python")
- def deep_analyse(dataset_path, output, datatype="MSTextRecDataset"):
- """class analysis for dataset"""
- tags = ["train", "val"]
- labels_cnt = {}
- x_max = []
- classes_max = []
- for tag in tags:
- image_path = os.path.join(dataset_path, f"{tag}.txt")
- str_nums = []
- with custom_open(image_path, "r") as f:
- lines = f.readlines()
- for line in lines:
- line = line.strip().split("\t")
- if len(line) != 2:
- warning(f"Error in {line}.")
- continue
- str_nums.append(len(line[1]))
- if datatype == "LaTeXOCRDataset":
- max_length = min(768, max(str_nums))
- interval = 20
- else:
- max_length = min(100, max(str_nums))
- interval = 5
- start = 0
- for i in range(1, math.ceil((max_length / interval))):
- stop = i * interval
- num_str = sum(start < i <= stop for i in str_nums)
- labels_cnt[f"{start}-{stop}"] = num_str
- start = stop
- if sum(max_length < i for i in str_nums) != 0:
- labels_cnt[f"> {max_length}"] = sum(max_length < i for i in str_nums)
- if tag == "train":
- cnts_train = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
- x_train = np.arange(len(cnts_train))
- if len(x_train) > len(x_max):
- x_max = x_train
- classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]
- elif tag == "val":
- cnts_val = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
- x_val = np.arange(len(cnts_val))
- if len(x_val) > len(x_max):
- x_max = x_val
- classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]
- width = 0.3
- # bar
- os_system = platform.system().lower()
- if os_system == "windows":
- plt.rcParams["font.sans-serif"] = "FangSong"
- else:
- font = font_manager.FontProperties(fname=PINGFANG_FONT.path, size=15)
- if datatype == "LaTeXOCRDataset":
- fig, ax = plt.subplots(figsize=(15, 9), dpi=120)
- xlabel_name = "公式长度区间"
- else:
- fig, ax = plt.subplots(figsize=(10, 5), dpi=120)
- xlabel_name = "文本字长度区间"
- ax.bar(x_train, cnts_train, width=0.3, label="train")
- ax.bar(x_val + width, cnts_val, width=0.3, label="val")
- plt.xticks(x_max + width / 2, classes_max, rotation=90)
- plt.legend(prop={"size": 18})
- ax.set_xlabel(
- xlabel_name,
- fontproperties=None if os_system == "windows" else font,
- fontsize=12,
- )
- ax.set_ylabel(
- "图片数量", fontproperties=None if os_system == "windows" else font, fontsize=12
- )
- canvas = FigureCanvasAgg(fig)
- canvas.draw()
- width, height = fig.get_size_inches() * fig.get_dpi()
- pie_array = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
- int(height), int(width), 3
- )
- fig1_path = os.path.join(output, "histogram.png")
- cv2.imwrite(fig1_path, pie_array)
- return {"histogram": os.path.join("check_dataset", "histogram.png")}
|