analyse_dataset.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. import platform
  17. from collections import defaultdict
  18. import numpy as np
  19. from .....utils.deps import function_requires_deps, is_dep_available
  20. from .....utils.file_interface import custom_open
  21. from .....utils.fonts import PINGFANG_FONT
  22. from .....utils.logging import warning
  23. if is_dep_available("opencv-contrib-python"):
  24. import cv2
  25. if is_dep_available("matplotlib"):
  26. import matplotlib.pyplot as plt
  27. from matplotlib import font_manager
  28. from matplotlib.backends.backend_agg import FigureCanvasAgg
  29. def simple_analyse(dataset_path, images_dict):
  30. """
  31. Analyse the dataset samples by return image path and label path
  32. Args:
  33. dataset_path (str): dataset path
  34. ds_meta (dict): dataset meta
  35. images_dict (dict): train, val and test image path
  36. Returns:
  37. tuple: tuple of sample number, image path and label path for train, val and text subdataset.
  38. """
  39. tags = ["train", "val", "test"]
  40. sample_cnts = defaultdict(int)
  41. img_paths = defaultdict(list)
  42. res = [None] * 6
  43. for tag in tags:
  44. file_list = os.path.join(dataset_path, f"{tag}.txt")
  45. if not os.path.exists(file_list):
  46. if tag in ("train", "val"):
  47. res.insert(0, "数据集不符合规范,请先通过数据校准")
  48. return res
  49. else:
  50. continue
  51. else:
  52. with custom_open(file_list, "r") as f:
  53. all_lines = f.readlines()
  54. # Each line corresponds to a sample
  55. sample_cnts[tag] = len(all_lines)
  56. img_paths[tag] = images_dict[tag]
  57. return (
  58. "完成数据分析",
  59. sample_cnts[tags[0]],
  60. sample_cnts[tags[1]],
  61. sample_cnts[tags[2]],
  62. img_paths[tags[0]],
  63. img_paths[tags[1]],
  64. img_paths[tags[2]],
  65. )
  66. @function_requires_deps("matplotlib", "opencv-contrib-python")
  67. def deep_analyse(dataset_path, output, datatype="MSTextRecDataset"):
  68. """class analysis for dataset"""
  69. tags = ["train", "val"]
  70. labels_cnt = {}
  71. x_max = []
  72. classes_max = []
  73. for tag in tags:
  74. image_path = os.path.join(dataset_path, f"{tag}.txt")
  75. str_nums = []
  76. with custom_open(image_path, "r") as f:
  77. lines = f.readlines()
  78. for line in lines:
  79. line = line.strip().split("\t")
  80. if len(line) != 2:
  81. warning(f"Error in {line}.")
  82. continue
  83. str_nums.append(len(line[1]))
  84. if datatype == "LaTeXOCRDataset":
  85. max_length = min(768, max(str_nums))
  86. interval = 20
  87. else:
  88. max_length = min(100, max(str_nums))
  89. interval = 5
  90. start = 0
  91. for i in range(1, math.ceil((max_length / interval))):
  92. stop = i * interval
  93. num_str = sum(start < i <= stop for i in str_nums)
  94. labels_cnt[f"{start}-{stop}"] = num_str
  95. start = stop
  96. if sum(max_length < i for i in str_nums) != 0:
  97. labels_cnt[f"> {max_length}"] = sum(max_length < i for i in str_nums)
  98. if tag == "train":
  99. cnts_train = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
  100. x_train = np.arange(len(cnts_train))
  101. if len(x_train) > len(x_max):
  102. x_max = x_train
  103. classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]
  104. elif tag == "val":
  105. cnts_val = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
  106. x_val = np.arange(len(cnts_val))
  107. if len(x_val) > len(x_max):
  108. x_max = x_val
  109. classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]
  110. width = 0.3
  111. # bar
  112. os_system = platform.system().lower()
  113. if os_system == "windows":
  114. plt.rcParams["font.sans-serif"] = "FangSong"
  115. else:
  116. font = font_manager.FontProperties(fname=PINGFANG_FONT.path, size=15)
  117. if datatype == "LaTeXOCRDataset":
  118. fig, ax = plt.subplots(figsize=(15, 9), dpi=120)
  119. xlabel_name = "公式长度区间"
  120. else:
  121. fig, ax = plt.subplots(figsize=(10, 5), dpi=120)
  122. xlabel_name = "文本字长度区间"
  123. ax.bar(x_train, cnts_train, width=0.3, label="train")
  124. ax.bar(x_val + width, cnts_val, width=0.3, label="val")
  125. plt.xticks(x_max + width / 2, classes_max, rotation=90)
  126. plt.legend(prop={"size": 18})
  127. ax.set_xlabel(
  128. xlabel_name,
  129. fontproperties=None if os_system == "windows" else font,
  130. fontsize=12,
  131. )
  132. ax.set_ylabel(
  133. "图片数量", fontproperties=None if os_system == "windows" else font, fontsize=12
  134. )
  135. canvas = FigureCanvasAgg(fig)
  136. canvas.draw()
  137. width, height = fig.get_size_inches() * fig.get_dpi()
  138. pie_array = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
  139. int(height), int(width), 3
  140. )
  141. fig1_path = os.path.join(output, "histogram.png")
  142. cv2.imwrite(fig1_path, pie_array)
  143. return {"histogram": os.path.join("check_dataset", "histogram.png")}