split.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .dataset_split import *
  15. from paddlex.utils import logging
  16. def dataset_split(dataset_dir, dataset_format, val_value, test_value,
  17. save_dir):
  18. logging.info("Dataset split starts...")
  19. if dataset_format == "coco":
  20. train_num, val_num, test_num = split_coco_dataset(
  21. dataset_dir, val_value, test_value, save_dir)
  22. elif dataset_format == "voc":
  23. train_num, val_num, test_num = split_voc_dataset(
  24. dataset_dir, val_value, test_value, save_dir)
  25. elif dataset_format == "seg":
  26. train_num, val_num, test_num = split_seg_dataset(
  27. dataset_dir, val_value, test_value, save_dir)
  28. elif dataset_format == "imagenet":
  29. train_num, val_num, test_num = split_imagenet_dataset(
  30. dataset_dir, val_value, test_value, save_dir)
  31. else:
  32. raise Exception("Dataset format {} is not supported.".format(
  33. dataset_format))
  34. logging.info("Dataset split done.")
  35. logging.info("Train samples: {}".format(train_num))
  36. logging.info("Eval samples: {}".format(val_num))
  37. logging.info("Test samples: {}".format(test_num))
  38. logging.info("Split files saved in {}".format(save_dir))