# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path as osp import shutil import random import math import pandas as pd from tqdm import tqdm from .....utils.logging import info def split_dataset(root_dir, train_rate, val_rate, group_id='group_id'): """ split dataset """ assert train_rate + val_rate == 100, \ f"The sum of train_rate({train_rate}) and val_rate({val_rate}) should equal 100!" assert train_rate > 0 and val_rate > 0, \ f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!" tags = ['train.csv', 'val.csv'] df = pd.DataFrame() group_unique = None for tag in tags: if os.path.exists(osp.join(root_dir, tag)): df_one = pd.read_csv(osp.join(root_dir, tag)) cols = df_one.columns.values.tolist() assert group_id in cols, f"The default group_id '{group_id}' is not found in the df columns." if df.empty: df = df_one group_unique = sorted(df[group_id].unique()) else: group_unique_one = sorted(df_one[group_id].unique()) for id in group_unique_one: if id in group_unique: df_one[group_id].replace(id, str(id) + "_", inplace=True) group_unique.append(str(id) + "_") df = pd.concat([df, df_one], axis=0) df = df.drop_duplicates(keep='first') group_unique = df[group_id].unique() dfs = [] # seperate multiple group for column in group_unique: df_one = df[df[group_id].isin([column])] df_one = df_one.drop_duplicates(subset=['time'], keep='first') dfs.append(df_one) group_len = len(dfs) point_train = math.floor((group_len * train_rate / 100)) point_val = math.floor((group_len * (train_rate + val_rate) / 100)) assert point_train > 0, f"The train_len is 0, the train_percent should be greater ." assert point_val - point_train > 0, f"The train_len is 0, the val_percent should be greater ." train_df = pd.concat(dfs[:point_train], axis=0) val_df = pd.concat(dfs[point_train:point_val], axis=0) df_dict = {'train.csv': train_df, 'val.csv': val_df} if point_val < group_len - 1: test_df = pd.concat(dfs[point_val:], axis=0) df_dict.update({'test.csv': test_df}) for tag in df_dict.keys(): save_path = osp.join(root_dir, tag) if os.path.exists(save_path): bak_path = save_path + ".bak" shutil.move(save_path, bak_path) info( f"The original annotation file {tag} has been backed up to {bak_path}." ) df_dict[tag].to_csv(save_path, index=False) return root_dir