| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import os.path as osp
- import shutil
- import random
- import math
- import pandas as pd
- from tqdm import tqdm
- from .....utils.logging import info
- def split_dataset(root_dir, train_rate, val_rate, group_id='group_id'):
- """ split dataset """
- assert train_rate + val_rate == 100, \
- f"The sum of train_rate({train_rate}) and val_rate({val_rate}) should equal 100!"
- assert train_rate > 0 and val_rate > 0, \
- f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
- tags = ['train.csv', 'val.csv']
- df = pd.DataFrame()
- group_unique = None
- for tag in tags:
- if os.path.exists(osp.join(root_dir, tag)):
- df_one = pd.read_csv(osp.join(root_dir, tag))
- cols = df_one.columns.values.tolist()
- assert group_id in cols, f"The default group_id '{group_id}' is not found in the df columns."
- if df.empty:
- df = df_one
- group_unique = sorted(df[group_id].unique())
- else:
- group_unique_one = sorted(df_one[group_id].unique())
- for id in group_unique_one:
- if id in group_unique:
- df_one[group_id].replace(id, str(id) + "_", inplace=True)
- group_unique.append(str(id) + "_")
- df = pd.concat([df, df_one], axis=0)
- df = df.drop_duplicates(keep='first')
- group_unique = df[group_id].unique()
- dfs = [] # seperate multiple group
- for column in group_unique:
- df_one = df[df[group_id].isin([column])]
- df_one = df_one.drop_duplicates(subset=['time'], keep='first')
- dfs.append(df_one)
- group_len = len(dfs)
- point_train = math.floor((group_len * train_rate / 100))
- point_val = math.floor((group_len * (train_rate + val_rate) / 100))
- assert point_train > 0, f"The train_len is 0, the train_percent should be greater ."
- assert point_val - point_train > 0, f"The train_len is 0, the val_percent should be greater ."
- train_df = pd.concat(dfs[:point_train], axis=0)
- val_df = pd.concat(dfs[point_train:point_val], axis=0)
- df_dict = {'train.csv': train_df, 'val.csv': val_df}
- if point_val < group_len - 1:
- test_df = pd.concat(dfs[point_val:], axis=0)
- df_dict.update({'test.csv': test_df})
- for tag in df_dict.keys():
- save_path = osp.join(root_dir, tag)
- if os.path.exists(save_path):
- bak_path = save_path + ".bak"
- shutil.move(save_path, bak_path)
- info(
- f"The original annotation file {tag} has been backed up to {bak_path}."
- )
- df_dict[tag].to_csv(save_path, index=False)
- return root_dir
|