download.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import shutil
  15. import tarfile
  16. import tempfile
  17. import zipfile
  18. from pathlib import Path
  19. from urllib.request import urlopen
  20. def _download(url, save_path):
  21. with urlopen(url) as r:
  22. with open(save_path, "wb") as file:
  23. shutil.copyfileobj(r, file)
  24. def _extract_zip_file(file_path, extd_dir):
  25. with zipfile.ZipFile(file_path, "r") as f:
  26. file_list = f.namelist()
  27. for file in file_list:
  28. f.extract(file, extd_dir)
  29. def _extract_tar_file(file_path, extd_dir):
  30. with tarfile.open(file_path, "r:*") as f:
  31. file_list = f.getnames()
  32. for file in file_list:
  33. f.extract(file, extd_dir)
  34. def _extract(file_path, extd_dir):
  35. if zipfile.is_zipfile(file_path):
  36. handler = _extract_zip_file
  37. elif tarfile.is_tarfile(file_path):
  38. handler = _extract_tar_file
  39. else:
  40. raise ValueError("Unsupported file format")
  41. handler(file_path, extd_dir)
  42. def _remove_if_exists(path):
  43. if path.exists():
  44. if path.is_dir():
  45. shutil.rmtree(path)
  46. else:
  47. path.unlink()
  48. def download(url, save_path, overwrite=False):
  49. save_path.parent.mkdir(exist_ok=True)
  50. if overwrite:
  51. _remove_if_exists(save_path)
  52. if not save_path.exists():
  53. _download(url, save_path)
  54. def extract(file_path, extd_dir):
  55. return _extract(file_path, extd_dir)
  56. def download_and_extract(url, save_dir, dst_name, overwrite=False, no_interm_dir=True):
  57. save_dir = Path(save_dir)
  58. save_dir.mkdir(exist_ok=True)
  59. dst_path = save_dir / dst_name
  60. if overwrite:
  61. _remove_if_exists(dst_path)
  62. if not dst_path.exists():
  63. with tempfile.TemporaryDirectory() as td:
  64. td = Path(td)
  65. arc_file_path = td / url.split("/")[-1]
  66. extd_dir = arc_file_path.stem
  67. _download(url, arc_file_path)
  68. tmp_extd_dir = td / "extracted"
  69. _extract(arc_file_path, tmp_extd_dir)
  70. if no_interm_dir:
  71. paths = list(tmp_extd_dir.iterdir())
  72. if len(paths) == 1:
  73. sp = paths[0]
  74. else:
  75. sp = tmp_extd_dir / dst_name
  76. if not sp.exists():
  77. raise FileNotFoundError
  78. dp = save_dir / sp.name
  79. if sp.is_dir():
  80. shutil.copytree(sp, dp)
  81. else:
  82. shutil.copyfile(sp, dp)
  83. extd_file = dp
  84. else:
  85. shutil.copytree(tmp_extd_dir, extd_dir)
  86. extd_file = extd_dir
  87. if not dst_path.exists() or not extd_file.samefile(dst_path):
  88. shutil.move(extd_file, dst_path)