config_reader.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. """
  2. 根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
  3. """
  4. import json
  5. import os
  6. from loguru import logger
  7. from magic_pdf.libs.Constants import MODEL_NAME
  8. from magic_pdf.libs.commons import parse_bucket_key
  9. # 定义配置文件名常量
  10. CONFIG_FILE_NAME = "magic-pdf.json"
  11. def read_config():
  12. home_dir = os.path.expanduser("~")
  13. config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
  14. if not os.path.exists(config_file):
  15. raise FileNotFoundError(f"{config_file} not found")
  16. with open(config_file, "r", encoding="utf-8") as f:
  17. config = json.load(f)
  18. return config
  19. def get_s3_config(bucket_name: str):
  20. """
  21. ~/magic-pdf.json 读出来
  22. """
  23. config = read_config()
  24. bucket_info = config.get("bucket_info")
  25. if bucket_name not in bucket_info:
  26. access_key, secret_key, storage_endpoint = bucket_info["[default]"]
  27. else:
  28. access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
  29. if access_key is None or secret_key is None or storage_endpoint is None:
  30. raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
  31. # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
  32. return access_key, secret_key, storage_endpoint
  33. def get_s3_config_dict(path: str):
  34. access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
  35. return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
  36. def get_bucket_name(path):
  37. bucket, key = parse_bucket_key(path)
  38. return bucket
  39. def get_local_models_dir():
  40. config = read_config()
  41. models_dir = config.get("models-dir")
  42. if models_dir is None:
  43. logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
  44. return "/tmp/models"
  45. else:
  46. return models_dir
  47. def get_local_layoutreader_model_dir():
  48. config = read_config()
  49. layoutreader_model_dir = config.get("layoutreader-model-dir")
  50. if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
  51. home_dir = os.path.expanduser("~")
  52. layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
  53. logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
  54. return layoutreader_at_modelscope_dir_path
  55. else:
  56. return layoutreader_model_dir
  57. def get_device():
  58. config = read_config()
  59. device = config.get("device-mode")
  60. if device is None:
  61. logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
  62. return "cpu"
  63. else:
  64. return device
  65. def get_table_recog_config():
  66. config = read_config()
  67. table_config = config.get("table-config")
  68. if table_config is None:
  69. logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
  70. return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
  71. else:
  72. return table_config
  73. def get_layout_config():
  74. config = read_config()
  75. layout_config = config.get("layout-config")
  76. if layout_config is None:
  77. logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
  78. return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
  79. else:
  80. return layout_config
  81. def get_formula_config():
  82. config = read_config()
  83. formula_config = config.get("formula-config")
  84. if formula_config is None:
  85. logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
  86. return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
  87. else:
  88. return formula_config
  89. if __name__ == "__main__":
  90. ak, sk, endpoint = get_s3_config("llm-raw")