models_download_utils.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. from huggingface_hub import snapshot_download as hf_snapshot_download
  3. from modelscope import snapshot_download as ms_snapshot_download
  4. from mineru.utils.config_reader import get_local_models_dir
  5. from mineru.utils.enum_class import ModelPath
  6. def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipeline') -> str:
  7. """
  8. 支持文件或目录的可靠下载。
  9. - 如果输入文件: 返回本地文件绝对路径
  10. - 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串
  11. :param repo_mode: 指定仓库模式,'pipeline' 或 'vlm'
  12. :param relative_path: 文件或目录相对路径
  13. :return: 本地文件绝对路径或相对路径
  14. """
  15. model_source = os.getenv('MINERU_MODEL_SOURCE', "huggingface")
  16. if model_source == 'local':
  17. local_models_config = get_local_models_dir()
  18. root_path = local_models_config.get(repo_mode, None)
  19. if not root_path:
  20. raise ValueError(f"Local path for repo_mode '{repo_mode}' is not configured.")
  21. return root_path
  22. # 建立仓库模式到路径的映射
  23. repo_mapping = {
  24. 'pipeline': {
  25. 'huggingface': ModelPath.pipeline_root_hf,
  26. 'modelscope': ModelPath.pipeline_root_modelscope,
  27. 'default': ModelPath.pipeline_root_hf
  28. },
  29. 'vlm': {
  30. 'huggingface': ModelPath.vlm_root_hf,
  31. 'modelscope': ModelPath.vlm_root_modelscope,
  32. 'default': ModelPath.vlm_root_hf
  33. }
  34. }
  35. if repo_mode not in repo_mapping:
  36. raise ValueError(f"Unsupported repo_mode: {repo_mode}, must be 'pipeline' or 'vlm'")
  37. # 如果没有指定model_source或值不是'modelscope',则使用默认值
  38. repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default'])
  39. if model_source == "huggingface":
  40. snapshot_download = hf_snapshot_download
  41. elif model_source == "modelscope":
  42. snapshot_download = ms_snapshot_download
  43. else:
  44. raise ValueError(f"未知的仓库类型: {model_source}")
  45. cache_dir = None
  46. if repo_mode == 'pipeline':
  47. relative_path = relative_path.strip('/')
  48. cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
  49. elif repo_mode == 'vlm':
  50. # VLM 模式下,直接下载整个模型目录
  51. cache_dir = snapshot_download(repo)
  52. if not cache_dir:
  53. raise FileNotFoundError(f"Failed to download model: {relative_path} from {repo}")
  54. return cache_dir
  55. if __name__ == '__main__':
  56. path1 = "models/README.md"
  57. root = auto_download_and_get_model_root_path(path1)
  58. print("本地文件绝对路径:", os.path.join(root, path1))