|
|
@@ -51,9 +51,17 @@ def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipelin
|
|
|
else:
|
|
|
raise ValueError(f"未知的仓库类型: {model_source}")
|
|
|
|
|
|
- relative_path = relative_path.strip('/')
|
|
|
- cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
|
|
|
+ cache_dir = None
|
|
|
|
|
|
+ if repo_mode == 'pipeline':
|
|
|
+ relative_path = relative_path.strip('/')
|
|
|
+ cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
|
|
|
+ elif repo_mode == 'vlm':
|
|
|
+ # VLM 模式下,直接下载整个模型目录
|
|
|
+ cache_dir = snapshot_download(repo)
|
|
|
+
|
|
|
+ if not cache_dir:
|
|
|
+ raise FileNotFoundError(f"Failed to download model: {relative_path} from {repo}")
|
|
|
return cache_dir
|
|
|
|
|
|
|