config_reader.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import json
  3. import os
  4. from loguru import logger
  5. try:
  6. import torch
  7. import torch_npu
  8. except ImportError:
  9. pass
  10. # 定义配置文件名常量
  11. CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
  12. def read_config():
  13. if os.path.isabs(CONFIG_FILE_NAME):
  14. config_file = CONFIG_FILE_NAME
  15. else:
  16. home_dir = os.path.expanduser('~')
  17. config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
  18. if not os.path.exists(config_file):
  19. # logger.warning(f'{config_file} not found, using default configuration')
  20. return None
  21. else:
  22. with open(config_file, 'r', encoding='utf-8') as f:
  23. config = json.load(f)
  24. return config
  25. def get_s3_config(bucket_name: str):
  26. """~/magic-pdf.json 读出来."""
  27. config = read_config()
  28. bucket_info = config.get('bucket_info')
  29. if bucket_name not in bucket_info:
  30. access_key, secret_key, storage_endpoint = bucket_info['[default]']
  31. else:
  32. access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
  33. if access_key is None or secret_key is None or storage_endpoint is None:
  34. raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
  35. # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
  36. return access_key, secret_key, storage_endpoint
  37. def get_s3_config_dict(path: str):
  38. access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
  39. return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
  40. def get_bucket_name(path):
  41. bucket, key = parse_bucket_key(path)
  42. return bucket
  43. def parse_bucket_key(s3_full_path: str):
  44. """
  45. 输入 s3://bucket/path/to/my/file.txt
  46. 输出 bucket, path/to/my/file.txt
  47. """
  48. s3_full_path = s3_full_path.strip()
  49. if s3_full_path.startswith("s3://"):
  50. s3_full_path = s3_full_path[5:]
  51. if s3_full_path.startswith("/"):
  52. s3_full_path = s3_full_path[1:]
  53. bucket, key = s3_full_path.split("/", 1)
  54. return bucket, key
  55. def get_device():
  56. device_mode = os.getenv('MINERU_DEVICE_MODE', None)
  57. if device_mode is not None:
  58. return device_mode
  59. else:
  60. if torch.cuda.is_available():
  61. return "cuda"
  62. elif torch.backends.mps.is_available():
  63. return "mps"
  64. else:
  65. try:
  66. if torch_npu.npu.is_available():
  67. return "npu"
  68. except Exception as e:
  69. pass
  70. return "cpu"
  71. def get_formula_enable(formula_enable):
  72. formula_enable_env = os.getenv('MINERU_FORMULA_ENABLE')
  73. formula_enable = formula_enable if formula_enable_env is None else formula_enable_env.lower() == 'true'
  74. return formula_enable
  75. def get_table_enable(table_enable):
  76. table_enable_env = os.getenv('MINERU_TABLE_ENABLE')
  77. table_enable = table_enable if table_enable_env is None else table_enable_env.lower() == 'true'
  78. return table_enable
  79. def get_latex_delimiter_config():
  80. config = read_config()
  81. if config is None:
  82. return None
  83. latex_delimiter_config = config.get('latex-delimiter-config', None)
  84. if latex_delimiter_config is None:
  85. # logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
  86. return None
  87. else:
  88. return latex_delimiter_config
  89. def get_llm_aided_config():
  90. config = read_config()
  91. if config is None:
  92. return None
  93. llm_aided_config = config.get('llm-aided-config', None)
  94. if llm_aided_config is None:
  95. # logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
  96. return None
  97. else:
  98. return llm_aided_config
  99. def get_local_models_dir():
  100. config = read_config()
  101. if config is None:
  102. return None
  103. models_dir = config.get('models-dir')
  104. if models_dir is None:
  105. logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use None as default")
  106. return models_dir