seg_dataset.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import cv2
  16. from ..utils import list_files
  17. from .utils import is_pic, replace_ext, get_encoding, check_list_txt, read_seg_ann
  18. from .datasetbase import DatasetBase
  19. class SegDataset(DatasetBase):
  20. def __init__(self, dataset_id, path):
  21. super().__init__(dataset_id, path)
  22. def check_dataset(self, source_path):
  23. if not osp.isdir(osp.join(source_path, 'Annotations')):
  24. raise ValueError("标注文件应该放在{}目录下".format(
  25. osp.join(source_path, 'Annotations')))
  26. if not osp.isdir(osp.join(source_path, 'JPEGImages')):
  27. raise ValueError("图片文件应该放在{}目录下".format(
  28. osp.join(source_path, 'JPEGImages')))
  29. labels_txt = osp.join(source_path, 'labels.txt')
  30. if osp.exists(labels_txt):
  31. with open(labels_txt, encoding=get_encoding(labels_txt)) as fid:
  32. lines = fid.readlines()
  33. for line in lines:
  34. self.labels.append(line.strip())
  35. self.all_files = list_files(source_path)
  36. # 对语义分割数据集进行统计分析
  37. self.file_info = dict()
  38. self.label_info = dict()
  39. if osp.exists(osp.join(source_path, 'train_list.txt')):
  40. return self.check_splited_dataset(source_path)
  41. for f in self.all_files:
  42. if not is_pic(f):
  43. continue
  44. items = osp.split(f)
  45. if len(items) == 2 and items[0] == "JPEGImages":
  46. anno_name = replace_ext(items[1], "png")
  47. full_anno_path = osp.join(
  48. (osp.join(source_path, 'Annotations')), anno_name)
  49. if osp.exists(full_anno_path):
  50. self.file_info[f] = osp.join('Annotations', anno_name)
  51. # 解析PNG标注文件,获取类别信息
  52. labels, ann_img_shape = read_seg_ann(full_anno_path)
  53. img_shape = cv2.imread(osp.join(source_path, f)).shape
  54. if img_shape[0] != ann_img_shape[0] or img_shape[
  55. 1] != ann_img_shape[1]:
  56. raise ValueError("文件{}与标注图片尺寸不一致".format(items[1]))
  57. for i in labels:
  58. if str(i) not in self.label_info:
  59. self.label_info[str(i)] = list()
  60. self.label_info[str(i)].append(f)
  61. # 如果类标签的最大值大于类别数,统计相应的类别为零
  62. max_label = max([int(i) for i in self.label_info]) + 1
  63. for i in range(max_label):
  64. if str(i) not in self.label_info:
  65. self.label_info[str(i)] = list()
  66. if len(self.labels) == 0:
  67. self.labels = [int(i) for i in self.label_info]
  68. self.labels.sort()
  69. self.labels = [str(i) for i in self.labels]
  70. else:
  71. keys = list(self.label_info.keys())
  72. try:
  73. for key in keys:
  74. label = self.labels[int(key)]
  75. self.label_info[label] = self.label_info.pop(key)
  76. except:
  77. raise ValueError("标注信息与实际类别不一致")
  78. for label in self.labels:
  79. self.class_train_file_list[label] = list()
  80. self.class_val_file_list[label] = list()
  81. self.class_test_file_list[label] = list()
  82. # 将数据集分析信息dump到本地
  83. self.dump_statis_info()
  84. def check_splited_dataset(self, source_path):
  85. labels_txt = osp.join(source_path, "labels.txt")
  86. train_list_txt = osp.join(source_path, "train_list.txt")
  87. val_list_txt = osp.join(source_path, "val_list.txt")
  88. test_list_txt = osp.join(source_path, "test_list.txt")
  89. for txt_file in [train_list_txt, val_list_txt]:
  90. if not osp.exists(txt_file):
  91. raise Exception("已切分的数据集下应该包含train_list.txt, val_list.txt文件")
  92. check_list_txt([train_list_txt, val_list_txt, test_list_txt])
  93. if osp.exists(labels_txt):
  94. self.labels = open(
  95. labels_txt, 'r',
  96. encoding=get_encoding(labels_txt)).read().strip().split('\n')
  97. for txt_file in [train_list_txt, val_list_txt, test_list_txt]:
  98. if not osp.exists(txt_file):
  99. continue
  100. with open(txt_file, "r") as f:
  101. for line in f:
  102. items = line.strip().split()
  103. img_file, png_file = [items[0], items[1]]
  104. if not osp.isfile(osp.join(source_path, png_file)):
  105. raise ValueError("数据目录{}中不存在标注文件{}".format(
  106. osp.split(txt_file)[-1], png_file))
  107. if not osp.isfile(osp.join(source_path, img_file)):
  108. raise ValueError("数据目录{}中不存在图片文件{}".format(
  109. osp.split(txt_file)[-1], img_file))
  110. if not png_file.split('.')[-1] == 'png':
  111. raise ValueError("标注文件{}不是png文件".format(png_file))
  112. img_file_name = osp.split(img_file)[-1]
  113. if not is_pic(img_file_name) or img_file_name.startswith(
  114. '.'):
  115. raise ValueError("文件{}不是图片格式".format(img_file_name))
  116. self.file_info[img_file] = png_file
  117. if txt_file == train_list_txt:
  118. self.train_files.append(img_file)
  119. elif txt_file == val_list_txt:
  120. self.val_files.append(img_file)
  121. elif txt_file == test_list_txt:
  122. self.test_files.append(img_file)
  123. # 解析PNG标注文件
  124. labels, ann_img_shape = read_seg_ann(
  125. osp.join(source_path, png_file))
  126. img_shape = cv2.imread(osp.join(source_path,
  127. img_file)).shape
  128. if img_shape[0] != ann_img_shape[0] or img_shape[
  129. 1] != ann_img_shape[1]:
  130. raise ValueError("文件{}与标注图片尺寸不一致".format(
  131. img_file_name))
  132. for i in labels:
  133. if str(i) not in self.label_info:
  134. self.label_info[str(i)] = list()
  135. self.label_info[str(i)].append(img_file)
  136. # 如果类标签的最大值大于类别数,统计相应的类别为零
  137. max_label = max([int(i) for i in self.label_info]) + 1
  138. for i in range(max_label):
  139. if str(i) not in self.label_info:
  140. self.label_info[str(i)] = list()
  141. if len(self.labels) == 0:
  142. self.labels = [int(i) for i in self.label_info]
  143. self.labels.sort()
  144. self.labels = [str(i) for i in self.labels]
  145. else:
  146. keys = list(self.label_info.keys())
  147. try:
  148. for key in keys:
  149. label = self.labels[int(key)]
  150. self.label_info[label] = self.label_info.pop(key)
  151. except:
  152. raise ValueError("标注信息与实际类别不一致")
  153. self.train_set = set(self.train_files)
  154. self.val_set = set(self.val_files)
  155. self.test_set = set(self.test_files)
  156. for label, file_list in self.label_info.items():
  157. self.class_train_file_list[label] = list()
  158. self.class_val_file_list[label] = list()
  159. self.class_test_file_list[label] = list()
  160. for f in file_list:
  161. if f in self.test_set:
  162. self.class_test_file_list[label].append(f)
  163. if f in self.val_set:
  164. self.class_val_file_list[label].append(f)
  165. if f in self.train_set:
  166. self.class_train_file_list[label].append(f)
  167. # 将数据集分析信息dump到本地
  168. self.dump_statis_info()
  169. def split(self, val_split, test_split):
  170. super().split(val_split, test_split)
  171. with open(
  172. osp.join(self.path, 'train_list.txt'), mode='w',
  173. encoding='utf-8') as f:
  174. for x in self.train_files:
  175. label = self.file_info[x]
  176. f.write('{} {}\n'.format(x, label))
  177. with open(
  178. osp.join(self.path, 'val_list.txt'), mode='w',
  179. encoding='utf-8') as f:
  180. for x in self.val_files:
  181. label = self.file_info[x]
  182. f.write('{} {}\n'.format(x, label))
  183. with open(
  184. osp.join(self.path, 'test_list.txt'), mode='w',
  185. encoding='utf-8') as f:
  186. for x in self.test_files:
  187. label = self.file_info[x]
  188. f.write('{} {}\n'.format(x, label))
  189. if not osp.exists(osp.join(self.path, 'labels.txt')):
  190. with open(
  191. osp.join(self.path, 'labels.txt'), mode='w',
  192. encoding='utf-8') as f:
  193. max_label = max([int(i) for i in self.labels]) + 1
  194. for i in range(max_label):
  195. f.write('{}\n'.format(str(i)))
  196. self.dump_statis_info()