download.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import sys
  17. import tarfile
  18. import tempfile
  19. import time
  20. import zipfile
  21. import requests
  22. __all__ = ["download", "extract", "download_and_extract"]
  23. class _ProgressPrinter(object):
  24. """ProgressPrinter"""
  25. def __init__(self, flush_interval=0.1):
  26. super().__init__()
  27. self._last_time = 0
  28. self._flush_intvl = flush_interval
  29. def print(self, str_, end=False):
  30. """print"""
  31. if end:
  32. str_ += "\n"
  33. self._last_time = 0
  34. if time.time() - self._last_time >= self._flush_intvl:
  35. sys.stdout.write(f"\r{str_}")
  36. self._last_time = time.time()
  37. sys.stdout.flush()
  38. def _download(url, save_path, print_progress):
  39. if print_progress:
  40. print(f"Connecting to {url} ...")
  41. with requests.get(url, stream=True, timeout=15) as r:
  42. r.raise_for_status()
  43. total_length = r.headers.get("content-length")
  44. if total_length is None:
  45. with open(save_path, "wb") as f:
  46. shutil.copyfileobj(r.raw, f)
  47. else:
  48. with open(save_path, "wb") as f:
  49. dl = 0
  50. total_length = int(total_length)
  51. if print_progress:
  52. printer = _ProgressPrinter()
  53. print(f"Downloading {os.path.basename(save_path)} ...")
  54. for data in r.iter_content(chunk_size=4096):
  55. dl += len(data)
  56. f.write(data)
  57. if print_progress:
  58. done = int(50 * dl / total_length)
  59. printer.print(
  60. f"[{'=' * done:<50s}] {float(100 * dl) / total_length:.2f}%"
  61. )
  62. if print_progress:
  63. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  64. def _extract_zip_file(file_path, extd_dir):
  65. """extract zip file"""
  66. with zipfile.ZipFile(file_path, "r") as f:
  67. file_list = f.namelist()
  68. total_num = len(file_list)
  69. for index, file in enumerate(file_list):
  70. f.extract(file, extd_dir)
  71. yield total_num, index
  72. def _extract_tar_file(file_path, extd_dir):
  73. """extract tar file"""
  74. try:
  75. with tarfile.open(file_path, "r:*") as f:
  76. file_list = f.getnames()
  77. total_num = len(file_list)
  78. for index, file in enumerate(file_list):
  79. try:
  80. f.extract(file, extd_dir)
  81. except KeyError:
  82. print(f"File {file} not found in the archive.")
  83. yield total_num, index
  84. except Exception as e:
  85. print(f"An error occurred: {e}")
  86. def _extract(file_path, extd_dir, print_progress):
  87. """extract"""
  88. if print_progress:
  89. printer = _ProgressPrinter()
  90. print(f"Extracting {os.path.basename(file_path)}")
  91. if zipfile.is_zipfile(file_path):
  92. handler = _extract_zip_file
  93. elif tarfile.is_tarfile(file_path):
  94. handler = _extract_tar_file
  95. else:
  96. raise RuntimeError("Unsupported file format.")
  97. for total_num, index in handler(file_path, extd_dir):
  98. if print_progress:
  99. done = int(50 * float(index) / total_num)
  100. printer.print(f"[{'=' * done:<50s}] {float(100 * index) / total_num:.2f}%")
  101. if print_progress:
  102. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  103. def _remove_if_exists(path):
  104. """remove"""
  105. if os.path.exists(path):
  106. if os.path.isdir(path):
  107. shutil.rmtree(path)
  108. else:
  109. os.remove(path)
  110. def download(url, save_path, print_progress=True, overwrite=False):
  111. """download"""
  112. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  113. if overwrite:
  114. _remove_if_exists(save_path)
  115. if not os.path.exists(save_path):
  116. _download(url, save_path, print_progress=print_progress)
  117. def extract(file_path, extd_dir, print_progress=True):
  118. """extract"""
  119. return _extract(file_path, extd_dir, print_progress=print_progress)
  120. def download_and_extract(
  121. url, save_dir, dst_name, print_progress=True, overwrite=False, no_interm_dir=True
  122. ):
  123. """download and extract"""
  124. # NOTE: `url` MUST come from a trusted source, since we do not provide a solution
  125. # to secure against CVE-2007-4559.
  126. os.makedirs(save_dir, exist_ok=True)
  127. dst_path = os.path.join(save_dir, dst_name)
  128. if overwrite:
  129. _remove_if_exists(dst_path)
  130. if not os.path.exists(dst_path):
  131. with tempfile.TemporaryDirectory() as td:
  132. arc_file_path = os.path.join(td, url.split("/")[-1])
  133. extd_dir = os.path.splitext(arc_file_path)[0]
  134. _download(url, arc_file_path, print_progress=print_progress)
  135. tmp_extd_dir = os.path.join(td, "extract")
  136. _extract(arc_file_path, tmp_extd_dir, print_progress=print_progress)
  137. if no_interm_dir:
  138. file_names = os.listdir(tmp_extd_dir)
  139. if len(file_names) == 1:
  140. file_name = file_names[0]
  141. else:
  142. file_name = dst_name
  143. sp = os.path.join(tmp_extd_dir, file_name)
  144. if not os.path.exists(sp):
  145. raise FileNotFoundError
  146. dp = os.path.join(save_dir, file_name)
  147. if os.path.isdir(sp):
  148. shutil.copytree(sp, dp, symlinks=True)
  149. else:
  150. shutil.copyfile(sp, dp)
  151. extd_file = dp
  152. else:
  153. shutil.copytree(tmp_extd_dir, extd_dir)
  154. extd_file = extd_dir
  155. if not os.path.exists(dst_path) or not os.path.samefile(
  156. extd_file, dst_path
  157. ):
  158. shutil.move(extd_file, dst_path)