| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import shutil
- import random
- import json
- from tqdm import tqdm
- from .....utils.file_interface import custom_open, write_json_file
- from .....utils.logging import info
- def split_dataset(root_dir, train_rate, val_rate):
- """ split dataset """
- assert train_rate + val_rate == 100, \
- f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
- assert train_rate > 0 and val_rate > 0, \
- f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
- all_image_info_list = []
- all_category_dict = {}
- max_image_id = 0
- for fn in ["instance_train.json", "instance_val.json"]:
- anno_path = os.path.join(root_dir, "annotations", fn)
- if not os.path.exists(anno_path):
- info(
- f"The annotation file {anno_path} don't exists, has been ignored!"
- )
- continue
- image_info_list, category_list, max_image_id = json2list(anno_path,
- max_image_id)
- all_image_info_list.extend(image_info_list)
- for category in category_list:
- if category['id'] not in all_category_dict:
- all_category_dict[category['id']] = category
- total_num = len(all_image_info_list)
- random.shuffle(all_image_info_list)
- all_category_list = [all_category_dict[k] for k in all_category_dict]
- start = 0
- for fn, rate in [("instance_train.json", train_rate),
- ("instance_val.json", val_rate)]:
- end = start + round(total_num * rate / 100)
- save_path = os.path.join(root_dir, "annotations", fn)
- if os.path.exists(save_path):
- bak_path = save_path + ".bak"
- shutil.move(save_path, bak_path)
- info(
- f"The original annotation file {fn} has been backed up to {bak_path}."
- )
- assemble_write(all_image_info_list[start:end], all_category_list,
- save_path)
- start = end
- return root_dir
- def json2list(json_path, base_image_num):
- """ load json as list """
- assert os.path.exists(json_path), json_path
- with custom_open(json_path, 'r') as f:
- data = json.load(f)
- image_info_dict = {}
- max_image_id = 0
- for image_info in data['images']:
- # 得到全局唯一的image_id
- global_image_id = image_info['id'] + base_image_num
- max_image_id = max(global_image_id, max_image_id)
- image_info['id'] = global_image_id
- image_info_dict[global_image_id] = {"img": image_info, 'anno': []}
- image_info_dict = {
- image_info['id']: {
- "img": image_info,
- 'anno': []
- }
- for image_info in data['images']
- }
- info(f"Start loading annotation file {json_path}...")
- for anno in tqdm(data['annotations']):
- global_image_id = anno['image_id'] + base_image_num
- anno['image_id'] = global_image_id
- image_info_dict[global_image_id]['anno'].append(anno)
- image_info_list = [(image_info_dict[image_info]['img'],
- image_info_dict[image_info]['anno'])
- for image_info in image_info_dict]
- return image_info_list, data['categories'], max_image_id
- def assemble_write(image_info_list, category_list, save_path):
- """ assemble coco format and save to file """
- coco_data = {'categories': category_list}
- image_list = [i[0] for i in image_info_list]
- all_anno_list = []
- for i in image_info_list:
- all_anno_list.extend(i[1])
- anno_list = []
- for i, anno in enumerate(all_anno_list):
- anno['id'] = i + 1
- anno_list.append(anno)
- coco_data['images'] = image_list
- coco_data['annotations'] = anno_list
- write_json_file(coco_data, save_path)
- info(f"The splited annotations has been save to {save_path}.")
|