download.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import time
  17. import shutil
  18. import tarfile
  19. import zipfile
  20. import tempfile
  21. import requests
  22. __all__ = ['download', 'extract', 'download_and_extract']
  23. class _ProgressPrinter(object):
  24. """ ProgressPrinter """
  25. def __init__(self, flush_interval=0.1):
  26. super().__init__()
  27. self._last_time = 0
  28. self._flush_intvl = flush_interval
  29. def print(self, str_, end=False):
  30. """ print """
  31. if end:
  32. str_ += '\n'
  33. self._last_time = 0
  34. if time.time() - self._last_time >= self._flush_intvl:
  35. sys.stdout.write(f"\r{str_}")
  36. self._last_time = time.time()
  37. sys.stdout.flush()
  38. def _download(url, save_path, print_progress):
  39. if print_progress:
  40. print(f"Connecting to {url} ...")
  41. with requests.get(url, stream=True, timeout=15) as r:
  42. r.raise_for_status()
  43. total_length = r.headers.get('content-length')
  44. if total_length is None:
  45. with open(save_path, 'wb') as f:
  46. shutil.copyfileobj(r.raw, f)
  47. else:
  48. with open(save_path, 'wb') as f:
  49. dl = 0
  50. total_length = int(total_length)
  51. if print_progress:
  52. printer = _ProgressPrinter()
  53. print(f"Downloading {os.path.basename(save_path)} ...")
  54. for data in r.iter_content(chunk_size=4096):
  55. dl += len(data)
  56. f.write(data)
  57. if print_progress:
  58. done = int(50 * dl / total_length)
  59. printer.print(
  60. f"[{'=' * done:<50s}] {float(100 * dl) / total_length:.2f}%"
  61. )
  62. if print_progress:
  63. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  64. def _extract_zip_file(file_path, extd_dir):
  65. """ extract zip file """
  66. with zipfile.ZipFile(file_path, 'r') as f:
  67. file_list = f.namelist()
  68. total_num = len(file_list)
  69. for index, file in enumerate(file_list):
  70. f.extract(file, extd_dir)
  71. yield total_num, index
  72. def _extract_tar_file(file_path, extd_dir):
  73. """ extract tar file """
  74. try:
  75. with tarfile.open(file_path, 'r:*') as f:
  76. file_list = f.getnames()
  77. total_num = len(file_list)
  78. for index, file in enumerate(file_list):
  79. try:
  80. f.extract(file, extd_dir)
  81. except KeyError:
  82. print(f"File {file} not found in the archive.")
  83. yield total_num, index
  84. except Exception as e:
  85. print(f"An error occurred: {e}")
  86. def _extract(file_path, extd_dir, print_progress):
  87. """ extract """
  88. if print_progress:
  89. printer = _ProgressPrinter()
  90. print(f"Extracting {os.path.basename(file_path)}")
  91. if zipfile.is_zipfile(file_path):
  92. handler = _extract_zip_file
  93. elif tarfile.is_tarfile(file_path):
  94. handler = _extract_tar_file
  95. else:
  96. raise RuntimeError("Unsupported file format.")
  97. for total_num, index in handler(file_path, extd_dir):
  98. if print_progress:
  99. done = int(50 * float(index) / total_num)
  100. printer.print(
  101. f"[{'=' * done:<50s}] {float(100 * index) / total_num:.2f}%")
  102. if print_progress:
  103. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  104. def _remove_if_exists(path):
  105. """ remove """
  106. if os.path.exists(path):
  107. if os.path.isdir(path):
  108. shutil.rmtree(path)
  109. else:
  110. os.remove(path)
  111. def download(url, save_path, print_progress=True, overwrite=False):
  112. """ download """
  113. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  114. if overwrite:
  115. _remove_if_exists(save_path)
  116. if not os.path.exists(save_path):
  117. _download(url, save_path, print_progress=print_progress)
  118. def extract(file_path, extd_dir, print_progress=True):
  119. """ extract """
  120. return _extract(file_path, extd_dir, print_progress=print_progress)
  121. def download_and_extract(url,
  122. save_dir,
  123. dst_name,
  124. print_progress=True,
  125. overwrite=False,
  126. no_interm_dir=True):
  127. """ download and extract """
  128. # NOTE: `url` MUST come from a trusted source, since we do not provide a solution
  129. # to secure against CVE-2007-4559.
  130. os.makedirs(save_dir, exist_ok=True)
  131. dst_path = os.path.join(save_dir, dst_name)
  132. if overwrite:
  133. _remove_if_exists(dst_path)
  134. if not os.path.exists(dst_path):
  135. with tempfile.TemporaryDirectory() as td:
  136. arc_file_path = os.path.join(td, url.split('/')[-1])
  137. extd_dir = os.path.splitext(arc_file_path)[0]
  138. _download(url, arc_file_path, print_progress=print_progress)
  139. tmp_extd_dir = os.path.join(td, 'extract')
  140. _extract(arc_file_path, tmp_extd_dir, print_progress=print_progress)
  141. if no_interm_dir:
  142. file_names = os.listdir(tmp_extd_dir)
  143. if len(file_names) == 1:
  144. file_name = file_names[0]
  145. else:
  146. file_name = dst_name
  147. sp = os.path.join(tmp_extd_dir, file_name)
  148. if not os.path.exists(sp):
  149. raise FileNotFoundError
  150. dp = os.path.join(save_dir, file_name)
  151. if os.path.isdir(sp):
  152. shutil.copytree(sp, dp, symlinks=True)
  153. else:
  154. shutil.copyfile(sp, dp)
  155. extd_file = dp
  156. else:
  157. shutil.copytree(tmp_extd_dir, extd_dir)
  158. extd_file = extd_dir
  159. if not os.path.exists(dst_path) or not os.path.samefile(extd_file,
  160. dst_path):
  161. shutil.move(extd_file, dst_path)