convert_dataset.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import numpy as np
  19. from PIL import Image, ImageDraw
  20. from .....utils import logging
  21. from .....utils.deps import function_requires_deps, is_dep_available
  22. from .....utils.file_interface import custom_open
  23. from .....utils.logging import info
  24. if is_dep_available("opencv-contrib-python"):
  25. import cv2
  26. def convert_dataset(dataset_type, input_dir):
  27. """convert to paddlex official format"""
  28. if dataset_type == "LabelMe":
  29. return convert_labelme_dataset(input_dir)
  30. elif dataset_type == "MVTec_AD":
  31. return convert_mvtec_dataset(input_dir)
  32. else:
  33. raise NotImplementedError(dataset_type)
  34. @function_requires_deps("opencv-contrib-python")
  35. def convert_labelme_dataset(input_dir):
  36. """convert labelme format to paddlex official format"""
  37. bg_name = "_background_"
  38. ignore_name = "__ignore__"
  39. # prepare dir
  40. output_img_dir = osp.join(input_dir, "images")
  41. output_annot_dir = osp.join(input_dir, "annotations")
  42. if not osp.exists(output_img_dir):
  43. os.makedirs(output_img_dir)
  44. if not osp.exists(output_annot_dir):
  45. os.makedirs(output_annot_dir)
  46. # collect class_names and set class_name_to_id
  47. class_names = []
  48. class_name_to_id = {}
  49. split_tags = ["train", "val"]
  50. for tag in split_tags:
  51. mapping_file = osp.join(input_dir, f"{tag}_anno_list.txt")
  52. with open(mapping_file, "r") as f:
  53. label_files = [
  54. osp.join(input_dir, line.strip("\n")) for line in f.readlines()
  55. ]
  56. for label_file in label_files:
  57. with custom_open(label_file, "r") as fp:
  58. data = json.load(fp)
  59. for shape in data["shapes"]:
  60. cls_name = shape["label"]
  61. if cls_name not in class_names:
  62. class_names.append(cls_name)
  63. if ignore_name in class_names:
  64. class_name_to_id[ignore_name] = 255
  65. class_names.remove(ignore_name)
  66. if bg_name in class_names:
  67. class_names.remove(bg_name)
  68. class_name_to_id[bg_name] = 0
  69. for i, name in enumerate(class_names):
  70. class_name_to_id[name] = i + 1
  71. if len(class_names) > 256:
  72. raise ValueError(
  73. f"There are {len(class_names)} categories in the annotation file, "
  74. f"exceeding 256, Not compliant with paddlex official format!"
  75. )
  76. # create annotated images and copy origin images
  77. color_map = get_color_map_list(256)
  78. img_file_list = []
  79. label_file_list = []
  80. for i, label_file in enumerate(label_files):
  81. filename = osp.splitext(osp.basename(label_file))[0]
  82. annotated_img_path = osp.join(output_annot_dir, filename + ".png")
  83. with custom_open(label_file, "r") as f:
  84. data = json.load(f)
  85. img_path = osp.join(osp.dirname(label_file), data["imagePath"])
  86. if not os.path.exists(img_path):
  87. logging.info("%s is not existed, skip this image" % img_path)
  88. continue
  89. img_name = osp.basename(img_path)
  90. img_file_list.append(f"images/{img_name}")
  91. label_img_name = osp.basename(annotated_img_path)
  92. label_file_list.append(f"annotations/{label_img_name}")
  93. img = np.asarray(cv2.imread(img_path))
  94. lbl = shape2label(
  95. img_size=img.shape,
  96. shapes=data["shapes"],
  97. class_name_mapping=class_name_to_id,
  98. )
  99. lbl_pil = Image.fromarray(lbl.astype(np.uint8), mode="P")
  100. lbl_pil.putpalette(color_map)
  101. lbl_pil.save(annotated_img_path)
  102. shutil.copy(img_path, output_img_dir)
  103. with custom_open(osp.join(input_dir, f"{tag}.txt"), "w") as fp:
  104. for img_path, lbl_path in zip(img_file_list, label_file_list):
  105. fp.write(f"{img_path} {lbl_path}\n")
  106. with custom_open(osp.join(input_dir, "class_name.txt"), "w") as fp:
  107. for name in class_names:
  108. fp.write(f"{name}{os.linesep}")
  109. with custom_open(osp.join(input_dir, "class_name_to_id.txt"), "w") as fp:
  110. for key, val in class_name_to_id.items():
  111. fp.write(f"{val}: {key}{os.linesep}")
  112. return input_dir
  113. def get_color_map_list(num_classes):
  114. """get color map list"""
  115. num_classes += 1
  116. color_map = num_classes * [0, 0, 0]
  117. for i in range(0, num_classes):
  118. j = 0
  119. lab = i
  120. while lab:
  121. color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
  122. color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
  123. color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
  124. j += 1
  125. lab >>= 3
  126. color_map = color_map[3:]
  127. return color_map
  128. def shape2label(img_size, shapes, class_name_mapping):
  129. """根据输入的形状列表,将图像的标签矩阵填充为对应形状的类别编号"""
  130. label = np.zeros(img_size[:2], dtype=np.int32)
  131. for shape in shapes:
  132. points = shape["points"]
  133. class_name = shape["label"]
  134. label_mask = polygon2mask(img_size[:2], points)
  135. label[label_mask] = class_name_mapping[class_name]
  136. return label
  137. def polygon2mask(img_size, points):
  138. """将给定形状的点转换成对应的掩膜"""
  139. label_mask = Image.fromarray(np.zeros(img_size[:2], dtype=np.uint8))
  140. image_draw = ImageDraw.Draw(label_mask)
  141. points_list = [tuple(point) for point in points]
  142. assert len(points_list) > 2, ValueError("Polygon must have points more than 2")
  143. image_draw.polygon(xy=points_list, outline=1, fill=1)
  144. return np.array(label_mask, dtype=bool)
  145. def save_item_to_txt(items, file_path):
  146. try:
  147. with open(file_path, "a") as file:
  148. file.write(items)
  149. file.close()
  150. except Exception as e:
  151. print(f"Saving_error: {e}")
  152. def save_training_txt(cls_root, mode, cat):
  153. imgs = os.listdir(os.path.join(cls_root, mode, cat))
  154. imgs.sort()
  155. for img in imgs:
  156. if mode == "train":
  157. item = os.path.join(cls_root, mode, cat, img)
  158. items = item + " " + item + "\n"
  159. save_item_to_txt(items, os.path.join(cls_root, "train.txt"))
  160. elif mode == "test" and cat != "good":
  161. item1 = os.path.join(cls_root, mode, cat, img)
  162. item2 = os.path.join(
  163. cls_root, "ground_truth", cat, img.split(".")[0] + "_mask.png"
  164. )
  165. items = item1 + " " + item2 + "\n"
  166. save_item_to_txt(items, os.path.join(cls_root, "val.txt"))
  167. def check_old_txt(cls_pth, mode):
  168. set_name = "train.txt" if mode == "train" else "val.txt"
  169. pth = os.path.join(cls_pth, set_name)
  170. if os.path.exists(pth):
  171. os.remove(pth)
  172. def convert_mvtec_dataset(input_dir):
  173. classes = [
  174. "bottle",
  175. "cable",
  176. "capsule",
  177. "hazelnut",
  178. "metal_nut",
  179. "pill",
  180. "screw",
  181. "toothbrush",
  182. "transistor",
  183. "zipper",
  184. "carpet",
  185. "grid",
  186. "leather",
  187. "tile",
  188. "wood",
  189. ]
  190. clas = os.path.split(input_dir)[-1]
  191. assert clas in classes, info(
  192. f"Make sure your class: '{clas}' in your dataset root in\n {classes}"
  193. )
  194. modes = ["train", "test"]
  195. cls_root = input_dir
  196. for mode in modes:
  197. check_old_txt(cls_root, mode)
  198. cats = os.listdir(os.path.join(cls_root, mode))
  199. for cat in cats:
  200. save_training_txt(cls_root, mode, cat)
  201. info(f"Add train.txt/val.txt successfully for {input_dir}")