download.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import os
  16. import shutil
  17. import sys
  18. import tarfile
  19. import time
  20. import zipfile
  21. import requests
  22. lasttime = time.time()
  23. FLUSH_INTERVAL = 0.1
  24. def progress(str, end=False):
  25. global lasttime
  26. if end:
  27. str += "\n"
  28. lasttime = 0
  29. if time.time() - lasttime >= FLUSH_INTERVAL:
  30. sys.stdout.write("\r%s" % str)
  31. lasttime = time.time()
  32. sys.stdout.flush()
  33. def _download_file(url, savepath, print_progress):
  34. if print_progress:
  35. print("Connecting to {}".format(url))
  36. r = requests.get(url, stream=True, timeout=15)
  37. total_length = r.headers.get('content-length')
  38. if total_length is None:
  39. with open(savepath, 'wb') as f:
  40. shutil.copyfileobj(r.raw, f)
  41. else:
  42. with open(savepath, 'wb') as f:
  43. dl = 0
  44. total_length = int(total_length)
  45. starttime = time.time()
  46. if print_progress:
  47. print("Downloading %s" % os.path.basename(savepath))
  48. for data in r.iter_content(chunk_size=4096):
  49. dl += len(data)
  50. f.write(data)
  51. if print_progress:
  52. done = int(50 * dl / total_length)
  53. progress("[%-50s] %.2f%%" %
  54. ('=' * done, float(100 * dl) / total_length))
  55. if print_progress:
  56. progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
  57. def _uncompress_file_zip(filepath, extrapath):
  58. files = zipfile.ZipFile(filepath, 'r')
  59. filelist = files.namelist()
  60. rootpath = filelist[0]
  61. total_num = len(filelist)
  62. for index, file in enumerate(filelist):
  63. files.extract(file, extrapath)
  64. yield total_num, index, rootpath
  65. files.close()
  66. yield total_num, index, rootpath
  67. def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
  68. files = tarfile.open(filepath, mode)
  69. filelist = files.getnames()
  70. total_num = len(filelist)
  71. rootpath = filelist[0]
  72. for index, file in enumerate(filelist):
  73. files.extract(file, extrapath)
  74. yield total_num, index, rootpath
  75. files.close()
  76. yield total_num, index, rootpath
  77. def _uncompress_file(filepath, extrapath, delete_file, print_progress):
  78. if print_progress:
  79. print("Uncompress %s" % os.path.basename(filepath))
  80. if filepath.endswith("zip"):
  81. handler = _uncompress_file_zip
  82. elif filepath.endswith("tgz"):
  83. handler = functools.partial(_uncompress_file_tar, mode="r:*")
  84. else:
  85. handler = functools.partial(_uncompress_file_tar, mode="r")
  86. for total_num, index, rootpath in handler(filepath, extrapath):
  87. if print_progress:
  88. done = int(50 * float(index) / total_num)
  89. progress("[%-50s] %.2f%%" %
  90. ('=' * done, float(100 * index) / total_num))
  91. if print_progress:
  92. progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
  93. if delete_file:
  94. os.remove(filepath)
  95. return rootpath
  96. def download_file_and_uncompress(url,
  97. savepath=None,
  98. extrapath=None,
  99. extraname=None,
  100. print_progress=True,
  101. cover=False,
  102. delete_file=True):
  103. if savepath is None:
  104. savepath = "."
  105. if extrapath is None:
  106. extrapath = "."
  107. savename = url.split("/")[-1]
  108. if not os.path.exists(savepath):
  109. os.makedirs(savepath)
  110. savepath = os.path.join(savepath, savename)
  111. savename = ".".join(savename.split(".")[:-1])
  112. savename = os.path.join(extrapath, savename)
  113. extraname = savename if extraname is None else os.path.join(extrapath,
  114. extraname)
  115. if cover:
  116. if os.path.exists(savepath):
  117. shutil.rmtree(savepath)
  118. if os.path.exists(savename):
  119. shutil.rmtree(savename)
  120. if os.path.exists(extraname):
  121. shutil.rmtree(extraname)
  122. if not os.path.exists(extraname):
  123. if not os.path.exists(savename):
  124. if not os.path.exists(savepath):
  125. _download_file(url, savepath, print_progress)
  126. if (not tarfile.is_tarfile(savepath)) and (
  127. not zipfile.is_zipfile(savepath)):
  128. if not os.path.exists(extraname):
  129. os.makedirs(extraname)
  130. shutil.move(savepath, extraname)
  131. return extraname
  132. savename = _uncompress_file(savepath, extrapath, delete_file,
  133. print_progress)
  134. savename = os.path.join(extrapath, savename)
  135. shutil.move(savename, extraname)
  136. return extraname