download.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import sys
  13. import time
  14. import shutil
  15. import tarfile
  16. import zipfile
  17. import tempfile
  18. import requests
  19. __all__ = ['download', 'extract', 'download_and_extract']
  20. class _ProgressPrinter(object):
  21. """ ProgressPrinter """
  22. def __init__(self, flush_interval=0.1):
  23. super().__init__()
  24. self._last_time = 0
  25. self._flush_intvl = flush_interval
  26. def print(self, str_, end=False):
  27. """ print """
  28. if end:
  29. str_ += '\n'
  30. self._last_time = 0
  31. if time.time() - self._last_time >= self._flush_intvl:
  32. sys.stdout.write(f"\r{str_}")
  33. self._last_time = time.time()
  34. sys.stdout.flush()
  35. def _download(url, save_path, print_progress):
  36. if print_progress:
  37. print(f"Connecting to {url} ...")
  38. with requests.get(url, stream=True, timeout=15) as r:
  39. r.raise_for_status()
  40. total_length = r.headers.get('content-length')
  41. if total_length is None:
  42. with open(save_path, 'wb') as f:
  43. shutil.copyfileobj(r.raw, f)
  44. else:
  45. with open(save_path, 'wb') as f:
  46. dl = 0
  47. total_length = int(total_length)
  48. if print_progress:
  49. printer = _ProgressPrinter()
  50. print(f"Downloading {os.path.basename(save_path)} ...")
  51. for data in r.iter_content(chunk_size=4096):
  52. dl += len(data)
  53. f.write(data)
  54. if print_progress:
  55. done = int(50 * dl / total_length)
  56. printer.print(
  57. f"[{'=' * done:<50s}] {float(100 * dl) / total_length:.2f}%"
  58. )
  59. if print_progress:
  60. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  61. def _extract_zip_file(file_path, extd_dir):
  62. """ extract zip file """
  63. with zipfile.ZipFile(file_path, 'r') as f:
  64. file_list = f.namelist()
  65. total_num = len(file_list)
  66. for index, file in enumerate(file_list):
  67. f.extract(file, extd_dir)
  68. yield total_num, index
  69. def _extract_tar_file(file_path, extd_dir):
  70. """ extract tar file """
  71. with tarfile.open(file_path, 'r:*') as f:
  72. file_list = f.getnames()
  73. total_num = len(file_list)
  74. for index, file in enumerate(file_list):
  75. f.extract(file, extd_dir)
  76. yield total_num, index
  77. def _extract(file_path, extd_dir, print_progress):
  78. """ extract """
  79. if print_progress:
  80. printer = _ProgressPrinter()
  81. print(f"Extracting {os.path.basename(file_path)}")
  82. if zipfile.is_zipfile(file_path):
  83. handler = _extract_zip_file
  84. elif tarfile.is_tarfile(file_path):
  85. handler = _extract_tar_file
  86. else:
  87. raise RuntimeError("Unsupported file format.")
  88. for total_num, index in handler(file_path, extd_dir):
  89. if print_progress:
  90. done = int(50 * float(index) / total_num)
  91. printer.print(
  92. f"[{'=' * done:<50s}] {float(100 * index) / total_num:.2f}%")
  93. if print_progress:
  94. printer.print(f"[{'=' * 50:<50s}] {100:.2f}%", end=True)
  95. def _remove_if_exists(path):
  96. """ remove """
  97. if os.path.exists(path):
  98. if os.path.isdir(path):
  99. shutil.rmtree(path)
  100. else:
  101. os.remove(path)
  102. def download(url, save_path, print_progress=True, overwrite=False):
  103. """ download """
  104. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  105. if overwrite:
  106. _remove_if_exists(save_path)
  107. if not os.path.exists(save_path):
  108. _download(url, save_path, print_progress=print_progress)
  109. def extract(file_path, extd_dir, print_progress=True):
  110. """ extract """
  111. return _extract(file_path, extd_dir, print_progress=print_progress)
  112. def download_and_extract(url,
  113. save_dir,
  114. dst_name,
  115. print_progress=True,
  116. overwrite=False,
  117. no_interm_dir=True):
  118. """ download and extract """
  119. # NOTE: `url` MUST come from a trusted source, since we do not provide a solution
  120. # to secure against CVE-2007-4559.
  121. os.makedirs(save_dir, exist_ok=True)
  122. dst_path = os.path.join(save_dir, dst_name)
  123. if overwrite:
  124. _remove_if_exists(dst_path)
  125. if not os.path.exists(dst_path):
  126. with tempfile.TemporaryDirectory() as td:
  127. arc_file_path = os.path.join(td, url.split('/')[-1])
  128. extd_dir = os.path.splitext(arc_file_path)[0]
  129. _download(url, arc_file_path, print_progress=print_progress)
  130. tmp_extd_dir = os.path.join(td, 'extract')
  131. _extract(arc_file_path, tmp_extd_dir, print_progress=print_progress)
  132. if no_interm_dir:
  133. file_names = os.listdir(tmp_extd_dir)
  134. if len(file_names) == 1:
  135. file_name = file_names[0]
  136. else:
  137. file_name = dst_name
  138. sp = os.path.join(tmp_extd_dir, file_name)
  139. if not os.path.exists(sp):
  140. raise FileNotFoundError
  141. dp = os.path.join(save_dir, file_name)
  142. if os.path.isdir(sp):
  143. shutil.copytree(sp, dp)
  144. else:
  145. shutil.copyfile(sp, dp)
  146. extd_file = dp
  147. else:
  148. shutil.copytree(tmp_extd_dir, extd_dir)
  149. extd_file = extd_dir
  150. if not os.path.exists(dst_path) or not os.path.samefile(extd_file,
  151. dst_path):
  152. shutil.move(extd_file, dst_path)