config_reader.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. """根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
  2. import json
  3. import os
  4. from loguru import logger
  5. from magic_pdf.config.constants import MODEL_NAME
  6. from magic_pdf.libs.commons import parse_bucket_key
  7. # 定义配置文件名常量
  8. CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
  9. def read_config():
  10. if os.path.isabs(CONFIG_FILE_NAME):
  11. config_file = CONFIG_FILE_NAME
  12. else:
  13. home_dir = os.path.expanduser('~')
  14. config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
  15. if not os.path.exists(config_file):
  16. raise FileNotFoundError(f'{config_file} not found')
  17. with open(config_file, 'r', encoding='utf-8') as f:
  18. config = json.load(f)
  19. return config
  20. def get_s3_config(bucket_name: str):
  21. """~/magic-pdf.json 读出来."""
  22. config = read_config()
  23. bucket_info = config.get('bucket_info')
  24. if bucket_name not in bucket_info:
  25. access_key, secret_key, storage_endpoint = bucket_info['[default]']
  26. else:
  27. access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
  28. if access_key is None or secret_key is None or storage_endpoint is None:
  29. raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
  30. # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
  31. return access_key, secret_key, storage_endpoint
  32. def get_s3_config_dict(path: str):
  33. access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
  34. return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
  35. def get_bucket_name(path):
  36. bucket, key = parse_bucket_key(path)
  37. return bucket
  38. def get_local_models_dir():
  39. config = read_config()
  40. models_dir = config.get('models-dir')
  41. if models_dir is None:
  42. logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
  43. return '/tmp/models'
  44. else:
  45. return models_dir
  46. def get_local_layoutreader_model_dir():
  47. config = read_config()
  48. layoutreader_model_dir = config.get('layoutreader-model-dir')
  49. if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
  50. home_dir = os.path.expanduser('~')
  51. layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
  52. logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
  53. return layoutreader_at_modelscope_dir_path
  54. else:
  55. return layoutreader_model_dir
  56. def get_device():
  57. config = read_config()
  58. device = config.get('device-mode')
  59. if device is None:
  60. logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
  61. return 'cpu'
  62. else:
  63. return device
  64. def get_table_recog_config():
  65. config = read_config()
  66. table_config = config.get('table-config')
  67. if table_config is None:
  68. logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
  69. return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
  70. else:
  71. return table_config
  72. def get_layout_config():
  73. config = read_config()
  74. layout_config = config.get('layout-config')
  75. if layout_config is None:
  76. logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
  77. return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
  78. else:
  79. return layout_config
  80. def get_formula_config():
  81. config = read_config()
  82. formula_config = config.get('formula-config')
  83. if formula_config is None:
  84. logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
  85. return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
  86. else:
  87. return formula_config
  88. def get_llm_aided_config():
  89. config = read_config()
  90. llm_aided_config = config.get('llm-aided-config')
  91. if llm_aided_config is None:
  92. logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
  93. return None
  94. else:
  95. return llm_aided_config
  96. if __name__ == '__main__':
  97. ak, sk, endpoint = get_s3_config('llm-raw')