ins_seg_dataset.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import random
  16. from ..utils import list_files
  17. from .utils import is_pic, replace_ext, MyEncoder, read_coco_ann, get_npy_from_coco_json
  18. from .datasetbase import DatasetBase
  19. import json
  20. from pycocotools.coco import COCO
  21. class InsSegDataset(DatasetBase):
  22. def __init__(self, dataset_id, path):
  23. super().__init__(dataset_id, path)
  24. self.annotation_dict = None
  25. def check_dataset(self, source_path):
  26. if not osp.isdir(osp.join(source_path, 'JPEGImages')):
  27. raise ValueError("图片文件应该放在{}目录下".format(
  28. osp.join(source_path, 'JPEGImages')))
  29. self.all_files = list_files(source_path)
  30. # 对检测数据集进行统计分析
  31. self.file_info = dict()
  32. self.label_info = dict()
  33. # 若数据集已切分
  34. if osp.exists(osp.join(source_path, 'train.json')):
  35. return self.check_splited_dataset(source_path)
  36. if not osp.exists(osp.join(source_path, 'annotations.json')):
  37. raise ValueError("标注文件annotations.json应该放在{}目录下".format(
  38. source_path))
  39. filename_set = set()
  40. anno_set = set()
  41. for f in self.all_files:
  42. items = osp.split(f)
  43. if len(items) == 2 and items[0] == "JPEGImages":
  44. if not is_pic(f) or f.startswith('.'):
  45. continue
  46. filename_set.add(items[1])
  47. # 解析包含标注信息的json文件
  48. try:
  49. coco = COCO(osp.join(source_path, 'annotations.json'))
  50. img_ids = coco.getImgIds()
  51. cat_ids = coco.getCatIds()
  52. anno_ids = coco.getAnnIds()
  53. catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  54. cid2cname = dict({
  55. clsid: coco.loadCats(catid)[0]['name']
  56. for catid, clsid in catid2clsid.items()
  57. })
  58. for img_id in img_ids:
  59. img_anno = coco.loadImgs(img_id)[0]
  60. img_name = osp.split(img_anno['file_name'])[-1]
  61. anno_set.add(img_name)
  62. anno_dict = read_coco_ann(img_id, coco, cid2cname, catid2clsid)
  63. img_path = osp.join("JPEGImages", img_name)
  64. anno_path = osp.join("Annotations", img_name)
  65. anno = replace_ext(anno_path, "npy")
  66. self.file_info[img_path] = anno
  67. img_class = list(set(anno_dict["gt_class"]))
  68. for category_name in img_class:
  69. if not category_name in self.label_info:
  70. self.label_info[category_name] = [img_path]
  71. else:
  72. self.label_info[category_name].append(img_path)
  73. for label in sorted(self.label_info.keys()):
  74. self.labels.append(label)
  75. except:
  76. raise Exception("标注文件存在错误")
  77. if len(anno_set) > len(filename_set):
  78. sub_list = list(anno_set - filename_set)
  79. raise Exception("标注文件中{}等{}个信息无对应图片".format(sub_list[0],
  80. len(sub_list)))
  81. # 生成每个图片对应的标注信息npy文件
  82. npy_path = osp.join(self.path, "Annotations")
  83. get_npy_from_coco_json(coco, npy_path, self.file_info)
  84. for label in self.labels:
  85. self.class_train_file_list[label] = list()
  86. self.class_val_file_list[label] = list()
  87. self.class_test_file_list[label] = list()
  88. # 将数据集分析信息dump到本地
  89. self.dump_statis_info()
  90. def check_splited_dataset(self, source_path):
  91. train_files_json = osp.join(source_path, "train.json")
  92. val_files_json = osp.join(source_path, "val.json")
  93. test_files_json = osp.join(source_path, "test.json")
  94. for json_file in [train_files_json, val_files_json]:
  95. if not osp.exists(json_file):
  96. raise Exception("已切分的数据集下应该包含train.json, val.json文件")
  97. filename_set = set()
  98. anno_set = set()
  99. # 获取全部图片名称
  100. for f in self.all_files:
  101. items = osp.split(f)
  102. if len(items) == 2 and items[0] == "JPEGImages":
  103. if not is_pic(f) or f.startswith('.'):
  104. continue
  105. filename_set.add(items[1])
  106. img_id_index = 0
  107. anno_id_index = 0
  108. new_img_list = list()
  109. new_cat_list = list()
  110. new_anno_list = list()
  111. for json_file in [train_files_json, val_files_json, test_files_json]:
  112. if not osp.exists(json_file):
  113. continue
  114. coco = COCO(json_file)
  115. img_ids = coco.getImgIds()
  116. cat_ids = coco.getCatIds()
  117. anno_ids = coco.getAnnIds()
  118. catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  119. clsid2catid = dict({i: catid for i, catid in enumerate(cat_ids)})
  120. cid2cname = dict({
  121. clsid: coco.loadCats(catid)[0]['name']
  122. for catid, clsid in catid2clsid.items()
  123. })
  124. # 由原train.json中的category生成新的category信息
  125. if json_file == train_files_json:
  126. cname2catid = dict({
  127. coco.loadCats(catid)[0]['name']: clsid2catid[clsid]
  128. for catid, clsid in catid2clsid.items()
  129. })
  130. new_cat_list = coco.loadCats(cat_ids)
  131. # 获取json中全部标注图片的名字
  132. for img_id in img_ids:
  133. img_anno = coco.loadImgs(img_id)[0]
  134. im_fname = img_anno['file_name']
  135. anno_set.add(im_fname)
  136. if json_file == train_files_json:
  137. self.train_files.append(osp.join("JPEGImages", im_fname))
  138. elif json_file == val_files_json:
  139. self.val_files.append(osp.join("JPEGImages", im_fname))
  140. elif json_file == test_files_json:
  141. self.test_files.append(osp.join("JPEGImages", im_fname))
  142. # 获取每张图片的对应标注信息,并记录为npy格式
  143. anno_dict = read_coco_ann(img_id, coco, cid2cname, catid2clsid)
  144. img_path = osp.join("JPEGImages", im_fname)
  145. anno_path = osp.join("Annotations", im_fname)
  146. anno = replace_ext(anno_path, "npy")
  147. self.file_info[img_path] = anno
  148. # 生成label_info
  149. img_class = list(set(anno_dict["gt_class"]))
  150. for category_name in img_class:
  151. if not category_name in self.label_info:
  152. self.label_info[category_name] = [img_path]
  153. else:
  154. self.label_info[category_name].append(img_path)
  155. if json_file == train_files_json:
  156. if category_name in self.class_train_file_list:
  157. self.class_train_file_list[category_name].append(
  158. img_path)
  159. else:
  160. self.class_train_file_list[category_name] = list()
  161. self.class_train_file_list[category_name].append(
  162. img_path)
  163. elif json_file == val_files_json:
  164. if category_name in self.class_val_file_list:
  165. self.class_val_file_list[category_name].append(
  166. img_path)
  167. else:
  168. self.class_val_file_list[category_name] = list()
  169. self.class_val_file_list[category_name].append(
  170. img_path)
  171. elif json_file == test_files_json:
  172. if category_name in self.class_test_file_list:
  173. self.class_test_file_list[category_name].append(
  174. img_path)
  175. else:
  176. self.class_test_file_list[category_name] = list()
  177. self.class_test_file_list[category_name].append(
  178. img_path)
  179. # 生成新的图片信息
  180. new_img = img_anno
  181. new_img["id"] = img_id_index
  182. img_id_index += 1
  183. new_img_list.append(new_img)
  184. # 生成新的标注信息
  185. ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=0)
  186. for ins_anno_id in ins_anno_ids:
  187. anno = coco.loadAnns(ins_anno_id)[0]
  188. new_anno = anno
  189. new_anno["image_id"] = new_img["id"]
  190. new_anno["id"] = anno_id_index
  191. anno_id_index += 1
  192. cat = coco.loadCats(anno["category_id"])[0]
  193. new_anno_list.append(new_anno)
  194. if len(anno_set) > len(filename_set):
  195. sub_list = list(anno_set - filename_set)
  196. raise Exception("标注文件中{}等{}个信息无对应图片".format(sub_list[0],
  197. len(sub_list)))
  198. for label in sorted(self.label_info.keys()):
  199. self.labels.append(label)
  200. self.annotation_dict = {
  201. "images": new_img_list,
  202. "categories": new_cat_list,
  203. "annotations": new_anno_list
  204. }
  205. # 若原数据集已切分,无annotations.json文件
  206. if not osp.exists(osp.join(self.path, "annotations.json")):
  207. json_file = open(osp.join(self.path, "annotations.json"), 'w+')
  208. json.dump(self.annotation_dict, json_file, cls=MyEncoder)
  209. json_file.close()
  210. # 生成每个图片对应的标注信息npy文件
  211. coco = COCO(osp.join(self.path, "annotations.json"))
  212. npy_path = osp.join(self.path, "Annotations")
  213. get_npy_from_coco_json(coco, npy_path, self.file_info)
  214. self.dump_statis_info()
  215. def split(self, val_split, test_split):
  216. all_files = list(self.file_info.keys())
  217. val_num = int(len(all_files) * val_split)
  218. test_num = int(len(all_files) * test_split)
  219. train_num = len(all_files) - val_num - test_num
  220. assert train_num > 0, "训练集样本数量需大于0"
  221. assert val_num > 0, "验证集样本数量需大于0"
  222. self.train_files = list()
  223. self.val_files = list()
  224. self.test_files = list()
  225. coco = COCO(osp.join(self.path, 'annotations.json'))
  226. img_ids = coco.getImgIds()
  227. cat_ids = coco.getCatIds()
  228. anno_ids = coco.getAnnIds()
  229. random.shuffle(img_ids)
  230. train_files_ids = img_ids[:train_num]
  231. val_files_ids = img_ids[train_num:train_num + val_num]
  232. test_files_ids = img_ids[train_num + val_num:]
  233. for img_id_list in [train_files_ids, val_files_ids, test_files_ids]:
  234. img_anno_ids = coco.getAnnIds(imgIds=img_id_list, iscrowd=0)
  235. imgs = coco.loadImgs(img_id_list)
  236. instances = coco.loadAnns(img_anno_ids)
  237. categories = coco.loadCats(cat_ids)
  238. img_dict = {
  239. "annotations": instances,
  240. "images": imgs,
  241. "categories": categories
  242. }
  243. if img_id_list == train_files_ids:
  244. for img in imgs:
  245. self.train_files.append(
  246. osp.join("JPEGImages", img["file_name"]))
  247. json_file = open(osp.join(self.path, 'train.json'), 'w+')
  248. json.dump(img_dict, json_file, cls=MyEncoder)
  249. elif img_id_list == val_files_ids:
  250. for img in imgs:
  251. self.val_files.append(
  252. osp.join("JPEGImages", img["file_name"]))
  253. json_file = open(osp.join(self.path, 'val.json'), 'w+')
  254. json.dump(img_dict, json_file, cls=MyEncoder)
  255. elif img_id_list == test_files_ids:
  256. for img in imgs:
  257. self.test_files.append(
  258. osp.join("JPEGImages", img["file_name"]))
  259. json_file = open(osp.join(self.path, 'test.json'), 'w+')
  260. json.dump(img_dict, json_file, cls=MyEncoder)
  261. self.train_set = set(self.train_files)
  262. self.val_set = set(self.val_files)
  263. self.test_set = set(self.test_files)
  264. for label, file_list in self.label_info.items():
  265. self.class_train_file_list[label] = list()
  266. self.class_val_file_list[label] = list()
  267. self.class_test_file_list[label] = list()
  268. for f in file_list:
  269. if f in self.test_set:
  270. self.class_test_file_list[label].append(f)
  271. if f in self.val_set:
  272. self.class_val_file_list[label].append(f)
  273. if f in self.train_set:
  274. self.class_train_file_list[label].append(f)
  275. self.dump_statis_info()