analyse_dataset.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from cProfile import label
  15. import os
  16. from collections import defaultdict
  17. import matplotlib.pyplot as plt
  18. from matplotlib.backends.backend_agg import FigureCanvasAgg
  19. import numpy as np
  20. from PIL import Image, ImageOps
  21. import cv2
  22. import json
  23. from .....utils.file_interface import custom_open
  24. # show data samples
  25. def simple_analyse(dataset_path, max_recorded_sample_cnts=20, show_label=True):
  26. """
  27. Analyse the dataset samples by return not nore than
  28. max_recorded_sample_cnts image path and label path
  29. Args:
  30. dataset_path (str): dataset path
  31. max_recorded_sample_cnts (int, optional): the number to return. Default: 50.
  32. Returns:
  33. tuple: tuple of sample number, image path and label path for train, val and text subdataset.
  34. """
  35. tags = ["train", "val", "test"]
  36. sample_cnts = defaultdict(int)
  37. img_paths = defaultdict(list)
  38. lab_paths = defaultdict(list)
  39. lab_infos = defaultdict(list)
  40. res = [None] * 9
  41. delim = "\t"
  42. valid_num_parts = 2
  43. for tag in tags:
  44. file_list = os.path.join(dataset_path, f"{tag}.txt")
  45. if not os.path.exists(file_list):
  46. if tag in ("train", "val"):
  47. res.insert(0, "数据集不符合规范,请先通过数据校准")
  48. return res
  49. else:
  50. continue
  51. else:
  52. with custom_open(file_list, "r") as f:
  53. all_lines = f.readlines()
  54. # Each line corresponds to a sample
  55. sample_cnts[tag] = len(all_lines)
  56. for idx, line in enumerate(all_lines):
  57. parts = line.strip("\n").split(delim)
  58. if len(line.strip("\n")) < 1:
  59. continue
  60. if tag in ("train", "val"):
  61. valid_num_parts_lst = [2]
  62. else:
  63. valid_num_parts_lst = [1, 2]
  64. if len(parts) not in valid_num_parts_lst and len(line.strip("\n")) > 1:
  65. res.insert(0, "数据集的标注文件不符合规范")
  66. return res
  67. if len(parts) == 2:
  68. img_path, lab_path = parts
  69. else:
  70. # len(parts) == 1
  71. img_path = parts[0]
  72. lab_path = None
  73. # check det label
  74. if len(img_paths[tag]) < max_recorded_sample_cnts:
  75. img_path = os.path.join(dataset_path, img_path)
  76. if lab_path is not None:
  77. label = json.loads(lab_path)
  78. boxes = []
  79. for item in label:
  80. if "points" not in item or "transcription" not in item:
  81. res.insert(0, "数据集的标注文件不符合规范")
  82. return res
  83. box = np.array(item["points"])
  84. if box.shape[1] != 2:
  85. res.insert(0, "数据集的标注文件不符合规范")
  86. return res
  87. boxes.append(box)
  88. txt = item["transcription"]
  89. if not isinstance(txt, str):
  90. res.insert(0, "数据集的标注文件不符合规范")
  91. return res
  92. if show_label:
  93. lab_img = show_label_img(img_path, boxes)
  94. img_paths[tag].append(img_path)
  95. if show_label:
  96. lab_paths[tag].append(lab_img)
  97. else:
  98. lab_infos[tag].append({"img_path": img_path, "box": boxes})
  99. if show_label:
  100. return (
  101. "完成数据分析",
  102. sample_cnts[tags[0]],
  103. sample_cnts[tags[1]],
  104. sample_cnts[tags[2]],
  105. img_paths[tags[0]],
  106. img_paths[tags[1]],
  107. img_paths[tags[2]],
  108. lab_paths[tags[0]],
  109. lab_paths[tags[1]],
  110. lab_paths[tags[2]],
  111. )
  112. else:
  113. return (
  114. "完成数据分析",
  115. sample_cnts[tags[0]],
  116. sample_cnts[tags[1]],
  117. sample_cnts[tags[2]],
  118. img_paths[tags[0]],
  119. img_paths[tags[1]],
  120. img_paths[tags[2]],
  121. lab_infos[tags[0]],
  122. lab_infos[tags[1]],
  123. lab_infos[tags[2]],
  124. )
  125. def show_label_img(img_path, dt_boxes):
  126. """draw ocr detection label"""
  127. img = cv2.imread(img_path)
  128. for box in dt_boxes:
  129. box = np.array(box).astype(np.int32).reshape(-1, 2)
  130. cv2.polylines(img, [box], True, color=(0, 255, 0), thickness=3)
  131. return img[:, :, ::-1]
  132. def deep_analyse(dataset_path, output):
  133. """class analysis for dataset"""
  134. sample_results = simple_analyse(
  135. dataset_path, max_recorded_sample_cnts=float("inf"), show_label=False
  136. )
  137. lab_infos = sample_results[-3] + sample_results[-2] + sample_results[-1]
  138. labels_cnt = defaultdict(int)
  139. img_shapes = [] # w, h
  140. ratios_w = []
  141. ratios_h = []
  142. for info in lab_infos:
  143. img = np.asarray(ImageOps.exif_transpose(Image.open(info["img_path"])))
  144. img_h, img_w = np.shape(img)[:2]
  145. img_shapes.append([img_w, img_h])
  146. for box in info["box"]:
  147. box = np.array(box).astype(np.int32).reshape(-1, 2)
  148. box_w, box_h = np.max(box, axis=0) - np.min(box, axis=0)
  149. ratio_w = box_w / img_w
  150. ratio_h = box_h / img_h
  151. ratios_w.append(ratio_w)
  152. ratios_h.append(ratio_h)
  153. m_w_img, m_h_img = np.mean(img_shapes, axis=0) # mean img shape
  154. m_num_box = len(ratios_w) / len(lab_infos) # num box per img
  155. ratio_w = [i * 1000 for i in ratios_w]
  156. ratio_h = [i * 1000 for i in ratios_h]
  157. w_bins = int((max(ratio_w) - min(ratio_w)) // 10)
  158. h_bins = int((max(ratio_h) - min(ratio_h)) // 10)
  159. fig, ax = plt.subplots()
  160. ax.hist(ratio_w, bins=w_bins, rwidth=0.8, color="yellowgreen")
  161. ax.set_xlabel("Width rate *1000")
  162. ax.set_ylabel("number")
  163. canvas = FigureCanvasAgg(fig)
  164. canvas.draw()
  165. width, height = fig.get_size_inches() * fig.get_dpi()
  166. bar_array = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
  167. int(height), int(width), 3
  168. )
  169. # pie
  170. fig, ax = plt.subplots()
  171. ax.hist(ratio_h, bins=h_bins, rwidth=0.8, color="pink")
  172. ax.set_xlabel("Height rate *1000")
  173. ax.set_ylabel("number")
  174. canvas = FigureCanvasAgg(fig)
  175. canvas.draw()
  176. width, height = fig.get_size_inches() * fig.get_dpi()
  177. pie_array = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
  178. int(height), int(width), 3
  179. )
  180. os.makedirs(output, exist_ok=True)
  181. fig_path = os.path.join(output, "histogram.png")
  182. img_array = np.concatenate((bar_array, pie_array), axis=1)
  183. cv2.imwrite(fig_path, img_array)
  184. return {"histogram": os.path.join("check_dataset", "histogram.png")}
  185. # return {
  186. # "图像平均宽度": m_w_img,
  187. # "图像平均高度": m_h_img,
  188. # "每张图平均文本检测框数量": m_num_box,
  189. # "检测框相对宽度分布图": fig1_path,
  190. # "检测框相对高度分布图": fig2_path
  191. # }