models_download_utils.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import os
  2. from huggingface_hub import snapshot_download as hf_snapshot_download
  3. from modelscope import snapshot_download as ms_snapshot_download
  4. from mineru.utils.enum_class import ModelPath
  5. def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
  6. """
  7. 支持文件或目录的可靠下载。
  8. - 如果输入文件: 返回本地文件绝对路径
  9. - 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串
  10. :param repo_mode: 指定仓库模式,'pipeline' 或 'vlm'
  11. :param relative_path: 文件或目录相对路径
  12. :return: 本地文件绝对路径或相对路径
  13. """
  14. model_source = os.getenv('MINERU_MODEL_SOURCE', None)
  15. # 建立仓库模式到路径的映射
  16. repo_mapping = {
  17. 'pipeline': {
  18. 'huggingface': ModelPath.pipeline_root_hf,
  19. 'modelscope': ModelPath.pipeline_root_modelscope,
  20. 'default': ModelPath.pipeline_root_hf
  21. },
  22. 'vlm': {
  23. 'huggingface': ModelPath.vlm_root_hf,
  24. 'modelscope': ModelPath.vlm_root_modelscope,
  25. 'default': ModelPath.vlm_root_hf
  26. }
  27. }
  28. if repo_mode not in repo_mapping:
  29. raise ValueError(f"Unsupported repo_mode: {repo_mode}, must be 'pipeline' or 'vlm'")
  30. # 如果没有指定model_source或值不是'modelscope',则使用默认值
  31. repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default'])
  32. if model_source == "huggingface":
  33. snapshot_download = hf_snapshot_download
  34. elif model_source == "modelscope":
  35. snapshot_download = ms_snapshot_download
  36. else:
  37. raise ValueError(f"未知的仓库类型: {model_source}")
  38. relative_path = relative_path.strip('/')
  39. cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
  40. return cache_dir + "/" + relative_path
  41. if __name__ == '__main__':
  42. path1 = get_file_from_repos("models/README.md")
  43. print("本地文件绝对路径:", path1)
  44. path2 = get_file_from_repos("models/OCR/paddleocr_torch/")
  45. print("本地文件绝对路径:", path2)