|
@@ -2,6 +2,7 @@ import os
|
|
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
|
|
from modelscope import snapshot_download as ms_snapshot_download
|
|
from modelscope import snapshot_download as ms_snapshot_download
|
|
|
|
|
|
|
|
|
|
+from mineru.utils.config_reader import get_local_models_dir
|
|
|
from mineru.utils.enum_class import ModelPath
|
|
from mineru.utils.enum_class import ModelPath
|
|
|
|
|
|
|
|
def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
|
|
def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
|
|
@@ -15,6 +16,13 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
|
|
|
"""
|
|
"""
|
|
|
model_source = os.getenv('MINERU_MODEL_SOURCE', None)
|
|
model_source = os.getenv('MINERU_MODEL_SOURCE', None)
|
|
|
|
|
|
|
|
|
|
+ if model_source == 'local':
|
|
|
|
|
+ local_models_config = get_local_models_dir()
|
|
|
|
|
+ root_path = local_models_config.get(repo_mode, None)
|
|
|
|
|
+ if not root_path:
|
|
|
|
|
+ raise ValueError(f"Local path for repo_mode '{repo_mode}' is not configured.")
|
|
|
|
|
+ return os.path.join(root_path, relative_path.strip('/'))
|
|
|
|
|
+
|
|
|
# 建立仓库模式到路径的映射
|
|
# 建立仓库模式到路径的映射
|
|
|
repo_mapping = {
|
|
repo_mapping = {
|
|
|
'pipeline': {
|
|
'pipeline': {
|