analyse_dataset.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. from collections import defaultdict
  17. import numpy as np
  18. from PIL import Image, ImageOps
  19. from .....utils.deps import function_requires_deps, is_dep_available
  20. from .....utils.file_interface import custom_open
  21. if is_dep_available("opencv-contrib-python"):
  22. import cv2
  23. if is_dep_available("matplotlib"):
  24. import matplotlib.pyplot as plt
  25. from matplotlib.backends.backend_agg import FigureCanvasAgg
  26. # show data samples
  27. def simple_analyse(dataset_path, max_recorded_sample_cnts=20, show_label=True):
  28. """
  29. Analyse the dataset samples by return not nore than
  30. max_recorded_sample_cnts image path and label path
  31. Args:
  32. dataset_path (str): dataset path
  33. max_recorded_sample_cnts (int, optional): the number to return. Default: 50.
  34. Returns:
  35. tuple: tuple of sample number, image path and label path for train, val and text subdataset.
  36. """
  37. tags = ["train", "val", "test"]
  38. sample_cnts = defaultdict(int)
  39. img_paths = defaultdict(list)
  40. lab_paths = defaultdict(list)
  41. lab_infos = defaultdict(list)
  42. res = [None] * 9
  43. delim = "\t"
  44. for tag in tags:
  45. file_list = os.path.join(dataset_path, f"{tag}.txt")
  46. if not os.path.exists(file_list):
  47. if tag in ("train", "val"):
  48. res.insert(0, "数据集不符合规范,请先通过数据校准")
  49. return res
  50. else:
  51. continue
  52. else:
  53. with custom_open(file_list, "r") as f:
  54. all_lines = f.readlines()
  55. # Each line corresponds to a sample
  56. sample_cnts[tag] = len(all_lines)
  57. for idx, line in enumerate(all_lines):
  58. parts = line.strip("\n").split(delim)
  59. if len(line.strip("\n")) < 1:
  60. continue
  61. if tag in ("train", "val"):
  62. valid_num_parts_lst = [2]
  63. else:
  64. valid_num_parts_lst = [1, 2]
  65. if len(parts) not in valid_num_parts_lst and len(line.strip("\n")) > 1:
  66. res.insert(0, "数据集的标注文件不符合规范")
  67. return res
  68. if len(parts) == 2:
  69. img_path, lab_path = parts
  70. else:
  71. # len(parts) == 1
  72. img_path = parts[0]
  73. lab_path = None
  74. # check det label
  75. if len(img_paths[tag]) < max_recorded_sample_cnts:
  76. img_path = os.path.join(dataset_path, img_path)
  77. if lab_path is not None:
  78. label = json.loads(lab_path)
  79. boxes = []
  80. for item in label:
  81. if "points" not in item or "transcription" not in item:
  82. res.insert(0, "数据集的标注文件不符合规范")
  83. return res
  84. box = np.array(item["points"])
  85. if box.shape[1] != 2:
  86. res.insert(0, "数据集的标注文件不符合规范")
  87. return res
  88. boxes.append(box)
  89. txt = item["transcription"]
  90. if not isinstance(txt, str):
  91. res.insert(0, "数据集的标注文件不符合规范")
  92. return res
  93. if show_label:
  94. lab_img = show_label_img(img_path, boxes)
  95. img_paths[tag].append(img_path)
  96. if show_label:
  97. lab_paths[tag].append(lab_img)
  98. else:
  99. lab_infos[tag].append({"img_path": img_path, "box": boxes})
  100. if show_label:
  101. return (
  102. "完成数据分析",
  103. sample_cnts[tags[0]],
  104. sample_cnts[tags[1]],
  105. sample_cnts[tags[2]],
  106. img_paths[tags[0]],
  107. img_paths[tags[1]],
  108. img_paths[tags[2]],
  109. lab_paths[tags[0]],
  110. lab_paths[tags[1]],
  111. lab_paths[tags[2]],
  112. )
  113. else:
  114. return (
  115. "完成数据分析",
  116. sample_cnts[tags[0]],
  117. sample_cnts[tags[1]],
  118. sample_cnts[tags[2]],
  119. img_paths[tags[0]],
  120. img_paths[tags[1]],
  121. img_paths[tags[2]],
  122. lab_infos[tags[0]],
  123. lab_infos[tags[1]],
  124. lab_infos[tags[2]],
  125. )
  126. @function_requires_deps("opencv-contrib-python")
  127. def show_label_img(img_path, dt_boxes):
  128. """draw ocr detection label"""
  129. img = cv2.imread(img_path)
  130. for box in dt_boxes:
  131. box = np.array(box).astype(np.int32).reshape(-1, 2)
  132. cv2.polylines(img, [box], True, color=(0, 255, 0), thickness=3)
  133. return img[:, :, ::-1]
  134. @function_requires_deps("matplotlib", "opencv-contrib-python")
  135. def deep_analyse(dataset_path, output):
  136. """class analysis for dataset"""
  137. sample_results = simple_analyse(
  138. dataset_path, max_recorded_sample_cnts=float("inf"), show_label=False
  139. )
  140. lab_infos = sample_results[-3] + sample_results[-2] + sample_results[-1]
  141. defaultdict(int)
  142. img_shapes = [] # w, h
  143. ratios_w = []
  144. ratios_h = []
  145. for info in lab_infos:
  146. img = np.asarray(ImageOps.exif_transpose(Image.open(info["img_path"])))
  147. img_h, img_w = np.shape(img)[:2]
  148. img_shapes.append([img_w, img_h])
  149. for box in info["box"]:
  150. box = np.array(box).astype(np.int32).reshape(-1, 2)
  151. box_w, box_h = np.max(box, axis=0) - np.min(box, axis=0)
  152. ratio_w = box_w / img_w
  153. ratio_h = box_h / img_h
  154. ratios_w.append(ratio_w)
  155. ratios_h.append(ratio_h)
  156. m_w_img, m_h_img = np.mean(img_shapes, axis=0) # mean img shape
  157. m_num_box = len(ratios_w) / len(lab_infos) # num box per img
  158. ratio_w = [i * 1000 for i in ratios_w]
  159. ratio_h = [i * 1000 for i in ratios_h]
  160. w_bins = int((max(ratio_w) - min(ratio_w)) // 10)
  161. h_bins = int((max(ratio_h) - min(ratio_h)) // 10)
  162. fig, ax = plt.subplots()
  163. ax.hist(ratio_w, bins=w_bins, rwidth=0.8, color="yellowgreen")
  164. ax.set_xlabel("Width rate *1000")
  165. ax.set_ylabel("number")
  166. canvas = FigureCanvasAgg(fig)
  167. canvas.draw()
  168. width, height = fig.get_size_inches() * fig.get_dpi()
  169. bar_array = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
  170. int(height), int(width), 3
  171. )
  172. # pie
  173. fig, ax = plt.subplots()
  174. ax.hist(ratio_h, bins=h_bins, rwidth=0.8, color="pink")
  175. ax.set_xlabel("Height rate *1000")
  176. ax.set_ylabel("number")
  177. canvas = FigureCanvasAgg(fig)
  178. canvas.draw()
  179. width, height = fig.get_size_inches() * fig.get_dpi()
  180. pie_array = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
  181. int(height), int(width), 3
  182. )
  183. os.makedirs(output, exist_ok=True)
  184. fig_path = os.path.join(output, "histogram.png")
  185. img_array = np.concatenate((bar_array, pie_array), axis=1)
  186. cv2.imwrite(fig_path, img_array)
  187. return {"histogram": os.path.join("check_dataset", "histogram.png")}
  188. # return {
  189. # "图像平均宽度": m_w_img,
  190. # "图像平均高度": m_h_img,
  191. # "每张图平均文本检测框数量": m_num_box,
  192. # "检测框相对宽度分布图": fig1_path,
  193. # "检测框相对高度分布图": fig2_path
  194. # }