config_reader.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. 根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
  3. """
  4. import json
  5. import os
  6. from loguru import logger
  7. from magic_pdf.libs.commons import parse_bucket_key
  8. # 定义配置文件名常量
  9. CONFIG_FILE_NAME = "magic-pdf.json"
  10. def read_config():
  11. home_dir = os.path.expanduser("~")
  12. config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
  13. if not os.path.exists(config_file):
  14. raise FileNotFoundError(f"{config_file} not found")
  15. with open(config_file, "r", encoding="utf-8") as f:
  16. config = json.load(f)
  17. return config
  18. def get_s3_config(bucket_name: str):
  19. """
  20. ~/magic-pdf.json 读出来
  21. """
  22. config = read_config()
  23. bucket_info = config.get("bucket_info")
  24. if bucket_name not in bucket_info:
  25. access_key, secret_key, storage_endpoint = bucket_info["[default]"]
  26. else:
  27. access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
  28. if access_key is None or secret_key is None or storage_endpoint is None:
  29. raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
  30. # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
  31. return access_key, secret_key, storage_endpoint
  32. def get_s3_config_dict(path: str):
  33. access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
  34. return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
  35. def get_bucket_name(path):
  36. bucket, key = parse_bucket_key(path)
  37. return bucket
  38. def get_local_models_dir():
  39. config = read_config()
  40. models_dir = config.get("models-dir")
  41. if models_dir is None:
  42. logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
  43. return "/tmp/models"
  44. else:
  45. return models_dir
  46. def get_device():
  47. config = read_config()
  48. device = config.get("device-mode")
  49. if device is None:
  50. logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
  51. return "cpu"
  52. else:
  53. return device
  54. def get_table_recog_config():
  55. config = read_config()
  56. table_config = config.get("table-config")
  57. return table_config
  58. if __name__ == "__main__":
  59. ak, sk, endpoint = get_s3_config("llm-raw")