| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import shutil
- import json
- import random
- import xml.etree.ElementTree as ET
- import cv2
- import numpy as np
- from PIL import Image, ImageDraw
- from tqdm import tqdm
- from .....utils.file_interface import custom_open, write_json_file
- from .....utils.errors import ConvertFailedError
- from .....utils.logging import info, warning
- class Indexer(object):
- """Indexer"""
- def __init__(self):
- """init indexer"""
- self._map = {}
- self.idx = 0
- def get_id(self, key):
- """get id by key"""
- if key not in self._map:
- self.idx += 1
- self._map[key] = self.idx
- return self._map[key]
- def get_list(self, key_name):
- """return list containing key and id"""
- map_list = []
- for key in self._map:
- val = self._map[key]
- map_list.append({key_name: key, "id": val})
- return map_list
- class Extension(object):
- """Extension"""
- def __init__(self, exts_list):
- """init extension"""
- self._exts_list = ["." + ext for ext in exts_list]
- def __iter__(self):
- """iterator"""
- return iter(self._exts_list)
- def update(self, ext):
- """update extension"""
- self._exts_list.remove(ext)
- self._exts_list.insert(0, ext)
- def check_src_dataset(root_dir, dataset_type):
- """check src dataset format validity"""
- if dataset_type == "LabelMe":
- anno_suffix = ".json"
- else:
- raise ConvertFailedError(
- message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 LabelMe 格式。"
- )
- err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
- anno_map = {}
- for dst_anno, src_anno in [
- ("instance_train.json", "train_anno_list.txt"),
- ("instance_val.json", "val_anno_list.txt"),
- ]:
- src_anno_path = os.path.join(root_dir, src_anno)
- if not os.path.exists(src_anno_path):
- if dst_anno == "instance_train.json":
- raise ConvertFailedError(
- message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
- )
- continue
- with custom_open(src_anno_path, "r") as f:
- anno_list = f.readlines()
- for anno_fn in anno_list:
- anno_fn = anno_fn.strip().split(" ")[-1]
- anno_path = os.path.join(root_dir, anno_fn)
- if not os.path.exists(anno_path):
- raise ConvertFailedError(
- message=f'{err_msg_prefix}保证"{src_anno_path}"中的"{anno_fn}"文件存在。'
- )
- anno_map[dst_anno] = src_anno_path
- return anno_map
- def convert(dataset_type, input_dir):
- """convert dataset to coco format"""
- # check format validity
- anno_map = check_src_dataset(input_dir, dataset_type)
- if dataset_type == "LabelMe":
- convert_labelme_dataset(input_dir, anno_map)
- else:
- raise ValueError
- def split_anno_list(root_dir, anno_map):
- """Split anno list to 80% train and 20% val"""
- train_anno_list = []
- val_anno_list = []
- anno_list_bak = os.path.join(root_dir, "train_anno_list.txt.bak")
- shutil.move(anno_map["instance_train.json"], anno_list_bak),
- with custom_open(anno_list_bak, "r") as f:
- src_anno = f.readlines()
- random.shuffle(src_anno)
- train_anno_list = src_anno[: int(len(src_anno) * 0.8)]
- val_anno_list = src_anno[int(len(src_anno) * 0.8) :]
- with custom_open(os.path.join(root_dir, "train_anno_list.txt"), "w") as f:
- f.writelines(train_anno_list)
- with custom_open(os.path.join(root_dir, "val_anno_list.txt"), "w") as f:
- f.writelines(val_anno_list)
- anno_map["instance_train.json"] = os.path.join(root_dir, "train_anno_list.txt")
- anno_map["instance_val.json"] = os.path.join(root_dir, "val_anno_list.txt")
- msg = f"{os.path.join(root_dir,'val_anno_list.txt')}不存在,数据集已默认按照80%训练集,20%验证集划分,\
- 且将原始'train_anno_list.txt'重命名为'train_anno_list.txt.bak'."
- warning(msg)
- return anno_map
- def convert_labelme_dataset(root_dir, anno_map):
- """convert dataset labeled by LabelMe to coco format"""
- label_indexer = Indexer()
- img_indexer = Indexer()
- annotations_dir = os.path.join(root_dir, "annotations")
- if not os.path.exists(annotations_dir):
- os.makedirs(annotations_dir)
- # 不存在val_anno_list,对原始数据集进行划分
- if "instance_val.json" not in anno_map:
- anno_map = split_anno_list(root_dir, anno_map)
- for dst_anno in anno_map:
- labelme2coco(
- label_indexer,
- img_indexer,
- root_dir,
- anno_map[dst_anno],
- os.path.join(annotations_dir, dst_anno),
- )
- def labelme2coco(label_indexer, img_indexer, root_dir, anno_path, save_path):
- """convert json files generated by LabelMe to coco format and save to files"""
- import pycocotools.mask as mask_util
- with custom_open(anno_path, "r") as f:
- json_list = f.readlines()
- anno_num = 0
- anno_list = []
- image_list = []
- info(f"Start loading json annotation files from {anno_path} ...")
- for json_path in tqdm(json_list):
- json_path = json_path.strip()
- assert json_path.endswith(".json"), json_path
- with custom_open(os.path.join(root_dir, json_path.strip()), "r") as f:
- labelme_data = json.load(f)
- img_id = img_indexer.get_id(labelme_data["imagePath"])
- height = labelme_data["imageHeight"]
- width = labelme_data["imageWidth"]
- image_list.append(
- {
- "id": img_id,
- "file_name": labelme_data["imagePath"].split("/")[-1],
- "width": width,
- "height": height,
- }
- )
- for shape in labelme_data["shapes"]:
- assert shape["shape_type"] == "polygon", "Only polygon are supported."
- category_id = label_indexer.get_id(shape["label"])
- points = shape["points"]
- segmentation = [np.asarray(points).flatten().tolist()]
- mask = points_to_mask([height, width], points)
- mask = np.asfortranarray(mask.astype(np.uint8))
- mask = mask_util.encode(mask)
- area = float(mask_util.area(mask))
- bbox = mask_util.toBbox(mask).flatten().tolist()
- anno_num += 1
- anno_list.append(
- {
- "image_id": img_id,
- "bbox": bbox,
- "segmentation": segmentation,
- "category_id": category_id,
- "id": anno_num,
- "iscrowd": 0,
- "area": area,
- "ignore": 0,
- }
- )
- category_list = label_indexer.get_list(key_name="name")
- data_coco = {
- "images": image_list,
- "categories": category_list,
- "annotations": anno_list,
- }
- write_json_file(data_coco, save_path)
- info(f"The converted annotations has been save to {save_path}.")
- def points_to_mask(img_shape, points):
- """convert polygon points to binary mask"""
- mask = np.zeros(img_shape[:2], dtype=np.uint8)
- mask = Image.fromarray(mask)
- draw = ImageDraw.Draw(mask)
- xy = [tuple(point) for point in points]
- assert len(xy) > 2, "Polygon must have points more than 2"
- draw.polygon(xy=xy, outline=1, fill=1)
- mask = np.asarray(mask, dtype=bool)
- return mask
|