download.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import shutil
  17. import requests
  18. import time
  19. import zipfile
  20. import tarfile
  21. import hashlib
  22. import tqdm
  23. import logging
  24. from .utils.hub_model_server import model_server
  25. from .utils import hub_env as hubenv
  26. DOWNLOAD_RETRY_LIMIT = 3
  27. def md5check(fullname, md5sum=None):
  28. if md5sum is None:
  29. return True
  30. logging.info("File {} md5 checking...".format(fullname))
  31. md5 = hashlib.md5()
  32. with open(fullname, "rb") as f:
  33. for chunk in iter(lambda: f.read(4096), b""):
  34. md5.update(chunk)
  35. calc_md5sum = md5.hexdigest()
  36. if calc_md5sum != md5sum:
  37. logging.info(
  38. "File {} md5 check failed, {}(calc) != "
  39. "{}(base)".format(fullname, calc_md5sum, md5sum)
  40. )
  41. return False
  42. return True
  43. def move_and_merge_tree(src, dst):
  44. """
  45. Move src directory to dst, if dst is already exists,
  46. merge src to dst
  47. """
  48. if not osp.exists(dst):
  49. shutil.move(src, dst)
  50. else:
  51. if not osp.isdir(src):
  52. shutil.move(src, dst)
  53. return
  54. for fp in os.listdir(src):
  55. src_fp = osp.join(src, fp)
  56. dst_fp = osp.join(dst, fp)
  57. if osp.isdir(src_fp):
  58. if osp.isdir(dst_fp):
  59. move_and_merge_tree(src_fp, dst_fp)
  60. else:
  61. shutil.move(src_fp, dst_fp)
  62. elif osp.isfile(src_fp) and not osp.isfile(dst_fp):
  63. shutil.move(src_fp, dst_fp)
  64. def download(url, path, rename=None, md5sum=None, show_progress=False):
  65. """
  66. Download from url, save to path.
  67. url (str): download url
  68. path (str): download to given path
  69. """
  70. if not osp.exists(path):
  71. os.makedirs(path)
  72. fname = osp.split(url)[-1]
  73. fullname = osp.join(path, fname)
  74. if rename is not None:
  75. fullname = osp.join(path, rename)
  76. retry_cnt = 0
  77. while not (osp.exists(fullname) and md5check(fullname, md5sum)):
  78. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  79. retry_cnt += 1
  80. else:
  81. logging.debug("{} download failed.".format(fname))
  82. raise RuntimeError(
  83. "Download from {} failed. " "Retry limit reached".format(url)
  84. )
  85. logging.info("Downloading {} from {}".format(fname, url))
  86. req = requests.get(url, stream=True)
  87. if req.status_code != 200:
  88. raise RuntimeError(
  89. "Downloading from {} failed with code "
  90. "{}!".format(url, req.status_code)
  91. )
  92. # For protecting download interrupted, download to
  93. # tmp_fullname firstly, move tmp_fullname to fullname
  94. # after download finished
  95. tmp_fullname = fullname + "_tmp"
  96. total_size = req.headers.get("content-length")
  97. with open(tmp_fullname, "wb") as f:
  98. if total_size and show_progress:
  99. for chunk in tqdm.tqdm(
  100. req.iter_content(chunk_size=1024),
  101. total=(int(total_size) + 1023) // 1024,
  102. unit="KB",
  103. ):
  104. f.write(chunk)
  105. else:
  106. for chunk in req.iter_content(chunk_size=1024):
  107. if chunk:
  108. f.write(chunk)
  109. shutil.move(tmp_fullname, fullname)
  110. logging.debug("{} download completed.".format(fname))
  111. return fullname
  112. def decompress(fname):
  113. """
  114. Decompress for zip and tar file
  115. """
  116. logging.info("Decompressing {}...".format(fname))
  117. # For protecting decompressing interrupted,
  118. # decompress to fpath_tmp directory firstly, if decompress
  119. # successed, move decompress files to fpath and delete
  120. # fpath_tmp and remove download compress file.
  121. fpath = osp.split(fname)[0]
  122. fpath_tmp = osp.join(fpath, "tmp")
  123. if osp.isdir(fpath_tmp):
  124. shutil.rmtree(fpath_tmp)
  125. os.makedirs(fpath_tmp)
  126. if fname.find(".tar") >= 0 or fname.find(".tgz") >= 0:
  127. with tarfile.open(fname) as tf:
  128. def is_within_directory(directory, target):
  129. abs_directory = os.path.abspath(directory)
  130. abs_target = os.path.abspath(target)
  131. prefix = os.path.commonprefix([abs_directory, abs_target])
  132. return prefix == abs_directory
  133. def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
  134. for member in tar.getmembers():
  135. member_path = os.path.join(path, member.name)
  136. if not is_within_directory(path, member_path):
  137. raise Exception("Attempted Path Traversal in Tar File")
  138. tar.extractall(path, members, numeric_owner=numeric_owner)
  139. safe_extract(tf, path=fpath_tmp)
  140. elif fname.find(".zip") >= 0:
  141. with zipfile.ZipFile(fname) as zf:
  142. zf.extractall(path=fpath_tmp)
  143. else:
  144. raise TypeError("Unsupport compress file type {}".format(fname))
  145. for f in os.listdir(fpath_tmp):
  146. src_dir = osp.join(fpath_tmp, f)
  147. dst_dir = osp.join(fpath, f)
  148. move_and_merge_tree(src_dir, dst_dir)
  149. shutil.rmtree(fpath_tmp)
  150. logging.debug("{} decompressed.".format(fname))
  151. return dst_dir
  152. def url2dir(url, path, rename=None):
  153. full_name = download(url, path, rename, show_progress=True)
  154. print("File is downloaded, now extracting...")
  155. if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count("zip") > 0:
  156. return decompress(full_name)
  157. def download_and_decompress(url, path=".", rename=None):
  158. fname = osp.split(url)[-1]
  159. fullname = osp.join(path, fname)
  160. # if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
  161. # fullname = osp.join(path, fname.split('.')[0])
  162. nranks = 0
  163. if nranks <= 1:
  164. dst_dir = url2dir(url, path, rename)
  165. if dst_dir is not None:
  166. fullname = dst_dir
  167. else:
  168. lock_path = fullname + ".lock"
  169. if not os.path.exists(fullname):
  170. with open(lock_path, "w"):
  171. os.utime(lock_path, None)
  172. if nranks == 0:
  173. dst_dir = url2dir(url, path, rename)
  174. if dst_dir is not None:
  175. fullname = dst_dir
  176. os.remove(lock_path)
  177. else:
  178. while os.path.exists(lock_path):
  179. time.sleep(1)
  180. return
  181. def get_model_list(category: str = None):
  182. """
  183. Get all pre-trained models information supported by fd.download_model.
  184. Args:
  185. category(str): model category, if None, list all models in all categories.
  186. Returns:
  187. results(dict): a dictionary, key is category, value is a list which contains models information.
  188. """
  189. result = model_server.get_model_list()
  190. if result["status"] != 0:
  191. raise ValueError(
  192. "Failed to get pretrained models information from hub model server."
  193. )
  194. result = result["data"]
  195. if category is None:
  196. return result
  197. elif category in result:
  198. return {category: result[category]}
  199. else:
  200. raise ValueError(
  201. "No pretrained model in category {} can be downloaded now.".format(category)
  202. )
  203. def download_model(
  204. name: str, path: str = None, format: str = None, version: str = None
  205. ):
  206. """
  207. Download pre-trained model for UltraInfer inference engine.
  208. Args:
  209. name: model name
  210. path(str): local path for saving model. If not set, default is hubenv.MODEL_HOME
  211. format(str): UltraInfer model format
  212. version(str) : UltraInfer model version
  213. """
  214. result = model_server.search_model(name, format, version)
  215. if path is None:
  216. path = hubenv.MODEL_HOME
  217. if result:
  218. url = result[0]["url"]
  219. format = result[0]["format"]
  220. version = result[0]["version"]
  221. fullpath = download(url, path, show_progress=True)
  222. model_server.stat_model(name, format, version)
  223. if format == "paddle":
  224. if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count("zip") > 0:
  225. archive_path = fullpath
  226. fullpath = decompress(fullpath)
  227. try:
  228. os.rename(fullpath, os.path.join(os.path.dirname(fullpath), name))
  229. fullpath = os.path.join(os.path.dirname(fullpath), name)
  230. os.remove(archive_path)
  231. except FileExistsError:
  232. pass
  233. print("Successfully download model at path: {}".format(fullpath))
  234. return fullpath
  235. else:
  236. print("ERROR: Could not find a model named {}".format(name))