download.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. import os.path as osp
  3. import shutil
  4. import requests
  5. import tqdm
  6. import time
  7. import hashlib
  8. import tarfile
  9. import zipfile
  10. from .utils import logging
  11. DOWNLOAD_RETRY_LIMIT = 3
  12. def md5check(fullname, md5sum=None):
  13. if md5sum is None:
  14. return True
  15. logging.info("File {} md5 checking...".format(fullname))
  16. md5 = hashlib.md5()
  17. with open(fullname, 'rb') as f:
  18. for chunk in iter(lambda: f.read(4096), b""):
  19. md5.update(chunk)
  20. calc_md5sum = md5.hexdigest()
  21. if calc_md5sum != md5sum:
  22. logging.info("File {} md5 check failed, {}(calc) != "
  23. "{}(base)".format(fullname, calc_md5sum, md5sum))
  24. return False
  25. return True
  26. def move_and_merge_tree(src, dst):
  27. """
  28. Move src directory to dst, if dst is already exists,
  29. merge src to dst
  30. """
  31. if not osp.exists(dst):
  32. shutil.move(src, dst)
  33. else:
  34. for fp in os.listdir(src):
  35. src_fp = osp.join(src, fp)
  36. dst_fp = osp.join(dst, fp)
  37. if osp.isdir(src_fp):
  38. if osp.isdir(dst_fp):
  39. move_and_merge_tree(src_fp, dst_fp)
  40. else:
  41. shutil.move(src_fp, dst_fp)
  42. elif osp.isfile(src_fp) and \
  43. not osp.isfile(dst_fp):
  44. shutil.move(src_fp, dst_fp)
  45. def download(url, path, md5sum=None):
  46. """
  47. Download from url, save to path.
  48. url (str): download url
  49. path (str): download to given path
  50. """
  51. if not osp.exists(path):
  52. os.makedirs(path)
  53. fname = osp.split(url)[-1]
  54. fullname = osp.join(path, fname)
  55. retry_cnt = 0
  56. while not (osp.exists(fullname) and md5check(fullname, md5sum)):
  57. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  58. retry_cnt += 1
  59. else:
  60. logging.debug("{} download failed.".format(fname))
  61. raise RuntimeError("Download from {} failed. "
  62. "Retry limit reached".format(url))
  63. logging.info("Downloading {} from {}".format(fname, url))
  64. req = requests.get(url, stream=True)
  65. if req.status_code != 200:
  66. raise RuntimeError("Downloading from {} failed with code "
  67. "{}!".format(url, req.status_code))
  68. # For protecting download interupted, download to
  69. # tmp_fullname firstly, move tmp_fullname to fullname
  70. # after download finished
  71. tmp_fullname = fullname + "_tmp"
  72. total_size = req.headers.get('content-length')
  73. with open(tmp_fullname, 'wb') as f:
  74. if total_size:
  75. download_size = 0
  76. current_time = time.time()
  77. for chunk in tqdm.tqdm(
  78. req.iter_content(chunk_size=1024),
  79. total=(int(total_size) + 1023) // 1024,
  80. unit='KB'):
  81. f.write(chunk)
  82. download_size += 1024
  83. if download_size % 524288 == 0:
  84. total_size_m = round(
  85. int(total_size) / 1024.0 / 1024.0, 2)
  86. download_size_m = round(download_size / 1024.0 /
  87. 1024.0, 2)
  88. speed = int(524288 /
  89. (time.time() - current_time + 0.01) /
  90. 1024.0)
  91. current_time = time.time()
  92. logging.debug(
  93. "Downloading: TotalSize={}M, DownloadSize={}M, Speed={}KB/s"
  94. .format(total_size_m, download_size_m, speed))
  95. else:
  96. for chunk in req.iter_content(chunk_size=1024):
  97. if chunk:
  98. f.write(chunk)
  99. shutil.move(tmp_fullname, fullname)
  100. logging.debug("{} download completed.".format(fname))
  101. return fullname
  102. def decompress(fname):
  103. """
  104. Decompress for zip and tar file
  105. """
  106. logging.info("Decompressing {}...".format(fname))
  107. # For protecting decompressing interupted,
  108. # decompress to fpath_tmp directory firstly, if decompress
  109. # successed, move decompress files to fpath and delete
  110. # fpath_tmp and remove download compress file.
  111. fpath = osp.split(fname)[0]
  112. fpath_tmp = osp.join(fpath, 'tmp')
  113. if osp.isdir(fpath_tmp):
  114. shutil.rmtree(fpath_tmp)
  115. os.makedirs(fpath_tmp)
  116. if fname.find('tar') >= 0 or fname.find('tgz') >= 0:
  117. with tarfile.open(fname) as tf:
  118. def is_within_directory(directory, target):
  119. abs_directory = os.path.abspath(directory)
  120. abs_target = os.path.abspath(target)
  121. prefix = os.path.commonprefix([abs_directory, abs_target])
  122. return prefix == abs_directory
  123. def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
  124. for member in tar.getmembers():
  125. member_path = os.path.join(path, member.name)
  126. if not is_within_directory(path, member_path):
  127. raise Exception("Attempted Path Traversal in Tar File")
  128. tar.extractall(path, members, numeric_owner=numeric_owner)
  129. safe_extract(tf, path=fpath_tmp)
  130. elif fname.find('zip') >= 0:
  131. with zipfile.ZipFile(fname) as zf:
  132. zf.extractall(path=fpath_tmp)
  133. else:
  134. raise TypeError("Unsupport compress file type {}".format(fname))
  135. for f in os.listdir(fpath_tmp):
  136. src_dir = osp.join(fpath_tmp, f)
  137. dst_dir = osp.join(fpath, f)
  138. move_and_merge_tree(src_dir, dst_dir)
  139. shutil.rmtree(fpath_tmp)
  140. logging.debug("{} decompressed.".format(fname))
  141. def download_and_decompress(url, path='.'):
  142. download(url, path)
  143. fname = osp.split(url)[-1]
  144. decompress(osp.join(path, fname))