file_interface.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. from threading import ThreadError
  17. from filelock import FileLock
  18. import yaml
  19. import ruamel.yaml
  20. import chardet
  21. try:
  22. import ujson as json
  23. except:
  24. logging.error("failed to import ujson, using json instead")
  25. import json
  26. from contextlib import contextmanager
  27. @contextmanager
  28. def custom_open(file_path, mode):
  29. """
  30. 自定义打开文件函数
  31. Args:
  32. file_path (str): 文件路径
  33. mode (str): 文件打开模式,'r','w' 或 'a'
  34. Returns:
  35. Any: 返回文件对象
  36. Raises:
  37. FileNotFoundError: 当文件不存在时,raise FileNotFoundError
  38. ValueError: 当 mode 参数不是 'r', 'w' 和 'a' 时,raise ValueError
  39. """
  40. if mode == 'r':
  41. if not os.path.exists(file_path):
  42. raise FileNotFoundError("file {} not found".format(file_path))
  43. file = open(file_path, "r", encoding='utf-8')
  44. try:
  45. file.read()
  46. file.seek(0)
  47. yield file
  48. except UnicodeDecodeError:
  49. file = open(file_path, "r", encoding='gbk')
  50. try:
  51. file.read()
  52. file.seek(0)
  53. yield file
  54. except UnicodeDecodeError:
  55. with open(file_path, 'rb') as f:
  56. encoding = chardet.detect(f.read())['encoding']
  57. file = open(file_path, "r", encoding=encoding)
  58. yield file
  59. finally:
  60. file.close()
  61. elif mode == 'w':
  62. file = open(file_path, "w", encoding='utf-8')
  63. yield file
  64. file.close()
  65. elif mode == 'a':
  66. encoding = 'utf-8'
  67. if os.path.exists(file_path):
  68. file = open(file_path, "r", encoding=encoding)
  69. try:
  70. file.read()
  71. file.seek(0)
  72. except UnicodeDecodeError:
  73. encoding = 'gbk'
  74. file = open(file_path, "r", encoding=encoding)
  75. try:
  76. file.read()
  77. file.seek(0)
  78. except UnicodeDecodeError:
  79. with open(file_path, 'rb') as f:
  80. encoding = chardet.detect(f.read())['encoding']
  81. finally:
  82. file.close()
  83. file = open(file_path, "a", encoding=encoding)
  84. yield file
  85. file.close()
  86. else:
  87. raise ValueError("mode must be 'r', 'w' or 'a', but got {}".format(
  88. mode))
  89. # --------------- yaml ---------------
  90. def read_yaml_file(yaml_path: str, to_dict=True):
  91. """read from yaml file"""
  92. try:
  93. with open(yaml_path, "r", encoding='utf-8') as file:
  94. yaml_content = yaml.full_load(file)
  95. except UnicodeDecodeError:
  96. with open(yaml_path, "r", encoding='gbk') as file:
  97. yaml_content = yaml.full_load(file)
  98. yaml_content = dict(yaml_content) if to_dict else yaml_content
  99. return yaml_content
  100. def write_config_file(yaml_dict: dict, yaml_path: str):
  101. """write to config yaml file"""
  102. yaml = ruamel.yaml.YAML()
  103. lock = FileLock(yaml_path + '.lock')
  104. with lock:
  105. with open(yaml_path, "w", encoding='utf-8') as file:
  106. # yaml.safe_dump(yaml_dict, file, sort_keys=False)
  107. yaml.dump(yaml_dict, file)
  108. def update_yaml_file_with_dict(yaml_path, key_values: dict):
  109. """ update yaml file with key_values
  110. key_values is a dict
  111. """
  112. yaml_dict = read_yaml_file(yaml_path)
  113. yaml_dict.update(key_values)
  114. write_config_file(yaml_dict, yaml_path)
  115. def get_yaml_keys(yaml_path):
  116. """get all keys of yaml file"""
  117. yaml_dict = read_yaml_file(yaml_path)
  118. return yaml_dict.keys()
  119. # --------------- markdown ---------------
  120. def generate_markdown_from_dict(metrics):
  121. """ generate_markdown_from_dict """
  122. mk = ""
  123. keys = metrics.keys()
  124. mk += "| ".join(keys())
  125. mk += os.linesep
  126. mk += "|".join([' :----: '])
  127. # ------------------- jsonl ---------------------
  128. def read_jsonl_file(jsonl_path: str):
  129. """read from jsonl file"""
  130. with custom_open(jsonl_path, "r") as file:
  131. jsonl_content = [json.loads(line) for line in file]
  132. return jsonl_content
  133. def write_json_file(content, jsonl_path: str, ensure_ascii=False, **kwargs):
  134. """write to json file"""
  135. with custom_open(jsonl_path, "w") as file:
  136. json.dump(content, file, ensure_ascii=ensure_ascii, **kwargs)
  137. # --------------- check webui yaml -----------------
  138. def check_dict_keys(to_checked_dict, standard_dict, escape_list=None):
  139. """check if all keys of to_checked_dict is the same as standard_dict, and the value is the same type
  140. Args:
  141. escape_list: if set, will not check the keys in white_list
  142. """
  143. escape_list = [] if escape_list is None else escape_list
  144. for key in standard_dict.keys():
  145. if key not in to_checked_dict:
  146. logging.error(f"key {key} not in yaml file")
  147. return False
  148. if not isinstance(standard_dict[key], type(to_checked_dict[key])):
  149. logging.error(
  150. f"value type of key {key} is not the same as standard: "
  151. f"{type(standard_dict[key])}, {type(to_checked_dict[key])}")
  152. return False
  153. if isinstance(standard_dict[key], dict) and isinstance(
  154. to_checked_dict[key], dict) and key not in escape_list:
  155. return check_dict_keys(to_checked_dict[key], standard_dict[key],
  156. escape_list)
  157. if len(to_checked_dict.keys()) != len(standard_dict.keys()):
  158. logging.error(f"yaml file has extra keys")
  159. return False
  160. return True
  161. def check_dataset_valid(path_list):
  162. """check if dataset valid in path_list for datset_ui"""
  163. if path_list is not None and len(path_list) > 0:
  164. for path in path_list:
  165. if not os.path.exists(path):
  166. return False
  167. return True
  168. else:
  169. return False