file_interface.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import logging
  12. import os
  13. from threading import ThreadError
  14. from filelock import FileLock
  15. import yaml
  16. import ruamel.yaml
  17. import chardet
  18. try:
  19. import ujson as json
  20. except:
  21. logging.error("failed to import ujson, using json instead")
  22. import json
  23. from contextlib import contextmanager
  24. @contextmanager
  25. def custom_open(file_path, mode):
  26. """
  27. 自定义打开文件函数
  28. Args:
  29. file_path (str): 文件路径
  30. mode (str): 文件打开模式,'r','w' 或 'a'
  31. Returns:
  32. Any: 返回文件对象
  33. Raises:
  34. FileNotFoundError: 当文件不存在时,raise FileNotFoundError
  35. ValueError: 当 mode 参数不是 'r', 'w' 和 'a' 时,raise ValueError
  36. """
  37. if mode == 'r':
  38. if not os.path.exists(file_path):
  39. raise FileNotFoundError("file {} not found".format(file_path))
  40. file = open(file_path, "r", encoding='utf-8')
  41. try:
  42. file.read()
  43. file.seek(0)
  44. yield file
  45. except UnicodeDecodeError:
  46. file = open(file_path, "r", encoding='gbk')
  47. try:
  48. file.read()
  49. file.seek(0)
  50. yield file
  51. except UnicodeDecodeError:
  52. with open(file_path, 'rb') as f:
  53. encoding = chardet.detect(f.read())['encoding']
  54. file = open(file_path, "r", encoding=encoding)
  55. yield file
  56. finally:
  57. file.close()
  58. elif mode == 'w':
  59. file = open(file_path, "w", encoding='utf-8')
  60. yield file
  61. file.close()
  62. elif mode == 'a':
  63. encoding = 'utf-8'
  64. if os.path.exists(file_path):
  65. file = open(file_path, "r", encoding=encoding)
  66. try:
  67. file.read()
  68. file.seek(0)
  69. except UnicodeDecodeError:
  70. encoding = 'gbk'
  71. file = open(file_path, "r", encoding=encoding)
  72. try:
  73. file.read()
  74. file.seek(0)
  75. except UnicodeDecodeError:
  76. with open(file_path, 'rb') as f:
  77. encoding = chardet.detect(f.read())['encoding']
  78. finally:
  79. file.close()
  80. file = open(file_path, "a", encoding=encoding)
  81. yield file
  82. file.close()
  83. else:
  84. raise ValueError("mode must be 'r', 'w' or 'a', but got {}".format(
  85. mode))
  86. def read_yaml_file(yaml_path: str, to_dict=True):
  87. """read from yaml file"""
  88. try:
  89. with open(yaml_path, "r", encoding='utf-8') as file:
  90. yaml_content = yaml.full_load(file)
  91. except UnicodeDecodeError:
  92. with open(yaml_path, "r", encoding='gbk') as file:
  93. yaml_content = yaml.full_load(file)
  94. yaml_content = dict(yaml_content) if to_dict else yaml_content
  95. return yaml_content
  96. def write_config_file(yaml_dict: dict, yaml_path: str):
  97. """write to config yaml file"""
  98. yaml = ruamel.yaml.YAML()
  99. lock = FileLock(yaml_path + '.lock')
  100. with lock:
  101. with open(yaml_path, "w", encoding='utf-8') as file:
  102. # yaml.safe_dump(yaml_dict, file, sort_keys=False)
  103. yaml.dump(yaml_dict, file)
  104. def update_yaml_file_with_dict(yaml_path, key_values: dict):
  105. """ update yaml file with key_values
  106. key_values is a dict
  107. """
  108. yaml_dict = read_yaml_file(yaml_path)
  109. yaml_dict.update(key_values)
  110. write_config_file(yaml_dict, yaml_path)
  111. def get_yaml_keys(yaml_path):
  112. """get all keys of yaml file"""
  113. yaml_dict = read_yaml_file(yaml_path)
  114. return yaml_dict.keys()
  115. # --------------- markdown ---------------
  116. def generate_markdown_from_dict(metrics):
  117. """ generate_markdown_from_dict """
  118. mk = ""
  119. keys = metrics.keys()
  120. mk += "| ".join(keys())
  121. mk += os.linesep
  122. mk += "|".join([' :----: '])
  123. # ------------------- jsonl ---------------------
  124. def read_jsonl_file(jsonl_path: str):
  125. """read from jsonl file"""
  126. with custom_open(jsonl_path, "r") as file:
  127. jsonl_content = [json.loads(line) for line in file]
  128. return jsonl_content
  129. # --------------- check webui yaml -----------------
  130. def check_dict_keys(to_checked_dict, standard_dict, escape_list=None):
  131. """check if all keys of to_checked_dict is the same as standard_dict, and the value is the same type
  132. Args:
  133. escape_list: if set, will not check the keys in white_list
  134. """
  135. escape_list = [] if escape_list is None else escape_list
  136. for key in standard_dict.keys():
  137. if key not in to_checked_dict:
  138. logging.error(f"key {key} not in yaml file")
  139. return False
  140. if not isinstance(standard_dict[key], type(to_checked_dict[key])):
  141. logging.error(
  142. f"value type of key {key} is not the same as standard: "
  143. f"{type(standard_dict[key])}, {type(to_checked_dict[key])}")
  144. return False
  145. if isinstance(standard_dict[key], dict) and isinstance(
  146. to_checked_dict[key], dict) and key not in escape_list:
  147. return check_dict_keys(to_checked_dict[key], standard_dict[key],
  148. escape_list)
  149. if len(to_checked_dict.keys()) != len(standard_dict.keys()):
  150. logging.error(f"yaml file has extra keys")
  151. return False
  152. return True
  153. def check_dataset_valid(path_list):
  154. """check if dataset valid in path_list for datset_ui"""
  155. if path_list is not None and len(path_list) > 0:
  156. for path in path_list:
  157. if not os.path.exists(path):
  158. return False
  159. return True
  160. else:
  161. return False