config_reader.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """
  2. 根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
  3. """
  4. import json
  5. import os
  6. from loguru import logger
  7. from magic_pdf.libs.commons import parse_bucket_key
  8. # 定义配置文件名常量
  9. CONFIG_FILE_NAME = "magic-pdf.json"
  10. def read_config():
  11. home_dir = os.path.expanduser("~")
  12. config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
  13. if not os.path.exists(config_file):
  14. raise FileNotFoundError(f"{config_file} not found")
  15. with open(config_file, "r", encoding="utf-8") as f:
  16. config = json.load(f)
  17. return config
  18. def get_s3_config(bucket_name: str):
  19. """
  20. ~/magic-pdf.json 读出来
  21. """
  22. config = read_config()
  23. bucket_info = config.get("bucket_info")
  24. if bucket_name not in bucket_info:
  25. access_key, secret_key, storage_endpoint = bucket_info["[default]"]
  26. else:
  27. access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
  28. if access_key is None or secret_key is None or storage_endpoint is None:
  29. raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
  30. # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
  31. return access_key, secret_key, storage_endpoint
  32. def get_s3_config_dict(path: str):
  33. access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
  34. return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
  35. def get_bucket_name(path):
  36. bucket, key = parse_bucket_key(path)
  37. return bucket
  38. def get_local_models_dir():
  39. config = read_config()
  40. models_dir = config.get("models-dir")
  41. if models_dir is None:
  42. logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
  43. return "/tmp/models"
  44. else:
  45. return models_dir
  46. def get_device():
  47. config = read_config()
  48. device = config.get("device-mode")
  49. if device is None:
  50. logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
  51. return "cpu"
  52. else:
  53. return device
  54. def get_table_recog_config():
  55. config = read_config()
  56. table_config = config.get("table-config")
  57. if table_config is None:
  58. logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
  59. return json.loads('{"is_table_recog_enable": false, "max_time": 400}')
  60. else:
  61. return table_config
  62. if __name__ == "__main__":
  63. ak, sk, endpoint = get_s3_config("llm-raw")