file_interface.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. from threading import ThreadError
  17. from filelock import FileLock
  18. import yaml
  19. import ruamel.yaml
  20. import chardet
  21. try:
  22. import ujson as json
  23. except:
  24. logging.warning("failed to import ujson, using json instead")
  25. import json
  26. from contextlib import contextmanager
  27. @contextmanager
  28. def custom_open(file_path, mode):
  29. """
  30. 自定义打开文件函数
  31. Args:
  32. file_path (str): 文件路径
  33. mode (str): 文件打开模式,'r','w' 或 'a'
  34. Returns:
  35. Any: 返回文件对象
  36. Raises:
  37. FileNotFoundError: 当文件不存在时,raise FileNotFoundError
  38. ValueError: 当 mode 参数不是 'r', 'w' 和 'a' 时,raise ValueError
  39. """
  40. if mode == "r":
  41. if not os.path.exists(file_path):
  42. raise FileNotFoundError("file {} not found".format(file_path))
  43. file = open(file_path, "r", encoding="utf-8")
  44. try:
  45. file.read()
  46. file.seek(0)
  47. yield file
  48. except UnicodeDecodeError:
  49. file = open(file_path, "r", encoding="gbk")
  50. try:
  51. file.read()
  52. file.seek(0)
  53. yield file
  54. except UnicodeDecodeError:
  55. with open(file_path, "rb") as f:
  56. encoding = chardet.detect(f.read())["encoding"]
  57. file = open(file_path, "r", encoding=encoding)
  58. yield file
  59. finally:
  60. file.close()
  61. elif mode == "w":
  62. file = open(file_path, "w", encoding="utf-8")
  63. yield file
  64. file.close()
  65. elif mode == "a":
  66. encoding = "utf-8"
  67. if os.path.exists(file_path):
  68. file = open(file_path, "r", encoding=encoding)
  69. try:
  70. file.read()
  71. file.seek(0)
  72. except UnicodeDecodeError:
  73. encoding = "gbk"
  74. file = open(file_path, "r", encoding=encoding)
  75. try:
  76. file.read()
  77. file.seek(0)
  78. except UnicodeDecodeError:
  79. with open(file_path, "rb") as f:
  80. encoding = chardet.detect(f.read())["encoding"]
  81. finally:
  82. file.close()
  83. file = open(file_path, "a", encoding=encoding)
  84. yield file
  85. file.close()
  86. else:
  87. raise ValueError("mode must be 'r', 'w' or 'a', but got {}".format(mode))
  88. # --------------- yaml ---------------
  89. def read_yaml_file(yaml_path: str, to_dict=True):
  90. """read from yaml file"""
  91. try:
  92. with open(yaml_path, "r", encoding="utf-8") as file:
  93. yaml_content = yaml.full_load(file)
  94. except UnicodeDecodeError:
  95. with open(yaml_path, "r", encoding="gbk") as file:
  96. yaml_content = yaml.full_load(file)
  97. yaml_content = dict(yaml_content) if to_dict else yaml_content
  98. return yaml_content
  99. def write_config_file(yaml_dict: dict, yaml_path: str):
  100. """write to config yaml file"""
  101. yaml = ruamel.yaml.YAML()
  102. lock = FileLock(yaml_path + ".lock")
  103. with lock:
  104. with open(yaml_path, "w", encoding="utf-8") as file:
  105. # yaml.safe_dump(yaml_dict, file, sort_keys=False)
  106. yaml.dump(yaml_dict, file)
  107. def update_yaml_file_with_dict(yaml_path, key_values: dict):
  108. """update yaml file with key_values
  109. key_values is a dict
  110. """
  111. yaml_dict = read_yaml_file(yaml_path)
  112. yaml_dict.update(key_values)
  113. write_config_file(yaml_dict, yaml_path)
  114. def get_yaml_keys(yaml_path):
  115. """get all keys of yaml file"""
  116. yaml_dict = read_yaml_file(yaml_path)
  117. return yaml_dict.keys()
  118. # --------------- markdown ---------------
  119. def generate_markdown_from_dict(metrics):
  120. """generate_markdown_from_dict"""
  121. mk = ""
  122. keys = metrics.keys()
  123. mk += "| ".join(keys())
  124. mk += os.linesep
  125. mk += "|".join([" :----: "])
  126. # ------------------- jsonl ---------------------
  127. def read_jsonl_file(jsonl_path: str):
  128. """read from jsonl file"""
  129. with custom_open(jsonl_path, "r") as file:
  130. jsonl_content = [json.loads(line) for line in file]
  131. return jsonl_content
  132. def write_json_file(content, jsonl_path: str, ensure_ascii=False, **kwargs):
  133. """write to json file"""
  134. with custom_open(jsonl_path, "w") as file:
  135. json.dump(content, file, ensure_ascii=ensure_ascii, **kwargs)
  136. # --------------- check webui yaml -----------------
  137. def check_dict_keys(to_checked_dict, standard_dict, escape_list=None):
  138. """check if all keys of to_checked_dict is the same as standard_dict, and the value is the same type
  139. Args:
  140. escape_list: if set, will not check the keys in white_list
  141. """
  142. escape_list = [] if escape_list is None else escape_list
  143. for key in standard_dict.keys():
  144. if key not in to_checked_dict:
  145. logging.error(f"key {key} not in yaml file")
  146. return False
  147. if not isinstance(standard_dict[key], type(to_checked_dict[key])):
  148. logging.error(
  149. f"value type of key {key} is not the same as standard: "
  150. f"{type(standard_dict[key])}, {type(to_checked_dict[key])}"
  151. )
  152. return False
  153. if (
  154. isinstance(standard_dict[key], dict)
  155. and isinstance(to_checked_dict[key], dict)
  156. and key not in escape_list
  157. ):
  158. return check_dict_keys(
  159. to_checked_dict[key], standard_dict[key], escape_list
  160. )
  161. if len(to_checked_dict.keys()) != len(standard_dict.keys()):
  162. logging.error(f"yaml file has extra keys")
  163. return False
  164. return True
  165. def check_dataset_valid(path_list):
  166. """check if dataset valid in path_list for datset_ui"""
  167. if path_list is not None and len(path_list) > 0:
  168. for path in path_list:
  169. if not os.path.exists(path):
  170. return False
  171. return True
  172. else:
  173. return False