convert_dataset.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pandas as pd
  16. from .....utils.errors import ConvertFailedError
  17. def check_src_dataset(root_dir):
  18. """check src dataset format validity"""
  19. err_msg_prefix = f"数据格式转换失败!当前仅支持后续为'.xlsx/.xls'格式的数据转换。"
  20. for dst_anno, src_anno in [("train.xlsx", "train.xls"), ("val.xlsx", "val.xls")]:
  21. src_anno_path = os.path.join(root_dir, src_anno)
  22. dst_anno_path = os.path.join(root_dir, dst_anno)
  23. if not os.path.exists(src_anno_path) and not os.path.exists(dst_anno_path):
  24. if "train" in dst_anno:
  25. raise ConvertFailedError(
  26. message=f"{err_msg_prefix}保证{src_anno_path}或{dst_anno_path}文件存在。"
  27. )
  28. continue
  29. def convert_excel_dataset(input_dir):
  30. """
  31. 将excel标注的数据集转换为PaddleX需要的格式
  32. Args:
  33. input_dir (str): 输入的目录,包含多个json格式的Labelme标注文件
  34. Returns:
  35. str: 返回一个字符串表示转换的结果,“转换成功”表示转换没有问题。
  36. Raises:
  37. 该函数目前没有特定的异常抛出。
  38. """
  39. # read excel file
  40. for dst_anno, src_anno in [("train.xlsx", "train.xls"), ("val.xlsx", "val.xls")]:
  41. src_anno_path = os.path.join(input_dir, src_anno)
  42. dst_anno_path = os.path.join(input_dir, dst_anno)
  43. if os.path.exists(src_anno_path):
  44. excel_file = pd.read_excel(src_anno_path)
  45. output_csv_dir = os.path.join(input_dir, src_anno.replace(".xlsx", ".csv"))
  46. excel_file.to_csv(output_csv_dir, index=False)
  47. if os.path.exists(dst_anno_path):
  48. excel_file = pd.read_excel(dst_anno_path)
  49. output_csv_dir = os.path.join(input_dir, dst_anno.replace(".xlsx", ".csv"))
  50. excel_file.to_csv(output_csv_dir, index=False)
  51. def convert(input_dir):
  52. """convert dataset to coco format"""
  53. # check format validity
  54. check_src_dataset(input_dir)
  55. convert_excel_dataset(input_dir)
  56. return input_dir