convert_dataset.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import glob
  15. import json
  16. import os
  17. import os.path as osp
  18. import shutil
  19. import cv2
  20. import numpy as np
  21. from PIL import Image, ImageDraw
  22. from .....utils.file_interface import custom_open
  23. from .....utils import logging
  24. def convert_dataset(dataset_type, input_dir):
  25. """convert to paddlex official format"""
  26. if dataset_type == "LabelMe":
  27. return convert_labelme_dataset(input_dir)
  28. else:
  29. raise NotImplementedError(dataset_type)
  30. def convert_labelme_dataset(input_dir):
  31. """convert labelme format to paddlex official format"""
  32. bg_name = "_background_"
  33. ignore_name = "__ignore__"
  34. # prepare dir
  35. output_img_dir = osp.join(input_dir, "images")
  36. output_annot_dir = osp.join(input_dir, "annotations")
  37. if not osp.exists(output_img_dir):
  38. os.makedirs(output_img_dir)
  39. if not osp.exists(output_annot_dir):
  40. os.makedirs(output_annot_dir)
  41. # collect class_names and set class_name_to_id
  42. class_names = []
  43. class_name_to_id = {}
  44. split_tags = ["train", "val"]
  45. for tag in split_tags:
  46. mapping_file = osp.join(input_dir, f"{tag}_anno_list.txt")
  47. with open(mapping_file, "r") as f:
  48. label_files = [
  49. osp.join(input_dir, line.strip("\n")) for line in f.readlines()
  50. ]
  51. for label_file in label_files:
  52. with custom_open(label_file, "r") as fp:
  53. data = json.load(fp)
  54. for shape in data["shapes"]:
  55. cls_name = shape["label"]
  56. if cls_name not in class_names:
  57. class_names.append(cls_name)
  58. if ignore_name in class_names:
  59. class_name_to_id[ignore_name] = 255
  60. class_names.remove(ignore_name)
  61. if bg_name in class_names:
  62. class_names.remove(bg_name)
  63. class_name_to_id[bg_name] = 0
  64. for i, name in enumerate(class_names):
  65. class_name_to_id[name] = i + 1
  66. if len(class_names) > 256:
  67. raise ValueError(
  68. f"There are {len(class_names)} categories in the annotation file, "
  69. f"exceeding 256, Not compliant with paddlex official format!"
  70. )
  71. # create annotated images and copy origin images
  72. color_map = get_color_map_list(256)
  73. img_file_list = []
  74. label_file_list = []
  75. for i, label_file in enumerate(label_files):
  76. filename = osp.splitext(osp.basename(label_file))[0]
  77. annotated_img_path = osp.join(output_annot_dir, filename + ".png")
  78. with custom_open(label_file, "r") as f:
  79. data = json.load(f)
  80. img_path = osp.join(osp.dirname(label_file), data["imagePath"])
  81. if not os.path.exists(img_path):
  82. logging.info("%s is not existed, skip this image" % img_path)
  83. continue
  84. img_name = img_path.split("/")[-1]
  85. img_file_list.append(f"images/{img_name}")
  86. label_img_name = annotated_img_path.split("/")[-1]
  87. label_file_list.append(f"annotations/{label_img_name}")
  88. img = np.asarray(cv2.imread(img_path))
  89. lbl = shape2label(
  90. img_size=img.shape,
  91. shapes=data["shapes"],
  92. class_name_mapping=class_name_to_id,
  93. )
  94. lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode="P")
  95. lbl_pil.putpalette(color_map)
  96. lbl_pil.save(annotated_img_path)
  97. shutil.copy(img_path, output_img_dir)
  98. with custom_open(osp.join(input_dir, f"{tag}.txt"), "w") as fp:
  99. for img_path, lbl_path in zip(img_file_list, label_file_list):
  100. fp.write(f"{img_path} {lbl_path}\n")
  101. with custom_open(osp.join(input_dir, "class_name.txt"), "w") as fp:
  102. for name in class_names:
  103. fp.write(f"{name}{os.linesep}")
  104. with custom_open(osp.join(input_dir, "class_name_to_id.txt"), "w") as fp:
  105. for key, val in class_name_to_id.items():
  106. fp.write(f"{val}: {key}{os.linesep}")
  107. return input_dir
  108. def get_color_map_list(num_classes):
  109. """get color map list"""
  110. num_classes += 1
  111. color_map = num_classes * [0, 0, 0]
  112. for i in range(0, num_classes):
  113. j = 0
  114. lab = i
  115. while lab:
  116. color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
  117. color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
  118. color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
  119. j += 1
  120. lab >>= 3
  121. color_map = color_map[3:]
  122. return color_map
  123. def shape2label(img_size, shapes, class_name_mapping):
  124. """根据输入的形状列表,将图像的标签矩阵填充为对应形状的类别编号"""
  125. label = np.zeros(img_size[:2], dtype=np.int32)
  126. for shape in shapes:
  127. points = shape["points"]
  128. class_name = shape["label"]
  129. label_mask = polygon2mask(img_size[:2], points)
  130. label[label_mask] = class_name_mapping[class_name]
  131. return label
  132. def polygon2mask(img_size, points):
  133. """将给定形状的点转换成对应的掩膜"""
  134. label_mask = Image.fromarray(np.zeros(img_size[:2], dtype=np.uint8))
  135. image_draw = ImageDraw.Draw(label_mask)
  136. points_list = [tuple(point) for point in points]
  137. assert len(points_list) > 2, ValueError("Polygon must have points more than 2")
  138. image_draw.polygon(xy=points_list, outline=1, fill=1)
  139. return np.array(label_mask, dtype=bool)