| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import json
- import math
- import platform
- from pathlib import Path
- from collections import defaultdict
- from PIL import Image
- import cv2
- import numpy as np
- import matplotlib.pyplot as plt
- from matplotlib.backends.backend_agg import FigureCanvasAgg
- from matplotlib import font_manager
- from .....utils.file_interface import custom_open
- from .....utils.logging import warning
- from .....utils.fonts import PINGFANG_FONT_FILE_PATH
- def simple_analyse(dataset_path, images_dict):
- """
- Analyse the dataset samples by return image path and label path
- Args:
- dataset_path (str): dataset path
- ds_meta (dict): dataset meta
- images_dict (dict): train, val and test image path
- Returns:
- tuple: tuple of sample number, image path and label path for train, val and text subdataset.
- """
- tags = ['train', 'val', 'test']
- sample_cnts = defaultdict(int)
- img_paths = defaultdict(list)
- res = [None] * 6
- for tag in tags:
- file_list = os.path.join(dataset_path, f'{tag}.txt')
- if not os.path.exists(file_list):
- if tag in ('train', 'val'):
- res.insert(0, "数据集不符合规范,请先通过数据校准")
- return res
- else:
- continue
- else:
- with custom_open(file_list, 'r') as f:
- all_lines = f.readlines()
- # Each line corresponds to a sample
- sample_cnts[tag] = len(all_lines)
- img_paths[tag] = images_dict[tag]
- return ("完成数据分析", sample_cnts[tags[0]], sample_cnts[tags[1]],
- sample_cnts[tags[2]], img_paths[tags[0]], img_paths[tags[1]],
- img_paths[tags[2]])
- def deep_analyse(dataset_path, output, datatype = "MSTextRecDataset"):
- """class analysis for dataset"""
- tags = ['train', 'val']
- all_instances = 0
- labels_cnt = {}
- x_max = []
- classes_max = []
- for tag in tags:
- image_path = os.path.join(dataset_path, f'{tag}.txt')
- str_nums = []
- with custom_open(image_path, 'r') as f:
- lines = f.readlines()
- for line in lines:
- line = line.strip().split("\t")
- if len(line) != 2:
- warning(f"Error in {line}.")
- continue
- str_nums.append(len(line[1]))
- if datatype == "LaTeXOCRDataset":
- max_length = min(768, max(str_nums))
- interval = 20
- else:
- max_length = min(100, max(str_nums))
- interval = 5
- start = 0
- for i in range(1, math.ceil((max_length / interval))):
- stop = i * interval
- num_str = sum(start < i <= stop for i in str_nums)
- labels_cnt[f'{start}-{stop}'] = num_str
- start = stop
- if sum(max_length < i for i in str_nums) != 0:
- labels_cnt[f'> {max_length}'] = sum(max_length < i
- for i in str_nums)
- if tag == 'train':
- cnts_train = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
- x_train = np.arange(len(cnts_train))
- if len(x_train) > len(x_max):
- x_max = x_train
- classes_max = [
- cat_name for cat_name, cat_ids in labels_cnt.items()
- ]
- elif tag == 'val':
- cnts_val = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
- x_val = np.arange(len(cnts_val))
- if len(x_val) > len(x_max):
- x_max = x_val
- classes_max = [
- cat_name for cat_name, cat_ids in labels_cnt.items()
- ]
- width = 0.3
- # bar
- os_system = platform.system().lower()
- if os_system == "windows":
- plt.rcParams['font.sans-serif'] = 'FangSong'
- else:
- font = font_manager.FontProperties(
- fname=PINGFANG_FONT_FILE_PATH, size=15)
- if datatype == "LaTeXOCRDataset":
- fig, ax = plt.subplots(figsize=(15, 9), dpi=120)
- xlabel_name = '公式长度区间'
- else:
- fig, ax = plt.subplots(figsize=(10, 5), dpi=120)
- xlabel_name = '文本字长度区间'
- ax.bar(x_train, cnts_train, width=0.3, label='train')
- ax.bar(x_val + width, cnts_val, width=0.3, label='val')
- plt.xticks(x_max + width / 2, classes_max, rotation=90)
- plt.legend(prop = {'size':18})
- ax.set_xlabel(
- xlabel_name,
- fontproperties=None if os_system == "windows" else font,
- fontsize=12)
- ax.set_ylabel(
- '图片数量',
- fontproperties=None if os_system == "windows" else font,
- fontsize=12)
- canvas = FigureCanvasAgg(fig)
- canvas.draw()
- width, height = fig.get_size_inches() * fig.get_dpi()
- pie_array = np.frombuffer(
- canvas.tostring_rgb(), dtype='uint8').reshape(
- int(height), int(width), 3)
- fig1_path = os.path.join(output, "histogram.png")
- cv2.imwrite(fig1_path, pie_array)
- return {"histogram": os.path.join("check_dataset", "histogram.png")}
|