download.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import shutil
  17. import requests
  18. import tqdm
  19. import time
  20. import hashlib
  21. import tarfile
  22. import zipfile
  23. import filelock
  24. import paddle
  25. from . import logging
  26. DOWNLOAD_RETRY_LIMIT = 3
  27. def md5check(fullname, md5sum=None):
  28. if md5sum is None:
  29. return True
  30. logging.info("File {} md5 checking...".format(fullname))
  31. md5 = hashlib.md5()
  32. with open(fullname, 'rb') as f:
  33. for chunk in iter(lambda: f.read(4096), b""):
  34. md5.update(chunk)
  35. calc_md5sum = md5.hexdigest()
  36. if calc_md5sum != md5sum:
  37. logging.info("File {} md5 check failed, {}(calc) != "
  38. "{}(base)".format(fullname, calc_md5sum, md5sum))
  39. return False
  40. return True
  41. def move_and_merge_tree(src, dst):
  42. """
  43. Move src directory to dst, if dst is already exists,
  44. merge src to dst
  45. """
  46. if not osp.exists(dst):
  47. shutil.move(src, dst)
  48. else:
  49. for fp in os.listdir(src):
  50. src_fp = osp.join(src, fp)
  51. dst_fp = osp.join(dst, fp)
  52. if osp.isdir(src_fp):
  53. if osp.isdir(dst_fp):
  54. move_and_merge_tree(src_fp, dst_fp)
  55. else:
  56. shutil.move(src_fp, dst_fp)
  57. elif osp.isfile(src_fp) and \
  58. not osp.isfile(dst_fp):
  59. shutil.move(src_fp, dst_fp)
  60. def download(url, path, md5sum=None):
  61. """
  62. Download from url, save to path.
  63. url (str): download url
  64. path (str): download to given path
  65. """
  66. if not osp.exists(path):
  67. os.makedirs(path)
  68. fname = osp.split(url)[-1]
  69. fullname = osp.join(path, fname)
  70. retry_cnt = 0
  71. while not (osp.exists(fullname) and md5check(fullname, md5sum)):
  72. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  73. retry_cnt += 1
  74. else:
  75. logging.debug("{} download failed.".format(fname))
  76. raise RuntimeError("Download from {} failed. "
  77. "Retry limit reached".format(url))
  78. logging.info("Downloading {} from {}".format(fname, url))
  79. req = requests.get(url, stream=True)
  80. if req.status_code != 200:
  81. raise RuntimeError("Downloading from {} failed with code "
  82. "{}!".format(url, req.status_code))
  83. # For protecting download interupted, download to
  84. # tmp_fullname firstly, move tmp_fullname to fullname
  85. # after download finished
  86. tmp_fullname = fullname + "_tmp"
  87. total_size = req.headers.get('content-length')
  88. with open(tmp_fullname, 'wb') as f:
  89. if total_size:
  90. download_size = 0
  91. current_time = time.time()
  92. for chunk in tqdm.tqdm(
  93. req.iter_content(chunk_size=1024),
  94. total=(int(total_size) + 1023) // 1024,
  95. unit='KB'):
  96. f.write(chunk)
  97. download_size += 1024
  98. if download_size % 524288 == 0:
  99. total_size_m = round(
  100. int(total_size) / 1024.0 / 1024.0, 2)
  101. download_size_m = round(download_size / 1024.0 /
  102. 1024.0, 2)
  103. speed = int(524288 /
  104. (time.time() - current_time + 0.01) /
  105. 1024.0)
  106. current_time = time.time()
  107. logging.debug(
  108. "Downloading: TotalSize={}M, DownloadSize={}M, Speed={}KB/s"
  109. .format(total_size_m, download_size_m, speed))
  110. else:
  111. for chunk in req.iter_content(chunk_size=1024):
  112. if chunk:
  113. f.write(chunk)
  114. shutil.move(tmp_fullname, fullname)
  115. logging.debug("{} download completed.".format(fname))
  116. return fullname
  117. def decompress(fname):
  118. """
  119. Decompress for zip and tar file
  120. """
  121. logging.info("Decompressing {}...".format(fname))
  122. # For protecting decompressing interupted,
  123. # decompress to fpath_tmp directory firstly, if decompress
  124. # successed, move decompress files to fpath and delete
  125. # fpath_tmp and remove download compress file.
  126. fpath = osp.split(fname)[0]
  127. fpath_tmp = osp.join(fpath, 'tmp')
  128. if osp.isdir(fpath_tmp):
  129. shutil.rmtree(fpath_tmp)
  130. os.makedirs(fpath_tmp)
  131. if fname.find('tar') >= 0 or fname.find('tgz') >= 0:
  132. with tarfile.open(fname) as tf:
  133. def is_within_directory(directory, target):
  134. abs_directory = os.path.abspath(directory)
  135. abs_target = os.path.abspath(target)
  136. prefix = os.path.commonprefix([abs_directory, abs_target])
  137. return prefix == abs_directory
  138. def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
  139. for member in tar.getmembers():
  140. member_path = os.path.join(path, member.name)
  141. if not is_within_directory(path, member_path):
  142. raise Exception("Attempted Path Traversal in Tar File")
  143. tar.extractall(path, members, numeric_owner=numeric_owner)
  144. safe_extract(tf, path=fpath_tmp)
  145. elif fname.find('zip') >= 0:
  146. with zipfile.ZipFile(fname) as zf:
  147. zf.extractall(path=fpath_tmp)
  148. else:
  149. raise TypeError("Unsupport compress file type {}".format(fname))
  150. for f in os.listdir(fpath_tmp):
  151. src_dir = osp.join(fpath_tmp, f)
  152. dst_dir = osp.join(fpath, f)
  153. move_and_merge_tree(src_dir, dst_dir)
  154. shutil.rmtree(fpath_tmp)
  155. logging.debug("{} decompressed.".format(fname))
  156. return dst_dir
  157. def url2dir(url, path):
  158. download(url, path)
  159. if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
  160. fname = osp.split(url)[-1]
  161. savepath = osp.join(path, fname)
  162. return decompress(savepath)
  163. def download_and_decompress(url, path='.'):
  164. nranks = paddle.distributed.get_world_size()
  165. local_rank = paddle.distributed.get_rank()
  166. fname = osp.split(url)[-1]
  167. fullname = osp.join(path, fname)
  168. # if url.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
  169. # fullname = osp.join(path, fname.split('.')[0])
  170. if nranks <= 1:
  171. dst_dir = url2dir(url, path)
  172. if dst_dir is not None:
  173. fullname = dst_dir
  174. else:
  175. lock_path = fullname + '.lock'
  176. if not os.path.exists(fullname):
  177. with open(lock_path, 'w'):
  178. os.utime(lock_path, None)
  179. if local_rank == 0:
  180. dst_dir = url2dir(url, path)
  181. if dst_dir is not None:
  182. fullname = dst_dir
  183. os.remove(lock_path)
  184. else:
  185. while os.path.exists(lock_path):
  186. time.sleep(1)
  187. return fullname