瀏覽代碼

feat: modify get_file_from_repos

Yuefeng Sun 5 月之前
父節點
當前提交
7e20339241
共有 1 個文件被更改,包括 15 次插入113 次删除
  1. 15 113
      mineru/utils/models_download_utils.py

+ 15 - 113
mineru/utils/models_download_utils.py

@@ -1,23 +1,10 @@
 import os
-import hashlib
-import requests
-from typing import List, Union
-from huggingface_hub import hf_hub_download, model_info
-from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
+from huggingface_hub import snapshot_download as hf_snapshot_download
+from modelscope import snapshot_download as ms_snapshot_download
 
 from mineru.utils.enum_class import ModelPath
 
-
-def _sha256sum(path, chunk_size=8192):
-    h = hashlib.sha256()
-    with open(path, "rb") as f:
-        while True:
-            chunk = f.read(chunk_size)
-            if not chunk:
-                break
-            h.update(chunk)
-    return h.hexdigest()
-def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> Union[str, str]:
+def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
     """
     支持文件或目录的可靠下载。
     - 如果输入文件: 返回本地文件绝对路径
@@ -48,104 +35,19 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> Union[str,
     # 如果没有指定model_source或值不是'modelscope',则使用默认值
     repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default'])
 
-    input_clean = relative_path.strip('/')
-    # 获取huggingface云端仓库文件树
-    try:
-        # 获取仓库信息,包含文件元数据
-        info = model_info(repo, files_metadata=True)
-        # 构建文件字典
-        siblings_dict = {f.rfilename: f for f in info.siblings}
-    except Exception as e:
-        siblings_dict = {}
-        print(f"[Warn] 获取 Huggingface 仓库结构失败,错误: {e}")
-    # 1. 文件还是目录拓展
-    if input_clean in siblings_dict and not siblings_dict[input_clean].rfilename.endswith("/"):
-        is_file = True
-        all_paths = [input_clean]
-    else:
-        is_file = False
-        all_paths = [k for k in siblings_dict if k.startswith(input_clean + "/") and not k.endswith("/")]
-    # 若获取不到siblings(如 Huggingface 失败,直接按输入处理)
-    if not all_paths:
-        is_file = os.path.splitext(input_clean)[1] != ""
-        all_paths = [input_clean] if is_file else []
-    cache_home = str(HUGGINGFACE_HUB_CACHE)
-    # 判断主逻辑
-    output_files = []
-    # ---- Huggingface 分支 ----
-    hf_ok = False
-    for relpath in all_paths:
-        ok = False
-        if relpath in siblings_dict:
-            meta = siblings_dict[relpath]
-            sha256 = ""
-            if meta.lfs:
-                sha256 = meta.lfs.sha256
-            try:
-                # 不允许下载线上文件,只寻找本地文件
-                file_path = hf_hub_download(repo_id=repo, filename=relpath, local_files_only=True)
-                if sha256 and os.path.exists(file_path):
-                    if _sha256sum(file_path) == sha256:
-                        ok = True
-                        output_files.append(file_path)
-            except Exception as e:
-                print(f"[Info] Huggingface {relpath} 获取失败: {e}")
-            if not hf_ok:
-                file_path = hf_hub_download(repo_id=repo, filename=relpath, force_download=False)
-                print("file_path = ", file_path)
-                if sha256 and _sha256sum(file_path) != sha256:
-                    raise ValueError(f"Huggingface下载后校验失败: {relpath}")
-                ok = True
-                output_files.append(file_path)
-            hf_ok = hf_ok and ok
-    # ---- ModelScope 分支 ----
-    for relpath in all_paths:
-        if hf_ok:
-            break
-        if "/" in repo:
-            org_name, model_name = repo.split("/", 1)
-        else:
-            org_name, model_name = "modelscope", repo
-        # 目录结构: 缓存/home/modelscope-fallback/org/model/相对路径
-        target_dir = os.path.join(cache_home, "modelscope-fallback", org_name, model_name, os.path.dirname(relpath))
-        os.makedirs(target_dir, exist_ok=True)
-        local_path = os.path.join(target_dir, os.path.basename(relpath))
-        remote_len = 0
-        sha256 = ""
-        try:
-            get_meta_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo/raw?Revision=master&FilePath={relpath}&Needmeta=true"
-            resp = requests.get(get_meta_url, timeout=15)
-            if resp.ok:
-                remote_len = resp.json()["Data"]["MetaContent"]["Size"]
-                sha256 = resp.json()["Data"]["MetaContent"]["Sha256"]
-        except Exception as e:
-            print(f"[Info] modelscope {relpath} 获取失败: {e}")
-        ok_local = False
-        if remote_len > 0 and os.path.exists(local_path):
-            if sha256 == _sha256sum(local_path):
-                output_files.append(local_path)
-                ok_local = True
-        if not ok_local:
-            try:
-                modelscope_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo?Revision=master&FilePath={relpath}"
-                with requests.get(modelscope_url, stream=True, timeout=30) as resp:
-                    resp.raise_for_status()
-                    with open(local_path, 'wb') as f:
-                        for chunk in resp.iter_content(1024*1024):
-                            if chunk:
-                                f.write(chunk)
-                if remote_len == 0 or os.path.getsize(local_path) == remote_len:
-                    output_files.append(local_path)
-                    ok_local = True
-            except Exception as e:
-                print(f"[Error] ModelScope下载失败: {relpath} {e}")
-    if not output_files:
-        raise FileNotFoundError(f"{relative_path} 在 Huggingface 和 ModelScope 都未能获取")
-    if is_file:
-        return output_files[0]
+
+    if model_source == "huggingface":
+        snapshot_download = hf_snapshot_download
+    elif model_source == "modelscope":
+        snapshot_download = ms_snapshot_download
     else:
-        # 输入是文件,只返回路径字符串
-        return os.path.dirname(os.path.abspath(output_files[0]))
+        raise ValueError(f"未知的仓库类型: {model_source}")
+
+    relative_path = relative_path.strip('/')
+    cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
+
+    return cache_dir + "/" + relative_path
+
 if __name__ == '__main__':
     path1 = get_file_from_repos("models/README.md")
     print("本地文件绝对路径:", path1)