download.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import os.path as osp
  20. import shutil
  21. import requests
  22. import hashlib
  23. import tarfile
  24. import zipfile
  25. import time
  26. from collections import OrderedDict
  27. from tqdm import tqdm
  28. from paddlex.ppcls.utils import logger
  29. __all__ = ['get_weights_path_from_url']
  30. WEIGHTS_HOME = osp.expanduser("~/.paddleclas/weights")
  31. DOWNLOAD_RETRY_LIMIT = 3
  32. def is_url(path):
  33. """
  34. Whether path is URL.
  35. Args:
  36. path (string): URL string or not.
  37. """
  38. return path.startswith('http://') or path.startswith('https://')
  39. def get_weights_path_from_url(url, md5sum=None):
  40. """Get weights path from WEIGHT_HOME, if not exists,
  41. download it from url.
  42. Args:
  43. url (str): download url
  44. md5sum (str): md5 sum of download package
  45. Returns:
  46. str: a local path to save downloaded weights.
  47. Examples:
  48. .. code-block:: python
  49. from paddle.utils.download import get_weights_path_from_url
  50. resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
  51. local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
  52. """
  53. path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
  54. return path
  55. def _map_path(url, root_dir):
  56. # parse path after download under root_dir
  57. fname = osp.split(url)[-1]
  58. fpath = fname
  59. return osp.join(root_dir, fpath)
  60. def _get_unique_endpoints(trainer_endpoints):
  61. # Sorting is to avoid different environmental variables for each card
  62. trainer_endpoints.sort()
  63. ips = set()
  64. unique_endpoints = set()
  65. for endpoint in trainer_endpoints:
  66. ip = endpoint.split(":")[0]
  67. if ip in ips:
  68. continue
  69. ips.add(ip)
  70. unique_endpoints.add(endpoint)
  71. logger.info("unique_endpoints {}".format(unique_endpoints))
  72. return unique_endpoints
  73. def get_path_from_url(url,
  74. root_dir,
  75. md5sum=None,
  76. check_exist=True,
  77. decompress=True):
  78. """ Download from given url to root_dir.
  79. if file or directory specified by url is exists under
  80. root_dir, return the path directly, otherwise download
  81. from url and decompress it, return the path.
  82. Args:
  83. url (str): download url
  84. root_dir (str): root dir for downloading, it should be
  85. WEIGHTS_HOME or DATASET_HOME
  86. md5sum (str): md5 sum of download package
  87. Returns:
  88. str: a local path to save downloaded models & weights & datasets.
  89. """
  90. from paddle.distributed import ParallelEnv
  91. assert is_url(url), "downloading from {} not a url".format(url)
  92. # parse path after download to decompress under root_dir
  93. fullpath = _map_path(url, root_dir)
  94. # Mainly used to solve the problem of downloading data from different
  95. # machines in the case of multiple machines. Different ips will download
  96. # data, and the same ip will only download data once.
  97. unique_endpoints = _get_unique_endpoints(ParallelEnv()
  98. .trainer_endpoints[:])
  99. if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
  100. logger.info("Found {}".format(fullpath))
  101. else:
  102. if ParallelEnv().current_endpoint in unique_endpoints:
  103. fullpath = _download(url, root_dir, md5sum)
  104. else:
  105. while not os.path.exists(fullpath):
  106. time.sleep(1)
  107. if ParallelEnv().current_endpoint in unique_endpoints:
  108. if decompress and (tarfile.is_tarfile(fullpath) or
  109. zipfile.is_zipfile(fullpath)):
  110. fullpath = _decompress(fullpath)
  111. return fullpath
  112. def _download(url, path, md5sum=None):
  113. """
  114. Download from url, save to path.
  115. url (str): download url
  116. path (str): download to given path
  117. """
  118. if not osp.exists(path):
  119. os.makedirs(path)
  120. fname = osp.split(url)[-1]
  121. fullname = osp.join(path, fname)
  122. retry_cnt = 0
  123. while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
  124. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  125. retry_cnt += 1
  126. else:
  127. raise RuntimeError("Download from {} failed. "
  128. "Retry limit reached".format(url))
  129. logger.info("Downloading {} from {}".format(fname, url))
  130. try:
  131. req = requests.get(url, stream=True)
  132. except Exception as e: # requests.exceptions.ConnectionError
  133. logger.info(
  134. "Downloading {} from {} failed {} times with exception {}".
  135. format(fname, url, retry_cnt + 1, str(e)))
  136. time.sleep(1)
  137. continue
  138. if req.status_code != 200:
  139. raise RuntimeError("Downloading from {} failed with code "
  140. "{}!".format(url, req.status_code))
  141. # For protecting download interupted, download to
  142. # tmp_fullname firstly, move tmp_fullname to fullname
  143. # after download finished
  144. tmp_fullname = fullname + "_tmp"
  145. total_size = req.headers.get('content-length')
  146. with open(tmp_fullname, 'wb') as f:
  147. if total_size:
  148. with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
  149. for chunk in req.iter_content(chunk_size=1024):
  150. f.write(chunk)
  151. pbar.update(1)
  152. else:
  153. for chunk in req.iter_content(chunk_size=1024):
  154. if chunk:
  155. f.write(chunk)
  156. shutil.move(tmp_fullname, fullname)
  157. return fullname
  158. def _md5check(fullname, md5sum=None):
  159. if md5sum is None:
  160. return True
  161. logger.info("File {} md5 checking...".format(fullname))
  162. md5 = hashlib.md5()
  163. with open(fullname, 'rb') as f:
  164. for chunk in iter(lambda: f.read(4096), b""):
  165. md5.update(chunk)
  166. calc_md5sum = md5.hexdigest()
  167. if calc_md5sum != md5sum:
  168. logger.info("File {} md5 check failed, {}(calc) != "
  169. "{}(base)".format(fullname, calc_md5sum, md5sum))
  170. return False
  171. return True
  172. def _decompress(fname):
  173. """
  174. Decompress for zip and tar file
  175. """
  176. logger.info("Decompressing {}...".format(fname))
  177. # For protecting decompressing interupted,
  178. # decompress to fpath_tmp directory firstly, if decompress
  179. # successed, move decompress files to fpath and delete
  180. # fpath_tmp and remove download compress file.
  181. if tarfile.is_tarfile(fname):
  182. uncompressed_path = _uncompress_file_tar(fname)
  183. elif zipfile.is_zipfile(fname):
  184. uncompressed_path = _uncompress_file_zip(fname)
  185. else:
  186. raise TypeError("Unsupport compress file type {}".format(fname))
  187. return uncompressed_path
  188. def _uncompress_file_zip(filepath):
  189. files = zipfile.ZipFile(filepath, 'r')
  190. file_list = files.namelist()
  191. file_dir = os.path.dirname(filepath)
  192. if _is_a_single_file(file_list):
  193. rootpath = file_list[0]
  194. uncompressed_path = os.path.join(file_dir, rootpath)
  195. for item in file_list:
  196. files.extract(item, file_dir)
  197. elif _is_a_single_dir(file_list):
  198. rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
  199. uncompressed_path = os.path.join(file_dir, rootpath)
  200. for item in file_list:
  201. files.extract(item, file_dir)
  202. else:
  203. rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
  204. uncompressed_path = os.path.join(file_dir, rootpath)
  205. if not os.path.exists(uncompressed_path):
  206. os.makedirs(uncompressed_path)
  207. for item in file_list:
  208. files.extract(item, os.path.join(file_dir, rootpath))
  209. files.close()
  210. return uncompressed_path
  211. def _uncompress_file_tar(filepath, mode="r:*"):
  212. files = tarfile.open(filepath, mode)
  213. file_list = files.getnames()
  214. file_dir = os.path.dirname(filepath)
  215. if _is_a_single_file(file_list):
  216. rootpath = file_list[0]
  217. uncompressed_path = os.path.join(file_dir, rootpath)
  218. for item in file_list:
  219. files.extract(item, file_dir)
  220. elif _is_a_single_dir(file_list):
  221. rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
  222. uncompressed_path = os.path.join(file_dir, rootpath)
  223. for item in file_list:
  224. files.extract(item, file_dir)
  225. else:
  226. rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
  227. uncompressed_path = os.path.join(file_dir, rootpath)
  228. if not os.path.exists(uncompressed_path):
  229. os.makedirs(uncompressed_path)
  230. for item in file_list:
  231. files.extract(item, os.path.join(file_dir, rootpath))
  232. files.close()
  233. return uncompressed_path
  234. def _is_a_single_file(file_list):
  235. if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
  236. return True
  237. return False
  238. def _is_a_single_dir(file_list):
  239. new_file_list = []
  240. for file_path in file_list:
  241. if '/' in file_path:
  242. file_path = file_path.replace('/', os.sep)
  243. elif '\\' in file_path:
  244. file_path = file_path.replace('\\', os.sep)
  245. new_file_list.append(file_path)
  246. file_name = new_file_list[0].split(os.sep)[0]
  247. for i in range(1, len(new_file_list)):
  248. if file_name != new_file_list[i].split(os.sep)[0]:
  249. return False
  250. return True