split_dataset.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import pandas as pd
  19. from .....utils.logging import info
  20. def split_dataset(root_dir, train_rate, val_rate, group_id="group_id"):
  21. """split dataset"""
  22. assert (
  23. train_rate + val_rate == 100
  24. ), f"The sum of train_rate({train_rate}) and val_rate({val_rate}) should equal 100!"
  25. assert (
  26. train_rate > 0 and val_rate > 0
  27. ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  28. tags = ["train.csv", "val.csv"]
  29. df = pd.DataFrame()
  30. group_unique = None
  31. for tag in tags:
  32. if os.path.exists(osp.join(root_dir, tag)):
  33. df_one = pd.read_csv(osp.join(root_dir, tag))
  34. cols = df_one.columns.values.tolist()
  35. assert (
  36. group_id in cols
  37. ), f"The default group_id '{group_id}' is not found in the df columns."
  38. if df.empty:
  39. df = df_one
  40. group_unique = sorted(df[group_id].unique())
  41. else:
  42. group_unique_one = sorted(df_one[group_id].unique())
  43. for id in group_unique_one:
  44. if id in group_unique:
  45. df_one[group_id].replace(id, str(id) + "_", inplace=True)
  46. group_unique.append(str(id) + "_")
  47. df = pd.concat([df, df_one], axis=0)
  48. df = df.drop_duplicates(keep="first")
  49. group_unique = df[group_id].unique()
  50. dfs = [] # separate multiple group
  51. for column in group_unique:
  52. df_one = df[df[group_id].isin([column])]
  53. df_one = df_one.drop_duplicates(subset=["time"], keep="first")
  54. dfs.append(df_one)
  55. group_len = len(dfs)
  56. point_train = math.floor((group_len * train_rate / 100))
  57. point_val = math.floor((group_len * (train_rate + val_rate) / 100))
  58. assert point_train > 0, f"The train_len is 0, the train_percent should be greater ."
  59. assert (
  60. point_val - point_train > 0
  61. ), f"The train_len is 0, the val_percent should be greater ."
  62. train_df = pd.concat(dfs[:point_train], axis=0)
  63. val_df = pd.concat(dfs[point_train:point_val], axis=0)
  64. df_dict = {"train.csv": train_df, "val.csv": val_df}
  65. if point_val < group_len - 1:
  66. test_df = pd.concat(dfs[point_val:], axis=0)
  67. df_dict.update({"test.csv": test_df})
  68. for tag in df_dict.keys():
  69. save_path = osp.join(root_dir, tag)
  70. if os.path.exists(save_path):
  71. bak_path = save_path + ".bak"
  72. shutil.move(save_path, bak_path)
  73. info(
  74. f"The original annotation file {tag} has been backed up to {bak_path}."
  75. )
  76. df_dict[tag].to_csv(save_path, index=False)
  77. return root_dir