x2seg.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import cv2
  17. import uuid
  18. import json
  19. import os
  20. import os.path as osp
  21. import shutil
  22. import numpy as np
  23. import PIL.Image
  24. from .base import MyEncoder, is_pic, get_encoding
  25. import math
  26. class X2Seg(object):
  27. def __init__(self):
  28. self.labels2ids = {'_background_': 0}
  29. def shapes_to_label(self, img_shape, shapes, label_name_to_value):
  30. # 该函数基于https://github.com/wkentaro/labelme/blob/master/labelme/utils/shape.py实现。
  31. def shape_to_mask(img_shape, points, shape_type=None,
  32. line_width=10, point_size=5):
  33. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  34. mask = PIL.Image.fromarray(mask)
  35. draw = PIL.ImageDraw.Draw(mask)
  36. xy = [tuple(point) for point in points]
  37. if shape_type == 'circle':
  38. assert len(xy) == 2, 'Shape of shape_type=circle must have 2 points'
  39. (cx, cy), (px, py) = xy
  40. d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
  41. draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
  42. elif shape_type == 'rectangle':
  43. assert len(xy) == 2, 'Shape of shape_type=rectangle must have 2 points'
  44. draw.rectangle(xy, outline=1, fill=1)
  45. elif shape_type == 'line':
  46. assert len(xy) == 2, 'Shape of shape_type=line must have 2 points'
  47. draw.line(xy=xy, fill=1, width=line_width)
  48. elif shape_type == 'linestrip':
  49. draw.line(xy=xy, fill=1, width=line_width)
  50. elif shape_type == 'point':
  51. assert len(xy) == 1, 'Shape of shape_type=point must have 1 points'
  52. cx, cy = xy[0]
  53. r = point_size
  54. draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
  55. else:
  56. assert len(xy) > 2, 'Polygon must have points more than 2'
  57. draw.polygon(xy=xy, outline=1, fill=1)
  58. mask = np.array(mask, dtype=bool)
  59. return mask
  60. cls = np.zeros(img_shape[:2], dtype=np.int32)
  61. ins = np.zeros_like(cls)
  62. instances = []
  63. for shape in shapes:
  64. points = shape['points']
  65. label = shape['label']
  66. group_id = shape.get('group_id')
  67. if group_id is None:
  68. group_id = uuid.uuid1()
  69. shape_type = shape.get('shape_type', None)
  70. cls_name = label
  71. instance = (cls_name, group_id)
  72. if instance not in instances:
  73. instances.append(instance)
  74. ins_id = instances.index(instance) + 1
  75. cls_id = label_name_to_value[cls_name]
  76. mask = shape_to_mask(img_shape[:2], points, shape_type)
  77. cls[mask] = cls_id
  78. ins[mask] = ins_id
  79. return cls, ins
  80. def get_color_map_list(self, num_classes):
  81. color_map = num_classes * [0, 0, 0]
  82. for i in range(0, num_classes):
  83. j = 0
  84. lab = i
  85. while lab:
  86. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  87. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  88. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  89. j += 1
  90. lab >>= 3
  91. return color_map
  92. def convert(self, image_dir, json_dir, dataset_save_dir):
  93. """转换。
  94. Args:
  95. image_dir (str): 图像文件存放的路径。
  96. json_dir (str): 与每张图像对应的json文件的存放路径。
  97. dataset_save_dir (str): 转换后数据集存放路径。
  98. """
  99. assert osp.exists(image_dir), "The image folder does not exist!"
  100. assert osp.exists(json_dir), "The json folder does not exist!"
  101. assert osp.exists(dataset_save_dir), "The save folder does not exist!"
  102. # Convert the image files.
  103. new_image_dir = osp.join(dataset_save_dir, "JPEGImages")
  104. if osp.exists(new_image_dir):
  105. shutil.rmtree(new_image_dir)
  106. os.makedirs(new_image_dir)
  107. for img_name in os.listdir(image_dir):
  108. if is_pic(img_name):
  109. shutil.copyfile(
  110. osp.join(image_dir, img_name),
  111. osp.join(new_image_dir, img_name))
  112. # Convert the json files.
  113. png_dir = osp.join(dataset_save_dir, "Annotations")
  114. if osp.exists(png_dir):
  115. shutil.rmtree(png_dir)
  116. os.makedirs(png_dir)
  117. self.get_labels2ids(new_image_dir, json_dir)
  118. self.json2png(new_image_dir, json_dir, png_dir)
  119. # Generate the labels.txt
  120. ids2labels = {v : k for k, v in self.labels2ids.items()}
  121. with open(osp.join(dataset_save_dir, 'labels.txt'), 'w') as fw:
  122. for i in range(len(ids2labels)):
  123. fw.write(ids2labels[i] + '\n')
  124. class JingLing2Seg(X2Seg):
  125. """将使用标注精灵标注的数据集转换为Seg数据集。
  126. """
  127. def __init__(self):
  128. super(JingLing2Seg, self).__init__()
  129. def get_labels2ids(self, image_dir, json_dir):
  130. for img_name in os.listdir(image_dir):
  131. img_name_part = osp.splitext(img_name)[0]
  132. json_file = osp.join(json_dir, img_name_part + ".json")
  133. if not osp.exists(json_file):
  134. os.remove(osp.join(image_dir, img_name))
  135. continue
  136. with open(json_file, mode="r", \
  137. encoding=get_encoding(json_file)) as j:
  138. json_info = json.load(j)
  139. if 'outputs' in json_info:
  140. for output in json_info['outputs']['object']:
  141. cls_name = output['name']
  142. if cls_name not in self.labels2ids:
  143. self.labels2ids[cls_name] = len(self.labels2ids)
  144. def json2png(self, image_dir, json_dir, png_dir):
  145. color_map = self.get_color_map_list(256)
  146. for img_name in os.listdir(image_dir):
  147. img_name_part = osp.splitext(img_name)[0]
  148. json_file = osp.join(json_dir, img_name_part + ".json")
  149. if not osp.exists(json_file):
  150. os.remove(os.remove(osp.join(image_dir, img_name)))
  151. continue
  152. with open(json_file, mode="r", \
  153. encoding=get_encoding(json_file)) as j:
  154. json_info = json.load(j)
  155. data_shapes = []
  156. if 'outputs' in json_info:
  157. for output in json_info['outputs']['object']:
  158. if 'polygon' in output.keys():
  159. polygon = output['polygon']
  160. name = output['name']
  161. points = []
  162. for i in range(1, int(len(polygon) / 2) + 1):
  163. points.append(
  164. [polygon['x' + str(i)], polygon['y' + str(i)]])
  165. shape = {
  166. 'label': name,
  167. 'points': points,
  168. 'shape_type': 'polygon'
  169. }
  170. data_shapes.append(shape)
  171. if 'size' not in json_info:
  172. continue
  173. img_shape = (json_info['size']['height'],
  174. json_info['size']['width'],
  175. json_info['size']['depth'])
  176. lbl, _ = self.shapes_to_label(
  177. img_shape=img_shape,
  178. shapes=data_shapes,
  179. label_name_to_value=self.labels2ids,
  180. )
  181. out_png_file = osp.join(png_dir, img_name_part + '.png')
  182. if lbl.min() >= 0 and lbl.max() <= 255:
  183. lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
  184. lbl_pil.putpalette(color_map)
  185. lbl_pil.save(out_png_file)
  186. else:
  187. raise ValueError(
  188. '[%s] Cannot save the pixel-wise class label as PNG. '
  189. 'Please consider using the .npy format.' % out_png_file)
  190. class LabelMe2Seg(X2Seg):
  191. """将使用LabelMe标注的数据集转换为Seg数据集。
  192. """
  193. def __init__(self):
  194. super(LabelMe2Seg, self).__init__()
  195. def get_labels2ids(self, image_dir, json_dir):
  196. for img_name in os.listdir(image_dir):
  197. img_name_part = osp.splitext(img_name)[0]
  198. json_file = osp.join(json_dir, img_name_part + ".json")
  199. if not osp.exists(json_file):
  200. os.remove(os.remove(osp.join(image_dir, img_name)))
  201. continue
  202. with open(json_file, mode="r", \
  203. encoding=get_encoding(json_file)) as j:
  204. json_info = json.load(j)
  205. for shape in json_info['shapes']:
  206. cls_name = shape['label']
  207. if cls_name not in self.labels2ids:
  208. self.labels2ids[cls_name] = len(self.labels2ids)
  209. def json2png(self, image_dir, json_dir, png_dir):
  210. color_map = self.get_color_map_list(256)
  211. for img_name in os.listdir(image_dir):
  212. img_name_part = osp.splitext(img_name)[0]
  213. json_file = osp.join(json_dir, img_name_part + ".json")
  214. if not osp.exists(json_file):
  215. os.remove(osp.join(image_dir, img_name))
  216. continue
  217. img_file = osp.join(image_dir, img_name)
  218. img = np.asarray(PIL.Image.open(img_file))
  219. with open(json_file, mode="r", \
  220. encoding=get_encoding(json_file)) as j:
  221. json_info = json.load(j)
  222. lbl, _ = self.shapes_to_label(
  223. img_shape=img.shape,
  224. shapes=json_info['shapes'],
  225. label_name_to_value=self.labels2ids,
  226. )
  227. out_png_file = osp.join(png_dir, img_name_part + '.png')
  228. if lbl.min() >= 0 and lbl.max() <= 255:
  229. lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
  230. lbl_pil.putpalette(color_map)
  231. lbl_pil.save(out_png_file)
  232. else:
  233. raise ValueError(
  234. '[%s] Cannot save the pixel-wise class label as PNG. '
  235. 'Please consider using the .npy format.' % out_png_file)
  236. class EasyData2Seg(X2Seg):
  237. """将使用EasyData标注的分割数据集转换为Seg数据集。
  238. """
  239. def __init__(self):
  240. super(EasyData2Seg, self).__init__()
  241. def get_labels2ids(self, image_dir, json_dir):
  242. for img_name in os.listdir(image_dir):
  243. img_name_part = osp.splitext(img_name)[0]
  244. json_file = osp.join(json_dir, img_name_part + ".json")
  245. if not osp.exists(json_file):
  246. os.remove(osp.join(image_dir, img_name))
  247. continue
  248. with open(json_file, mode="r", \
  249. encoding=get_encoding(json_file)) as j:
  250. json_info = json.load(j)
  251. for shape in json_info["labels"]:
  252. cls_name = shape['name']
  253. if cls_name not in self.labels2ids:
  254. self.labels2ids[cls_name] = len(self.labels2ids)
  255. def mask2polygon(self, mask, label):
  256. contours, hierarchy = cv2.findContours(
  257. (mask).astype(np.uint8), cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  258. segmentation = []
  259. for contour in contours:
  260. contour_list = contour.flatten().tolist()
  261. if len(contour_list) > 4:
  262. points = []
  263. for i in range(0, len(contour_list), 2):
  264. points.append(
  265. [contour_list[i], contour_list[i + 1]])
  266. shape = {
  267. 'label': label,
  268. 'points': points,
  269. 'shape_type': 'polygon'
  270. }
  271. segmentation.append(shape)
  272. return segmentation
  273. def json2png(self, image_dir, json_dir, png_dir):
  274. from pycocotools.mask import decode
  275. color_map = self.get_color_map_list(256)
  276. for img_name in os.listdir(image_dir):
  277. img_name_part = osp.splitext(img_name)[0]
  278. json_file = osp.join(json_dir, img_name_part + ".json")
  279. if not osp.exists(json_file):
  280. os.remove(os.remove(osp.join(image_dir, img_name)))
  281. continue
  282. img_file = osp.join(image_dir, img_name)
  283. img = np.asarray(PIL.Image.open(img_file))
  284. img_h = img.shape[0]
  285. img_w = img.shape[1]
  286. with open(json_file, mode="r", \
  287. encoding=get_encoding(json_file)) as j:
  288. json_info = json.load(j)
  289. data_shapes = []
  290. for shape in json_info['labels']:
  291. mask_dict = {}
  292. mask_dict['size'] = [img_h, img_w]
  293. mask_dict['counts'] = shape['mask'].encode()
  294. mask = decode(mask_dict)
  295. polygon = self.mask2polygon(mask, shape["name"])
  296. data_shapes.extend(polygon)
  297. lbl, _ = self.shapes_to_label(
  298. img_shape=img.shape,
  299. shapes=data_shapes,
  300. label_name_to_value=self.labels2ids,
  301. )
  302. out_png_file = osp.join(png_dir, img_name_part + '.png')
  303. if lbl.min() >= 0 and lbl.max() <= 255:
  304. lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
  305. lbl_pil.putpalette(color_map)
  306. lbl_pil.save(out_png_file)
  307. else:
  308. raise ValueError(
  309. '[%s] Cannot save the pixel-wise class label as PNG. '
  310. 'Please consider using the .npy format.' % out_png_file)