download.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import time
  17. import shutil
  18. import tarfile
  19. import zipfile
  20. import tempfile
  21. import requests
  22. __all__ = ['download', 'extract', 'download_and_extract']
  23. class _ProgressPrinter(object):
  24. """ ProgressPrinter """
  25. def __init__(self, flush_interval=0.1):
  26. super().__init__()
  27. self._last_time = 0
  28. self._flush_intvl = flush_interval
  29. def print(self, str_, end=False):
  30. """ print """
  31. if end:
  32. str_ += '\n'
  33. self._last_time = 0
  34. if time.time() - self._last_time >= self._flush_intvl:
  35. sys.stdout.write(f"\r{str_}")
  36. self._last_time = time.time()
  37. sys.stdout.flush()
  38. def _download(url, save_path, print_progress):
  39. if print_progress:
  40. print(f"Connecting to {url} ...")
  41. with requests.get(url, stream=True, timeout=15) as r:
  42. r.raise_for_status()
  43. total_length = r.headers.get('content-length')
  44. if total_length is None:
  45. with open(save_path, 'wb') as f:
  46. shutil.copyfileobj(r.raw, f)
  47. else:
  48. with open(save_path, 'wb') as f:
  49. dl = 0
  50. total_length = int(total_length)
  51. if print_progress:
  52. printer = _ProgressPrinter()
  53. print(f"Downloading {os.path.basename(save_path)} ...")
  54. for data in r.iter_content(chunk_size=4096):
  55. dl += len(data)
  56. f.write(data)
  57. if print_progress:
  58. done = int(50 * dl / total_length)
  59. printer.print(
  60. f"[{'=' * done:<50s}] {float(100 * dl) / total_length:.2f}%"
  61. )
  62. if print_progress:
  63. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  64. def _extract_zip_file(file_path, extd_dir):
  65. """ extract zip file """
  66. with zipfile.ZipFile(file_path, 'r') as f:
  67. file_list = f.namelist()
  68. total_num = len(file_list)
  69. for index, file in enumerate(file_list):
  70. f.extract(file, extd_dir)
  71. yield total_num, index
  72. def _extract_tar_file(file_path, extd_dir):
  73. """ extract tar file """
  74. with tarfile.open(file_path, 'r:*') as f:
  75. file_list = f.getnames()
  76. total_num = len(file_list)
  77. for index, file in enumerate(file_list):
  78. f.extract(file, extd_dir)
  79. yield total_num, index
  80. def _extract(file_path, extd_dir, print_progress):
  81. """ extract """
  82. if print_progress:
  83. printer = _ProgressPrinter()
  84. print(f"Extracting {os.path.basename(file_path)}")
  85. if zipfile.is_zipfile(file_path):
  86. handler = _extract_zip_file
  87. elif tarfile.is_tarfile(file_path):
  88. handler = _extract_tar_file
  89. else:
  90. raise RuntimeError("Unsupported file format.")
  91. for total_num, index in handler(file_path, extd_dir):
  92. if print_progress:
  93. done = int(50 * float(index) / total_num)
  94. printer.print(
  95. f"[{'=' * done:<50s}] {float(100 * index) / total_num:.2f}%")
  96. if print_progress:
  97. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  98. def _remove_if_exists(path):
  99. """ remove """
  100. if os.path.exists(path):
  101. if os.path.isdir(path):
  102. shutil.rmtree(path)
  103. else:
  104. os.remove(path)
  105. def download(url, save_path, print_progress=True, overwrite=False):
  106. """ download """
  107. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  108. if overwrite:
  109. _remove_if_exists(save_path)
  110. if not os.path.exists(save_path):
  111. _download(url, save_path, print_progress=print_progress)
  112. def extract(file_path, extd_dir, print_progress=True):
  113. """ extract """
  114. return _extract(file_path, extd_dir, print_progress=print_progress)
  115. def download_and_extract(url,
  116. save_dir,
  117. dst_name,
  118. print_progress=True,
  119. overwrite=False,
  120. no_interm_dir=True):
  121. """ download and extract """
  122. # NOTE: `url` MUST come from a trusted source, since we do not provide a solution
  123. # to secure against CVE-2007-4559.
  124. os.makedirs(save_dir, exist_ok=True)
  125. dst_path = os.path.join(save_dir, dst_name)
  126. if overwrite:
  127. _remove_if_exists(dst_path)
  128. if not os.path.exists(dst_path):
  129. with tempfile.TemporaryDirectory() as td:
  130. arc_file_path = os.path.join(td, url.split('/')[-1])
  131. extd_dir = os.path.splitext(arc_file_path)[0]
  132. _download(url, arc_file_path, print_progress=print_progress)
  133. tmp_extd_dir = os.path.join(td, 'extract')
  134. _extract(arc_file_path, tmp_extd_dir, print_progress=print_progress)
  135. if no_interm_dir:
  136. file_names = os.listdir(tmp_extd_dir)
  137. if len(file_names) == 1:
  138. file_name = file_names[0]
  139. else:
  140. file_name = dst_name
  141. sp = os.path.join(tmp_extd_dir, file_name)
  142. if not os.path.exists(sp):
  143. raise FileNotFoundError
  144. dp = os.path.join(save_dir, file_name)
  145. if os.path.isdir(sp):
  146. shutil.copytree(sp, dp)
  147. else:
  148. shutil.copyfile(sp, dp)
  149. extd_file = dp
  150. else:
  151. shutil.copytree(tmp_extd_dir, extd_dir)
  152. extd_file = extd_dir
  153. if not os.path.exists(dst_path) or not os.path.samefile(extd_file,
  154. dst_path):
  155. shutil.move(extd_file, dst_path)