|
|
@@ -57,8 +57,12 @@ def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipelin
|
|
|
relative_path = relative_path.strip('/')
|
|
|
cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
|
|
|
elif repo_mode == 'vlm':
|
|
|
- # VLM 模式下,直接下载整个模型目录
|
|
|
- cache_dir = snapshot_download(repo)
|
|
|
+ # VLM 模式下,根据 relative_path 的不同处理方式
|
|
|
+ if relative_path == "/":
|
|
|
+ cache_dir = snapshot_download(repo)
|
|
|
+ else:
|
|
|
+ relative_path = relative_path.strip('/')
|
|
|
+ cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
|
|
|
|
|
|
if not cache_dir:
|
|
|
raise FileNotFoundError(f"Failed to download model: {relative_path} from {repo}")
|