models_download.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import json
  2. import os
  3. import sys
  4. import click
  5. import requests
  6. from loguru import logger
  7. from mineru.utils.enum_class import ModelPath
  8. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  9. def download_json(url):
  10. """下载JSON文件"""
  11. response = requests.get(url)
  12. response.raise_for_status()
  13. return response.json()
  14. def download_and_modify_json(url, local_filename, modifications):
  15. """下载JSON并修改内容"""
  16. if os.path.exists(local_filename):
  17. data = json.load(open(local_filename))
  18. config_version = data.get('config_version', '0.0.0')
  19. if config_version < '1.3.0':
  20. data = download_json(url)
  21. else:
  22. data = download_json(url)
  23. # 修改内容
  24. for key, value in modifications.items():
  25. if key in data:
  26. if isinstance(data[key], dict):
  27. # 如果是字典,合并新值
  28. data[key].update(value)
  29. else:
  30. # 否则直接替换
  31. data[key] = value
  32. # 保存修改后的内容
  33. with open(local_filename, 'w', encoding='utf-8') as f:
  34. json.dump(data, f, ensure_ascii=False, indent=4)
  35. def configure_model(model_dir, model_type):
  36. """配置模型"""
  37. json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/mineru.template.json'
  38. config_file_name = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
  39. home_dir = os.path.expanduser('~')
  40. config_file = os.path.join(home_dir, config_file_name)
  41. json_mods = {
  42. 'models-dir': {
  43. f'{model_type}': model_dir
  44. }
  45. }
  46. download_and_modify_json(json_url, config_file, json_mods)
  47. logger.info(f'The configuration file has been successfully configured, the path is: {config_file}')
  48. def download_pipeline_models():
  49. """下载Pipeline模型"""
  50. model_paths = [
  51. ModelPath.doclayout_yolo,
  52. ModelPath.yolo_v8_mfd,
  53. ModelPath.unimernet_small,
  54. ModelPath.pytorch_paddle,
  55. ModelPath.layout_reader,
  56. ModelPath.slanet_plus,
  57. ModelPath.unet_structure,
  58. ModelPath.paddle_table_cls,
  59. ModelPath.paddle_orientation_classification,
  60. ]
  61. download_finish_path = ""
  62. for model_path in model_paths:
  63. logger.info(f"Downloading model: {model_path}")
  64. download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
  65. logger.info(f"Pipeline models downloaded successfully to: {download_finish_path}")
  66. configure_model(download_finish_path, "pipeline")
  67. def download_vlm_models():
  68. """下载VLM模型"""
  69. download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
  70. logger.info(f"VLM models downloaded successfully to: {download_finish_path}")
  71. configure_model(download_finish_path, "vlm")
  72. @click.command()
  73. @click.option(
  74. '-s',
  75. '--source',
  76. 'model_source',
  77. type=click.Choice(['huggingface', 'modelscope']),
  78. help="""
  79. The source of the model repository.
  80. """,
  81. default=None,
  82. )
  83. @click.option(
  84. '-m',
  85. '--model_type',
  86. 'model_type',
  87. type=click.Choice(['pipeline', 'vlm', 'all']),
  88. help="""
  89. The type of the model to download.
  90. """,
  91. default=None,
  92. )
  93. def download_models(model_source, model_type):
  94. """Download MinerU model files.
  95. Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
  96. """
  97. # 如果未显式指定则交互式输入下载来源
  98. if model_source is None:
  99. model_source = click.prompt(
  100. "Please select the model download source: ",
  101. type=click.Choice(['huggingface', 'modelscope']),
  102. default='huggingface'
  103. )
  104. if os.getenv('MINERU_MODEL_SOURCE', None) is None:
  105. os.environ['MINERU_MODEL_SOURCE'] = model_source
  106. # 如果未显式指定则交互式输入模型类型
  107. if model_type is None:
  108. model_type = click.prompt(
  109. "Please select the model type to download: ",
  110. type=click.Choice(['pipeline', 'vlm', 'all']),
  111. default='all'
  112. )
  113. logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
  114. try:
  115. if model_type == 'pipeline':
  116. download_pipeline_models()
  117. elif model_type == 'vlm':
  118. download_vlm_models()
  119. elif model_type == 'all':
  120. download_pipeline_models()
  121. download_vlm_models()
  122. else:
  123. click.echo(f"Unsupported model type: {model_type}", err=True)
  124. sys.exit(1)
  125. except Exception as e:
  126. logger.exception(f"An error occurred while downloading models: {str(e)}")
  127. sys.exit(1)
  128. if __name__ == '__main__':
  129. download_models()