ins_seg_dataset.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import random
  17. from ..utils import list_files
  18. from .utils import is_pic, replace_ext, MyEncoder, read_coco_ann, get_npy_from_coco_json
  19. from .datasetbase import DatasetBase
  20. import numpy as np
  21. import json
  22. from pycocotools.coco import COCO
  23. class InsSegDataset(DatasetBase):
  24. def __init__(self, dataset_id, path):
  25. super().__init__(dataset_id, path)
  26. self.annotation_dict = None
  27. def check_dataset(self, source_path):
  28. if not osp.isdir(osp.join(source_path, 'JPEGImages')):
  29. raise ValueError("图片文件应该放在{}目录下".format(
  30. osp.join(source_path, 'JPEGImages')))
  31. self.all_files = list_files(source_path)
  32. # 对检测数据集进行统计分析
  33. self.file_info = dict()
  34. self.label_info = dict()
  35. # 若数据集已切分
  36. if osp.exists(osp.join(source_path, 'train.json')):
  37. return self.check_splited_dataset(source_path)
  38. if not osp.exists(osp.join(source_path, 'annotations.json')):
  39. raise ValueError("标注文件annotations.json应该放在{}目录下".format(
  40. source_path))
  41. filename_set = set()
  42. anno_set = set()
  43. for f in self.all_files:
  44. items = osp.split(f)
  45. if len(items) == 2 and items[0] == "JPEGImages":
  46. if not is_pic(f) or f.startswith('.'):
  47. continue
  48. filename_set.add(items[1])
  49. # 解析包含标注信息的json文件
  50. try:
  51. coco = COCO(osp.join(source_path, 'annotations.json'))
  52. img_ids = coco.getImgIds()
  53. cat_ids = coco.getCatIds()
  54. anno_ids = coco.getAnnIds()
  55. catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  56. cid2cname = dict({
  57. clsid: coco.loadCats(catid)[0]['name']
  58. for catid, clsid in catid2clsid.items()
  59. })
  60. for img_id in img_ids:
  61. img_anno = coco.loadImgs(img_id)[0]
  62. img_name = osp.split(img_anno['file_name'])[-1]
  63. anno_set.add(img_name)
  64. anno_dict = read_coco_ann(img_id, coco, cid2cname, catid2clsid)
  65. img_path = osp.join("JPEGImages", img_name)
  66. anno_path = osp.join("Annotations", img_name)
  67. anno = replace_ext(anno_path, "npy")
  68. self.file_info[img_path] = anno
  69. img_class = list(set(anno_dict["gt_class"]))
  70. for category_name in img_class:
  71. if not category_name in self.label_info:
  72. self.label_info[category_name] = [img_path]
  73. else:
  74. self.label_info[category_name].append(img_path)
  75. for label in sorted(self.label_info.keys()):
  76. self.labels.append(label)
  77. except:
  78. raise Exception("标注文件存在错误")
  79. if len(anno_set) > len(filename_set):
  80. sub_list = list(anno_set - filename_set)
  81. raise Exception("标注文件中{}等{}个信息无对应图片".format(sub_list[0],
  82. len(sub_list)))
  83. # 生成每个图片对应的标注信息npy文件
  84. npy_path = osp.join(self.path, "Annotations")
  85. get_npy_from_coco_json(coco, npy_path, self.file_info)
  86. for label in self.labels:
  87. self.class_train_file_list[label] = list()
  88. self.class_val_file_list[label] = list()
  89. self.class_test_file_list[label] = list()
  90. # 将数据集分析信息dump到本地
  91. self.dump_statis_info()
  92. def check_splited_dataset(self, source_path):
  93. train_files_json = osp.join(source_path, "train.json")
  94. val_files_json = osp.join(source_path, "val.json")
  95. test_files_json = osp.join(source_path, "test.json")
  96. for json_file in [train_files_json, val_files_json]:
  97. if not osp.exists(json_file):
  98. raise Exception("已切分的数据集下应该包含train.json, val.json文件")
  99. filename_set = set()
  100. anno_set = set()
  101. # 获取全部图片名称
  102. for f in self.all_files:
  103. items = osp.split(f)
  104. if len(items) == 2 and items[0] == "JPEGImages":
  105. if not is_pic(f) or f.startswith('.'):
  106. continue
  107. filename_set.add(items[1])
  108. img_id_index = 0
  109. anno_id_index = 0
  110. new_img_list = list()
  111. new_cat_list = list()
  112. new_anno_list = list()
  113. for json_file in [train_files_json, val_files_json, test_files_json]:
  114. if not osp.exists(json_file):
  115. continue
  116. coco = COCO(json_file)
  117. img_ids = coco.getImgIds()
  118. cat_ids = coco.getCatIds()
  119. anno_ids = coco.getAnnIds()
  120. catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  121. clsid2catid = dict({i: catid for i, catid in enumerate(cat_ids)})
  122. cid2cname = dict({
  123. clsid: coco.loadCats(catid)[0]['name']
  124. for catid, clsid in catid2clsid.items()
  125. })
  126. # 由原train.json中的category生成新的category信息
  127. if json_file == train_files_json:
  128. cname2catid = dict({
  129. coco.loadCats(catid)[0]['name']: clsid2catid[clsid]
  130. for catid, clsid in catid2clsid.items()
  131. })
  132. new_cat_list = coco.loadCats(cat_ids)
  133. # 获取json中全部标注图片的名字
  134. for img_id in img_ids:
  135. img_anno = coco.loadImgs(img_id)[0]
  136. im_fname = img_anno['file_name']
  137. anno_set.add(im_fname)
  138. if json_file == train_files_json:
  139. self.train_files.append(osp.join("JPEGImages", im_fname))
  140. elif json_file == val_files_json:
  141. self.val_files.append(osp.join("JPEGImages", im_fname))
  142. elif json_file == test_files_json:
  143. self.test_files.append(osp.join("JPEGImages", im_fname))
  144. # 获取每张图片的对应标注信息,并记录为npy格式
  145. anno_dict = read_coco_ann(img_id, coco, cid2cname, catid2clsid)
  146. img_path = osp.join("JPEGImages", im_fname)
  147. anno_path = osp.join("Annotations", im_fname)
  148. anno = replace_ext(anno_path, "npy")
  149. self.file_info[img_path] = anno
  150. # 生成label_info
  151. img_class = list(set(anno_dict["gt_class"]))
  152. for category_name in img_class:
  153. if not category_name in self.label_info:
  154. self.label_info[category_name] = [img_path]
  155. else:
  156. self.label_info[category_name].append(img_path)
  157. if json_file == train_files_json:
  158. if category_name in self.class_train_file_list:
  159. self.class_train_file_list[category_name].append(
  160. img_path)
  161. else:
  162. self.class_train_file_list[category_name] = list()
  163. self.class_train_file_list[category_name].append(
  164. img_path)
  165. elif json_file == val_files_json:
  166. if category_name in self.class_val_file_list:
  167. self.class_val_file_list[category_name].append(
  168. img_path)
  169. else:
  170. self.class_val_file_list[category_name] = list()
  171. self.class_val_file_list[category_name].append(
  172. img_path)
  173. elif json_file == test_files_json:
  174. if category_name in self.class_test_file_list:
  175. self.class_test_file_list[category_name].append(
  176. img_path)
  177. else:
  178. self.class_test_file_list[category_name] = list()
  179. self.class_test_file_list[category_name].append(
  180. img_path)
  181. # 生成新的图片信息
  182. new_img = img_anno
  183. new_img["id"] = img_id_index
  184. img_id_index += 1
  185. new_img_list.append(new_img)
  186. # 生成新的标注信息
  187. ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=0)
  188. for ins_anno_id in ins_anno_ids:
  189. anno = coco.loadAnns(ins_anno_id)[0]
  190. new_anno = anno
  191. new_anno["image_id"] = new_img["id"]
  192. new_anno["id"] = anno_id_index
  193. anno_id_index += 1
  194. cat = coco.loadCats(anno["category_id"])[0]
  195. new_anno_list.append(new_anno)
  196. if len(anno_set) > len(filename_set):
  197. sub_list = list(anno_set - filename_set)
  198. raise Exception("标注文件中{}等{}个信息无对应图片".format(sub_list[0],
  199. len(sub_list)))
  200. for label in sorted(self.label_info.keys()):
  201. self.labels.append(label)
  202. self.annotation_dict = {
  203. "images": new_img_list,
  204. "categories": new_cat_list,
  205. "annotations": new_anno_list
  206. }
  207. # 若原数据集已切分,无annotations.json文件
  208. if not osp.exists(osp.join(self.path, "annotations.json")):
  209. json_file = open(osp.join(self.path, "annotations.json"), 'w+')
  210. json.dump(self.annotation_dict, json_file, cls=MyEncoder)
  211. json_file.close()
  212. # 生成每个图片对应的标注信息npy文件
  213. coco = COCO(osp.join(self.path, "annotations.json"))
  214. npy_path = osp.join(self.path, "Annotations")
  215. get_npy_from_coco_json(coco, npy_path, self.file_info)
  216. self.dump_statis_info()
  217. def split(self, val_split, test_split):
  218. all_files = list(self.file_info.keys())
  219. val_num = int(len(all_files) * val_split)
  220. test_num = int(len(all_files) * test_split)
  221. train_num = len(all_files) - val_num - test_num
  222. assert train_num > 0, "训练集样本数量需大于0"
  223. assert val_num > 0, "验证集样本数量需大于0"
  224. self.train_files = list()
  225. self.val_files = list()
  226. self.test_files = list()
  227. coco = COCO(osp.join(self.path, 'annotations.json'))
  228. img_ids = coco.getImgIds()
  229. cat_ids = coco.getCatIds()
  230. anno_ids = coco.getAnnIds()
  231. random.shuffle(img_ids)
  232. train_files_ids = img_ids[:train_num]
  233. val_files_ids = img_ids[train_num:train_num + val_num]
  234. test_files_ids = img_ids[train_num + val_num:]
  235. for img_id_list in [train_files_ids, val_files_ids, test_files_ids]:
  236. img_anno_ids = coco.getAnnIds(imgIds=img_id_list, iscrowd=0)
  237. imgs = coco.loadImgs(img_id_list)
  238. instances = coco.loadAnns(img_anno_ids)
  239. categories = coco.loadCats(cat_ids)
  240. img_dict = {
  241. "annotations": instances,
  242. "images": imgs,
  243. "categories": categories
  244. }
  245. if img_id_list == train_files_ids:
  246. for img in imgs:
  247. self.train_files.append(
  248. osp.join("JPEGImages", img["file_name"]))
  249. json_file = open(osp.join(self.path, 'train.json'), 'w+')
  250. json.dump(img_dict, json_file, cls=MyEncoder)
  251. elif img_id_list == val_files_ids:
  252. for img in imgs:
  253. self.val_files.append(
  254. osp.join("JPEGImages", img["file_name"]))
  255. json_file = open(osp.join(self.path, 'val.json'), 'w+')
  256. json.dump(img_dict, json_file, cls=MyEncoder)
  257. elif img_id_list == test_files_ids:
  258. for img in imgs:
  259. self.test_files.append(
  260. osp.join("JPEGImages", img["file_name"]))
  261. json_file = open(osp.join(self.path, 'test.json'), 'w+')
  262. json.dump(img_dict, json_file, cls=MyEncoder)
  263. self.train_set = set(self.train_files)
  264. self.val_set = set(self.val_files)
  265. self.test_set = set(self.test_files)
  266. for label, file_list in self.label_info.items():
  267. self.class_train_file_list[label] = list()
  268. self.class_val_file_list[label] = list()
  269. self.class_test_file_list[label] = list()
  270. for f in file_list:
  271. if f in self.test_set:
  272. self.class_test_file_list[label].append(f)
  273. if f in self.val_set:
  274. self.class_val_file_list[label].append(f)
  275. if f in self.train_set:
  276. self.class_train_file_list[label].append(f)
  277. self.dump_statis_info()