config_reader.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import json
  3. import os
  4. import torch
  5. from loguru import logger
  6. # 定义配置文件名常量
  7. CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
  8. def read_config():
  9. if os.path.isabs(CONFIG_FILE_NAME):
  10. config_file = CONFIG_FILE_NAME
  11. else:
  12. home_dir = os.path.expanduser('~')
  13. config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
  14. if not os.path.exists(config_file):
  15. logger.warning(f'{config_file} not found, using default configuration')
  16. return None
  17. else:
  18. with open(config_file, 'r', encoding='utf-8') as f:
  19. config = json.load(f)
  20. return config
  21. def get_s3_config(bucket_name: str):
  22. """~/magic-pdf.json 读出来."""
  23. config = read_config()
  24. bucket_info = config.get('bucket_info')
  25. if bucket_name not in bucket_info:
  26. access_key, secret_key, storage_endpoint = bucket_info['[default]']
  27. else:
  28. access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
  29. if access_key is None or secret_key is None or storage_endpoint is None:
  30. raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
  31. # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
  32. return access_key, secret_key, storage_endpoint
  33. def get_s3_config_dict(path: str):
  34. access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
  35. return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
  36. def get_bucket_name(path):
  37. bucket, key = parse_bucket_key(path)
  38. return bucket
  39. def parse_bucket_key(s3_full_path: str):
  40. """
  41. 输入 s3://bucket/path/to/my/file.txt
  42. 输出 bucket, path/to/my/file.txt
  43. """
  44. s3_full_path = s3_full_path.strip()
  45. if s3_full_path.startswith("s3://"):
  46. s3_full_path = s3_full_path[5:]
  47. if s3_full_path.startswith("/"):
  48. s3_full_path = s3_full_path[1:]
  49. bucket, key = s3_full_path.split("/", 1)
  50. return bucket, key
  51. def get_local_models_dir():
  52. config = read_config()
  53. models_dir = config.get('models-dir')
  54. if models_dir is None:
  55. logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
  56. return '/tmp/models'
  57. else:
  58. return models_dir
  59. def get_local_layoutreader_model_dir():
  60. config = read_config()
  61. layoutreader_model_dir = config.get('layoutreader-model-dir')
  62. if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
  63. home_dir = os.path.expanduser('~')
  64. layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
  65. logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
  66. return layoutreader_at_modelscope_dir_path
  67. else:
  68. return layoutreader_model_dir
  69. def get_device():
  70. device_mode = os.getenv('MINERU_DEVICE_MODE', None)
  71. if device_mode is not None:
  72. return device_mode
  73. else:
  74. if torch.cuda.is_available():
  75. return "cuda"
  76. if torch.backends.mps.is_available():
  77. return "mps"
  78. return "cpu"
  79. def get_table_recog_config():
  80. table_enable = os.getenv('MINERU_TABLE_ENABLE', None)
  81. if table_enable is not None:
  82. return json.loads(f'{{"enable": {table_enable}}}')
  83. else:
  84. logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.")
  85. return json.loads(f'{{"enable": true}}')
  86. def get_formula_config():
  87. formula_enable = os.getenv('MINERU_FORMULA_ENABLE', None)
  88. if formula_enable is not None:
  89. return json.loads(f'{{"enable": {formula_enable}}}')
  90. else:
  91. logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.")
  92. return json.loads(f'{{"enable": true}}')
  93. def get_latex_delimiter_config():
  94. config = read_config()
  95. latex_delimiter_config = config.get('latex-delimiter-config')
  96. if latex_delimiter_config is None:
  97. logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
  98. return None
  99. else:
  100. return latex_delimiter_config
  101. def get_llm_aided_config():
  102. config = read_config()
  103. llm_aided_config = config.get('llm-aided-config')
  104. if llm_aided_config is None:
  105. logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
  106. return None
  107. else:
  108. return llm_aided_config