| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import os
- from huggingface_hub import snapshot_download as hf_snapshot_download
- from modelscope import snapshot_download as ms_snapshot_download
- from mineru.utils.enum_class import ModelPath
- def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
- """
- 支持文件或目录的可靠下载。
- - 如果输入文件: 返回本地文件绝对路径
- - 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串
- :param repo_mode: 指定仓库模式,'pipeline' 或 'vlm'
- :param relative_path: 文件或目录相对路径
- :return: 本地文件绝对路径或相对路径
- """
- model_source = os.getenv('MINERU_MODEL_SOURCE', None)
- # 建立仓库模式到路径的映射
- repo_mapping = {
- 'pipeline': {
- 'huggingface': ModelPath.pipeline_root_hf,
- 'modelscope': ModelPath.pipeline_root_modelscope,
- 'default': ModelPath.pipeline_root_hf
- },
- 'vlm': {
- 'huggingface': ModelPath.vlm_root_hf,
- 'modelscope': ModelPath.vlm_root_modelscope,
- 'default': ModelPath.vlm_root_hf
- }
- }
- if repo_mode not in repo_mapping:
- raise ValueError(f"Unsupported repo_mode: {repo_mode}, must be 'pipeline' or 'vlm'")
- # 如果没有指定model_source或值不是'modelscope',则使用默认值
- repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default'])
- if model_source == "huggingface":
- snapshot_download = hf_snapshot_download
- elif model_source == "modelscope":
- snapshot_download = ms_snapshot_download
- else:
- raise ValueError(f"未知的仓库类型: {model_source}")
- relative_path = relative_path.strip('/')
- cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
- return cache_dir + "/" + relative_path
- if __name__ == '__main__':
- path1 = get_file_from_repos("models/README.md")
- print("本地文件绝对路径:", path1)
- path2 = get_file_from_repos("models/OCR/paddleocr_torch/")
- print("本地文件绝对路径:", path2)
|