import os from huggingface_hub import snapshot_download as hf_snapshot_download from modelscope import snapshot_download as ms_snapshot_download from mineru.utils.config_reader import get_local_models_dir from mineru.utils.enum_class import ModelPath def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str: """ 支持文件或目录的可靠下载。 - 如果输入文件: 返回本地文件绝对路径 - 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串 :param repo_mode: 指定仓库模式,'pipeline' 或 'vlm' :param relative_path: 文件或目录相对路径 :return: 本地文件绝对路径或相对路径 """ model_source = os.getenv('MINERU_MODEL_SOURCE', None) if model_source == 'local': local_models_config = get_local_models_dir() root_path = local_models_config.get(repo_mode, None) if not root_path: raise ValueError(f"Local path for repo_mode '{repo_mode}' is not configured.") return os.path.join(root_path, relative_path.strip('/')) # 建立仓库模式到路径的映射 repo_mapping = { 'pipeline': { 'huggingface': ModelPath.pipeline_root_hf, 'modelscope': ModelPath.pipeline_root_modelscope, 'default': ModelPath.pipeline_root_hf }, 'vlm': { 'huggingface': ModelPath.vlm_root_hf, 'modelscope': ModelPath.vlm_root_modelscope, 'default': ModelPath.vlm_root_hf } } if repo_mode not in repo_mapping: raise ValueError(f"Unsupported repo_mode: {repo_mode}, must be 'pipeline' or 'vlm'") # 如果没有指定model_source或值不是'modelscope',则使用默认值 repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default']) if model_source == "huggingface": snapshot_download = hf_snapshot_download elif model_source == "modelscope": snapshot_download = ms_snapshot_download else: raise ValueError(f"未知的仓库类型: {model_source}") relative_path = relative_path.strip('/') cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"]) return cache_dir + "/" + relative_path if __name__ == '__main__': path1 = get_file_from_repos("models/README.md") print("本地文件绝对路径:", path1) path2 = get_file_from_repos("models/OCR/paddleocr_torch/") print("本地文件绝对路径:", path2)