models_download_utils.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import os
  2. import hashlib
  3. import requests
  4. from typing import List, Union
  5. from huggingface_hub import hf_hub_download, model_info
  6. from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
  7. from mineru.utils.enum_class import ModelPath
  8. def _sha256sum(path, chunk_size=8192):
  9. h = hashlib.sha256()
  10. with open(path, "rb") as f:
  11. while True:
  12. chunk = f.read(chunk_size)
  13. if not chunk:
  14. break
  15. h.update(chunk)
  16. return h.hexdigest()
  17. def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> Union[str, str]:
  18. """
  19. 支持文件或目录的可靠下载。
  20. - 如果输入文件: 返回本地文件绝对路径
  21. - 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串
  22. :param repo_mode: 指定仓库模式,'pipeline' 或 'vlm'
  23. :param relative_path: 文件或目录相对路径
  24. :return: 本地文件绝对路径或相对路径
  25. """
  26. model_source = os.getenv('MINERU_MODEL_SOURCE', None)
  27. # 建立仓库模式到路径的映射
  28. repo_mapping = {
  29. 'pipeline': {
  30. 'huggingface': ModelPath.pipeline_root_hf,
  31. 'modelscope': ModelPath.pipeline_root_modelscope,
  32. 'default': ModelPath.pipeline_root_hf
  33. },
  34. 'vlm': {
  35. 'huggingface': ModelPath.vlm_root_hf,
  36. 'modelscope': ModelPath.vlm_root_modelscope,
  37. 'default': ModelPath.vlm_root_hf
  38. }
  39. }
  40. if repo_mode not in repo_mapping:
  41. raise ValueError(f"Unsupported repo_mode: {repo_mode}, must be 'pipeline' or 'vlm'")
  42. # 如果没有指定model_source或值不是'modelscope',则使用默认值
  43. repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default'])
  44. input_clean = relative_path.strip('/')
  45. # 获取huggingface云端仓库文件树
  46. try:
  47. # 获取仓库信息,包含文件元数据
  48. info = model_info(repo, files_metadata=True)
  49. # 构建文件字典
  50. siblings_dict = {f.rfilename: f for f in info.siblings}
  51. except Exception as e:
  52. siblings_dict = {}
  53. print(f"[Warn] 获取 Huggingface 仓库结构失败,错误: {e}")
  54. # 1. 文件还是目录拓展
  55. if input_clean in siblings_dict and not siblings_dict[input_clean].rfilename.endswith("/"):
  56. is_file = True
  57. all_paths = [input_clean]
  58. else:
  59. is_file = False
  60. all_paths = [k for k in siblings_dict if k.startswith(input_clean + "/") and not k.endswith("/")]
  61. # 若获取不到siblings(如 Huggingface 失败,直接按输入处理)
  62. if not all_paths:
  63. is_file = os.path.splitext(input_clean)[1] != ""
  64. all_paths = [input_clean] if is_file else []
  65. cache_home = str(HUGGINGFACE_HUB_CACHE)
  66. # 判断主逻辑
  67. output_files = []
  68. # ---- Huggingface 分支 ----
  69. hf_ok = False
  70. for relpath in all_paths:
  71. ok = False
  72. if relpath in siblings_dict:
  73. meta = siblings_dict[relpath]
  74. sha256 = ""
  75. if meta.lfs:
  76. sha256 = meta.lfs.sha256
  77. try:
  78. # 不允许下载线上文件,只寻找本地文件
  79. file_path = hf_hub_download(repo_id=repo, filename=relpath, local_files_only=True)
  80. if sha256 and os.path.exists(file_path):
  81. if _sha256sum(file_path) == sha256:
  82. ok = True
  83. output_files.append(file_path)
  84. except Exception as e:
  85. print(f"[Info] Huggingface {relpath} 获取失败: {e}")
  86. if not hf_ok:
  87. file_path = hf_hub_download(repo_id=repo, filename=relpath, force_download=False)
  88. print("file_path = ", file_path)
  89. if sha256 and _sha256sum(file_path) != sha256:
  90. raise ValueError(f"Huggingface下载后校验失败: {relpath}")
  91. ok = True
  92. output_files.append(file_path)
  93. hf_ok = hf_ok and ok
  94. # ---- ModelScope 分支 ----
  95. for relpath in all_paths:
  96. if hf_ok:
  97. break
  98. if "/" in repo:
  99. org_name, model_name = repo.split("/", 1)
  100. else:
  101. org_name, model_name = "modelscope", repo
  102. # 目录结构: 缓存/home/modelscope-fallback/org/model/相对路径
  103. target_dir = os.path.join(cache_home, "modelscope-fallback", org_name, model_name, os.path.dirname(relpath))
  104. os.makedirs(target_dir, exist_ok=True)
  105. local_path = os.path.join(target_dir, os.path.basename(relpath))
  106. remote_len = 0
  107. sha256 = ""
  108. try:
  109. get_meta_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo/raw?Revision=master&FilePath={relpath}&Needmeta=true"
  110. resp = requests.get(get_meta_url, timeout=15)
  111. if resp.ok:
  112. remote_len = resp.json()["Data"]["MetaContent"]["Size"]
  113. sha256 = resp.json()["Data"]["MetaContent"]["Sha256"]
  114. except Exception as e:
  115. print(f"[Info] modelscope {relpath} 获取失败: {e}")
  116. ok_local = False
  117. if remote_len > 0 and os.path.exists(local_path):
  118. if sha256 == _sha256sum(local_path):
  119. output_files.append(local_path)
  120. ok_local = True
  121. if not ok_local:
  122. try:
  123. modelscope_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo?Revision=master&FilePath={relpath}"
  124. with requests.get(modelscope_url, stream=True, timeout=30) as resp:
  125. resp.raise_for_status()
  126. with open(local_path, 'wb') as f:
  127. for chunk in resp.iter_content(1024*1024):
  128. if chunk:
  129. f.write(chunk)
  130. if remote_len == 0 or os.path.getsize(local_path) == remote_len:
  131. output_files.append(local_path)
  132. ok_local = True
  133. except Exception as e:
  134. print(f"[Error] ModelScope下载失败: {relpath} {e}")
  135. if not output_files:
  136. raise FileNotFoundError(f"{relative_path} 在 Huggingface 和 ModelScope 都未能获取")
  137. if is_file:
  138. return output_files[0]
  139. else:
  140. # 输入是文件,只返回路径字符串
  141. return os.path.dirname(os.path.abspath(output_files[0]))
  142. if __name__ == '__main__':
  143. path1 = get_file_from_repos("models/README.md")
  144. print("本地文件绝对路径:", path1)
  145. path2 = get_file_from_repos("models/OCR/paddleocr_torch/")
  146. print("本地文件绝对路径:", path2)