models_download.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import json
  2. import os
  3. import sys
  4. import click
  5. import requests
  6. from mineru.utils.enum_class import ModelPath
  7. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  8. def download_json(url):
  9. """下载JSON文件"""
  10. response = requests.get(url)
  11. response.raise_for_status()
  12. return response.json()
  13. def download_and_modify_json(url, local_filename, modifications):
  14. """下载JSON并修改内容"""
  15. if os.path.exists(local_filename):
  16. data = json.load(open(local_filename))
  17. config_version = data.get('config_version', '0.0.0')
  18. if config_version < '1.3.0':
  19. data = download_json(url)
  20. else:
  21. data = download_json(url)
  22. # 修改内容
  23. for key, value in modifications.items():
  24. if key in data:
  25. if isinstance(data[key], dict):
  26. # 如果是字典,合并新值
  27. data[key].update(value)
  28. else:
  29. # 否则直接替换
  30. data[key] = value
  31. # 保存修改后的内容
  32. with open(local_filename, 'w', encoding='utf-8') as f:
  33. json.dump(data, f, ensure_ascii=False, indent=4)
  34. def configure_model(model_dir, model_type):
  35. """配置模型"""
  36. json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/mineru.template.json'
  37. config_file_name = 'mineru.json'
  38. home_dir = os.path.expanduser('~')
  39. config_file = os.path.join(home_dir, config_file_name)
  40. json_mods = {
  41. 'models-dir': {
  42. f'{model_type}': model_dir
  43. }
  44. }
  45. download_and_modify_json(json_url, config_file, json_mods)
  46. print(f'The configuration file has been successfully configured, the path is: {config_file}')
  47. @click.command()
  48. def download_models():
  49. """下载MinerU模型文件。
  50. 支持从ModelScope或HuggingFace下载pipeline或VLM模型。
  51. """
  52. # 交互式输入下载来源
  53. source = click.prompt(
  54. "Please select the model download source: ",
  55. type=click.Choice(['huggingface', 'modelscope']),
  56. default='huggingface'
  57. )
  58. os.environ['MINERU_MODEL_SOURCE'] = source
  59. # 交互式输入模型类型
  60. model_type = click.prompt(
  61. "Please select the model type to download: ",
  62. type=click.Choice(['pipeline', 'vlm']),
  63. default='pipeline'
  64. )
  65. click.echo(f"Downloading {model_type} model from {source}...")
  66. try:
  67. download_finish_path = ""
  68. if model_type == 'pipeline':
  69. for model_path in [ModelPath.doclayout_yolo, ModelPath.yolo_v8_mfd, ModelPath.unimernet_small, ModelPath.pytorch_paddle, ModelPath.layout_reader, ModelPath.slanet_plus]:
  70. click.echo(f"Downloading model: {model_path}")
  71. download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode=model_type)
  72. elif model_type == 'vlm':
  73. download_finish_path = auto_download_and_get_model_root_path("/", repo_mode=model_type)
  74. click.echo(f"Models downloaded successfully to: {download_finish_path}")
  75. configure_model(download_finish_path, model_type)
  76. except Exception as e:
  77. click.echo(f"Download failed: {str(e)}", err=True)
  78. sys.exit(1)
  79. if __name__ == '__main__':
  80. download_models()