x2seg.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import cv2
  15. import uuid
  16. import json
  17. import os
  18. import os.path as osp
  19. import shutil
  20. import numpy as np
  21. import PIL.Image
  22. from paddlex.utils import is_pic, get_encoding
  23. import math
  24. class X2Seg(object):
  25. def __init__(self):
  26. self.labels2ids = {'_background_': 0}
  27. def shapes_to_label(self, img_shape, shapes, label_name_to_value):
  28. # 该函数基于https://github.com/wkentaro/labelme/blob/master/labelme/utils/shape.py实现。
  29. def shape_to_mask(img_shape,
  30. points,
  31. shape_type=None,
  32. line_width=10,
  33. point_size=5):
  34. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  35. mask = PIL.Image.fromarray(mask)
  36. draw = PIL.ImageDraw.Draw(mask)
  37. xy = [tuple(point) for point in points]
  38. if shape_type == 'circle':
  39. assert len(
  40. xy) == 2, 'Shape of shape_type=circle must have 2 points'
  41. (cx, cy), (px, py) = xy
  42. d = math.sqrt((cx - px)**2 + (cy - py)**2)
  43. draw.ellipse(
  44. [cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
  45. elif shape_type == 'rectangle':
  46. assert len(
  47. xy) == 2, 'Shape of shape_type=rectangle must have 2 points'
  48. draw.rectangle(xy, outline=1, fill=1)
  49. elif shape_type == 'line':
  50. assert len(
  51. xy) == 2, 'Shape of shape_type=line must have 2 points'
  52. draw.line(xy=xy, fill=1, width=line_width)
  53. elif shape_type == 'linestrip':
  54. draw.line(xy=xy, fill=1, width=line_width)
  55. elif shape_type == 'point':
  56. assert len(
  57. xy) == 1, 'Shape of shape_type=point must have 1 points'
  58. cx, cy = xy[0]
  59. r = point_size
  60. draw.ellipse(
  61. [cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
  62. else:
  63. assert len(xy) > 2, 'Polygon must have points more than 2'
  64. draw.polygon(xy=xy, outline=1, fill=1)
  65. mask = np.array(mask, dtype=bool)
  66. return mask
  67. cls = np.zeros(img_shape[:2], dtype=np.int32)
  68. ins = np.zeros_like(cls)
  69. instances = []
  70. for shape in shapes:
  71. points = shape['points']
  72. label = shape['label']
  73. group_id = shape.get('group_id')
  74. if group_id is None:
  75. group_id = uuid.uuid1()
  76. shape_type = shape.get('shape_type', None)
  77. cls_name = label
  78. instance = (cls_name, group_id)
  79. if instance not in instances:
  80. instances.append(instance)
  81. ins_id = instances.index(instance) + 1
  82. cls_id = label_name_to_value[cls_name]
  83. mask = shape_to_mask(img_shape[:2], points, shape_type)
  84. cls[mask] = cls_id
  85. ins[mask] = ins_id
  86. return cls, ins
  87. def get_color_map_list(self, num_classes):
  88. color_map = num_classes * [0, 0, 0]
  89. for i in range(0, num_classes):
  90. j = 0
  91. lab = i
  92. while lab:
  93. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  94. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  95. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  96. j += 1
  97. lab >>= 3
  98. return color_map
  99. def convert(self, image_dir, json_dir, dataset_save_dir):
  100. """转换。
  101. Args:
  102. image_dir (str): 图像文件存放的路径。
  103. json_dir (str): 与每张图像对应的json文件的存放路径。
  104. dataset_save_dir (str): 转换后数据集存放路径。
  105. """
  106. assert osp.exists(image_dir), "The image folder does not exist!"
  107. assert osp.exists(json_dir), "The json folder does not exist!"
  108. if not osp.exists(dataset_save_dir):
  109. os.makedirs(dataset_save_dir)
  110. # Convert the image files.
  111. new_image_dir = osp.join(dataset_save_dir, "JPEGImages")
  112. if osp.exists(new_image_dir):
  113. raise Exception(
  114. "The directory {} is already exist, please remove the directory first".
  115. format(new_image_dir))
  116. os.makedirs(new_image_dir)
  117. for img_name in os.listdir(image_dir):
  118. if is_pic(img_name):
  119. shutil.copyfile(
  120. osp.join(image_dir, img_name),
  121. osp.join(new_image_dir, img_name))
  122. # Convert the json files.
  123. png_dir = osp.join(dataset_save_dir, "Annotations")
  124. if osp.exists(png_dir):
  125. shutil.rmtree(png_dir)
  126. os.makedirs(png_dir)
  127. self.get_labels2ids(new_image_dir, json_dir)
  128. self.json2png(new_image_dir, json_dir, png_dir)
  129. # Generate the labels.txt
  130. ids2labels = {v: k for k, v in self.labels2ids.items()}
  131. with open(osp.join(dataset_save_dir, 'labels.txt'), 'w') as fw:
  132. for i in range(len(ids2labels)):
  133. fw.write(ids2labels[i] + '\n')
  134. class JingLing2Seg(X2Seg):
  135. """将使用标注精灵标注的数据集转换为Seg数据集。
  136. """
  137. def __init__(self):
  138. super(JingLing2Seg, self).__init__()
  139. def get_labels2ids(self, image_dir, json_dir):
  140. for img_name in os.listdir(image_dir):
  141. img_name_part = osp.splitext(img_name)[0]
  142. json_file = osp.join(json_dir, img_name_part + ".json")
  143. if not osp.exists(json_file):
  144. os.remove(osp.join(image_dir, img_name))
  145. continue
  146. with open(json_file, mode="r", \
  147. encoding=get_encoding(json_file)) as j:
  148. json_info = json.load(j)
  149. if 'outputs' in json_info:
  150. for output in json_info['outputs']['object']:
  151. cls_name = output['name']
  152. if cls_name not in self.labels2ids:
  153. self.labels2ids[cls_name] = len(self.labels2ids)
  154. def json2png(self, image_dir, json_dir, png_dir):
  155. color_map = self.get_color_map_list(256)
  156. for img_name in os.listdir(image_dir):
  157. img_name_part = osp.splitext(img_name)[0]
  158. json_file = osp.join(json_dir, img_name_part + ".json")
  159. if not osp.exists(json_file):
  160. os.remove(osp.join(image_dir, img_name))
  161. continue
  162. with open(json_file, mode="r", \
  163. encoding=get_encoding(json_file)) as j:
  164. json_info = json.load(j)
  165. data_shapes = []
  166. if 'outputs' in json_info:
  167. for output in json_info['outputs']['object']:
  168. if 'polygon' in output.keys():
  169. polygon = output['polygon']
  170. name = output['name']
  171. points = []
  172. for i in range(1, int(len(polygon) / 2) + 1):
  173. points.append([
  174. polygon['x' + str(i)],
  175. polygon['y' + str(i)]
  176. ])
  177. shape = {
  178. 'label': name,
  179. 'points': points,
  180. 'shape_type': 'polygon'
  181. }
  182. data_shapes.append(shape)
  183. if 'size' not in json_info:
  184. continue
  185. img_shape = (json_info['size']['height'],
  186. json_info['size']['width'],
  187. json_info['size']['depth'])
  188. lbl, _ = self.shapes_to_label(
  189. img_shape=img_shape,
  190. shapes=data_shapes,
  191. label_name_to_value=self.labels2ids, )
  192. out_png_file = osp.join(png_dir, img_name_part + '.png')
  193. if lbl.min() >= 0 and lbl.max() <= 255:
  194. lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
  195. lbl_pil.putpalette(color_map)
  196. lbl_pil.save(out_png_file)
  197. else:
  198. raise ValueError(
  199. '[%s] Cannot save the pixel-wise class label as PNG. '
  200. 'Please consider using the .npy format.' % out_png_file)
  201. class LabelMe2Seg(X2Seg):
  202. """将使用LabelMe标注的数据集转换为Seg数据集。
  203. """
  204. def __init__(self):
  205. super(LabelMe2Seg, self).__init__()
  206. def get_labels2ids(self, image_dir, json_dir):
  207. for img_name in os.listdir(image_dir):
  208. img_name_part = osp.splitext(img_name)[0]
  209. json_file = osp.join(json_dir, img_name_part + ".json")
  210. if not osp.exists(json_file):
  211. os.remove(osp.join(image_dir, img_name))
  212. continue
  213. with open(json_file, mode="r", \
  214. encoding=get_encoding(json_file)) as j:
  215. json_info = json.load(j)
  216. for shape in json_info['shapes']:
  217. cls_name = shape['label']
  218. if cls_name not in self.labels2ids:
  219. self.labels2ids[cls_name] = len(self.labels2ids)
  220. def json2png(self, image_dir, json_dir, png_dir):
  221. color_map = self.get_color_map_list(256)
  222. for img_name in os.listdir(image_dir):
  223. img_name_part = osp.splitext(img_name)[0]
  224. json_file = osp.join(json_dir, img_name_part + ".json")
  225. if not osp.exists(json_file):
  226. os.remove(osp.join(image_dir, img_name))
  227. continue
  228. img_file = osp.join(image_dir, img_name)
  229. img = np.asarray(PIL.Image.open(img_file))
  230. with open(json_file, mode="r", \
  231. encoding=get_encoding(json_file)) as j:
  232. json_info = json.load(j)
  233. lbl, _ = self.shapes_to_label(
  234. img_shape=img.shape,
  235. shapes=json_info['shapes'],
  236. label_name_to_value=self.labels2ids, )
  237. out_png_file = osp.join(png_dir, img_name_part + '.png')
  238. if lbl.min() >= 0 and lbl.max() <= 255:
  239. lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
  240. lbl_pil.putpalette(color_map)
  241. lbl_pil.save(out_png_file)
  242. else:
  243. raise ValueError(
  244. '[%s] Cannot save the pixel-wise class label as PNG. '
  245. 'Please consider using the .npy format.' % out_png_file)
  246. class EasyData2Seg(X2Seg):
  247. """将使用EasyData标注的分割数据集转换为Seg数据集。
  248. """
  249. def __init__(self):
  250. super(EasyData2Seg, self).__init__()
  251. def get_labels2ids(self, image_dir, json_dir):
  252. for img_name in os.listdir(image_dir):
  253. img_name_part = osp.splitext(img_name)[0]
  254. json_file = osp.join(json_dir, img_name_part + ".json")
  255. if not osp.exists(json_file):
  256. os.remove(osp.join(image_dir, img_name))
  257. continue
  258. with open(json_file, mode="r", \
  259. encoding=get_encoding(json_file)) as j:
  260. json_info = json.load(j)
  261. for shape in json_info["labels"]:
  262. cls_name = shape['name']
  263. if cls_name not in self.labels2ids:
  264. self.labels2ids[cls_name] = len(self.labels2ids)
  265. def mask2polygon(self, mask, label):
  266. contours, hierarchy = cv2.findContours(
  267. (mask).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  268. segmentation = []
  269. for contour in contours:
  270. contour_list = contour.flatten().tolist()
  271. if len(contour_list) > 4:
  272. points = []
  273. for i in range(0, len(contour_list), 2):
  274. points.append([contour_list[i], contour_list[i + 1]])
  275. shape = {
  276. 'label': label,
  277. 'points': points,
  278. 'shape_type': 'polygon'
  279. }
  280. segmentation.append(shape)
  281. return segmentation
  282. def json2png(self, image_dir, json_dir, png_dir):
  283. from pycocotools.mask import decode
  284. color_map = self.get_color_map_list(256)
  285. for img_name in os.listdir(image_dir):
  286. img_name_part = osp.splitext(img_name)[0]
  287. json_file = osp.join(json_dir, img_name_part + ".json")
  288. if not osp.exists(json_file):
  289. os.remove(osp.join(image_dir, img_name))
  290. continue
  291. img_file = osp.join(image_dir, img_name)
  292. img = np.asarray(PIL.Image.open(img_file))
  293. img_h = img.shape[0]
  294. img_w = img.shape[1]
  295. with open(json_file, mode="r", \
  296. encoding=get_encoding(json_file)) as j:
  297. json_info = json.load(j)
  298. data_shapes = []
  299. for shape in json_info['labels']:
  300. mask_dict = {}
  301. mask_dict['size'] = [img_h, img_w]
  302. mask_dict['counts'] = shape['mask'].encode()
  303. mask = decode(mask_dict)
  304. polygon = self.mask2polygon(mask, shape["name"])
  305. data_shapes.extend(polygon)
  306. lbl, _ = self.shapes_to_label(
  307. img_shape=img.shape,
  308. shapes=data_shapes,
  309. label_name_to_value=self.labels2ids, )
  310. out_png_file = osp.join(png_dir, img_name_part + '.png')
  311. if lbl.min() >= 0 and lbl.max() <= 255:
  312. lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
  313. lbl_pil.putpalette(color_map)
  314. lbl_pil.save(out_png_file)
  315. else:
  316. raise ValueError(
  317. '[%s] Cannot save the pixel-wise class label as PNG. '
  318. 'Please consider using the .npy format.' % out_png_file)