split_dataset.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import shutil
  17. import random
  18. import math
  19. import pandas as pd
  20. from tqdm import tqdm
  21. from .....utils.logging import info
  22. def split_dataset(root_dir, train_rate, val_rate, group_id='group_id'):
  23. """ split dataset """
  24. assert train_rate + val_rate == 100, \
  25. f"The sum of train_rate({train_rate}) and val_rate({val_rate}) should equal 100!"
  26. assert train_rate > 0 and val_rate > 0, \
  27. f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  28. tags = ['train.csv', 'val.csv']
  29. df = pd.DataFrame()
  30. group_unique = None
  31. for tag in tags:
  32. if os.path.exists(osp.join(root_dir, tag)):
  33. df_one = pd.read_csv(osp.join(root_dir, tag))
  34. cols = df_one.columns.values.tolist()
  35. assert group_id in cols, f"The default group_id '{group_id}' is not found in the df columns."
  36. if df.empty:
  37. df = df_one
  38. group_unique = sorted(df[group_id].unique())
  39. else:
  40. group_unique_one = sorted(df_one[group_id].unique())
  41. for id in group_unique_one:
  42. if id in group_unique:
  43. df_one[group_id].replace(id, str(id) + "_", inplace=True)
  44. group_unique.append(str(id) + "_")
  45. df = pd.concat([df, df_one], axis=0)
  46. df = df.drop_duplicates(keep='first')
  47. group_unique = df[group_id].unique()
  48. dfs = [] # seperate multiple group
  49. for column in group_unique:
  50. df_one = df[df[group_id].isin([column])]
  51. df_one = df_one.drop_duplicates(subset=['time'], keep='first')
  52. dfs.append(df_one)
  53. group_len = len(dfs)
  54. point_train = math.floor((group_len * train_rate / 100))
  55. point_val = math.floor((group_len * (train_rate + val_rate) / 100))
  56. assert point_train > 0, f"The train_len is 0, the train_percent should be greater ."
  57. assert point_val - point_train > 0, f"The train_len is 0, the val_percent should be greater ."
  58. train_df = pd.concat(dfs[:point_train], axis=0)
  59. val_df = pd.concat(dfs[point_train:point_val], axis=0)
  60. df_dict = {'train.csv': train_df, 'val.csv': val_df}
  61. if point_val < group_len - 1:
  62. test_df = pd.concat(dfs[point_val:], axis=0)
  63. df_dict.update({'test.csv': test_df})
  64. for tag in df_dict.keys():
  65. save_path = osp.join(root_dir, tag)
  66. if os.path.exists(save_path):
  67. bak_path = save_path + ".bak"
  68. shutil.move(save_path, bak_path)
  69. info(
  70. f"The original annotation file {tag} has been backed up to {bak_path}."
  71. )
  72. df_dict[tag].to_csv(save_path, index=False)
  73. return root_dir