model_zoo.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import requests
  19. import shutil
  20. import tarfile
  21. import tqdm
  22. import zipfile
  23. from paddlex.ppcls.arch import similar_architectures
  24. from paddlex.ppcls.utils import logger
  25. __all__ = ['get']
  26. DOWNLOAD_RETRY_LIMIT = 3
  27. class UrlError(Exception):
  28. """ UrlError
  29. """
  30. def __init__(self, url='', code=''):
  31. message = "Downloading from {} failed with code {}!".format(url, code)
  32. super(UrlError, self).__init__(message)
  33. class ModelNameError(Exception):
  34. """ ModelNameError
  35. """
  36. def __init__(self, message=''):
  37. super(ModelNameError, self).__init__(message)
  38. class RetryError(Exception):
  39. """ RetryError
  40. """
  41. def __init__(self, url='', times=''):
  42. message = "Download from {} failed. Retry({}) limit reached".format(
  43. url, times)
  44. super(RetryError, self).__init__(message)
  45. def _get_url(architecture, postfix="pdparams"):
  46. prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/"
  47. fname = architecture + "_pretrained." + postfix
  48. return prefix + fname
  49. def _move_and_merge_tree(src, dst):
  50. """
  51. Move src directory to dst, if dst is already exists,
  52. merge src to dst
  53. """
  54. if not os.path.exists(dst):
  55. shutil.move(src, dst)
  56. elif os.path.isfile(src):
  57. shutil.move(src, dst)
  58. else:
  59. for fp in os.listdir(src):
  60. src_fp = os.path.join(src, fp)
  61. dst_fp = os.path.join(dst, fp)
  62. if os.path.isdir(src_fp):
  63. if os.path.isdir(dst_fp):
  64. _move_and_merge_tree(src_fp, dst_fp)
  65. else:
  66. shutil.move(src_fp, dst_fp)
  67. elif os.path.isfile(src_fp) and \
  68. not os.path.isfile(dst_fp):
  69. shutil.move(src_fp, dst_fp)
  70. def _download(url, path):
  71. """
  72. Download from url, save to path.
  73. url (str): download url
  74. path (str): download to given path
  75. """
  76. if not os.path.exists(path):
  77. os.makedirs(path)
  78. fname = os.path.split(url)[-1]
  79. fullname = os.path.join(path, fname)
  80. retry_cnt = 0
  81. while not os.path.exists(fullname):
  82. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  83. retry_cnt += 1
  84. else:
  85. raise RetryError(url, DOWNLOAD_RETRY_LIMIT)
  86. logger.info("Downloading {} from {}".format(fname, url))
  87. req = requests.get(url, stream=True)
  88. if req.status_code != 200:
  89. raise UrlError(url, req.status_code)
  90. # For protecting download interupted, download to
  91. # tmp_fullname firstly, move tmp_fullname to fullname
  92. # after download finished
  93. tmp_fullname = fullname + "_tmp"
  94. total_size = req.headers.get('content-length')
  95. with open(tmp_fullname, 'wb') as f:
  96. if total_size:
  97. for chunk in tqdm.tqdm(
  98. req.iter_content(chunk_size=1024),
  99. total=(int(total_size) + 1023) // 1024,
  100. unit='KB'):
  101. f.write(chunk)
  102. else:
  103. for chunk in req.iter_content(chunk_size=1024):
  104. if chunk:
  105. f.write(chunk)
  106. shutil.move(tmp_fullname, fullname)
  107. return fullname
  108. def _decompress(fname):
  109. """
  110. Decompress for zip and tar file
  111. """
  112. logger.info("Decompressing {}...".format(fname))
  113. # For protecting decompressing interupted,
  114. # decompress to fpath_tmp directory firstly, if decompress
  115. # successed, move decompress files to fpath and delete
  116. # fpath_tmp and remove download compress file.
  117. fpath = os.path.split(fname)[0]
  118. fpath_tmp = os.path.join(fpath, 'tmp')
  119. if os.path.isdir(fpath_tmp):
  120. shutil.rmtree(fpath_tmp)
  121. os.makedirs(fpath_tmp)
  122. if fname.find('tar') >= 0:
  123. with tarfile.open(fname) as tf:
  124. def is_within_directory(directory, target):
  125. abs_directory = os.path.abspath(directory)
  126. abs_target = os.path.abspath(target)
  127. prefix = os.path.commonprefix([abs_directory, abs_target])
  128. return prefix == abs_directory
  129. def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
  130. for member in tar.getmembers():
  131. member_path = os.path.join(path, member.name)
  132. if not is_within_directory(path, member_path):
  133. raise Exception("Attempted Path Traversal in Tar File")
  134. tar.extractall(path, members, numeric_owner=numeric_owner)
  135. safe_extract(tf, path=fpath_tmp)
  136. elif fname.find('zip') >= 0:
  137. with zipfile.ZipFile(fname) as zf:
  138. zf.extractall(path=fpath_tmp)
  139. else:
  140. raise TypeError("Unsupport compress file type {}".format(fname))
  141. fs = os.listdir(fpath_tmp)
  142. assert len(
  143. fs
  144. ) == 1, "There should just be 1 pretrained path in an archive file but got {}.".format(
  145. len(fs))
  146. f = fs[0]
  147. src_dir = os.path.join(fpath_tmp, f)
  148. dst_dir = os.path.join(fpath, f)
  149. _move_and_merge_tree(src_dir, dst_dir)
  150. shutil.rmtree(fpath_tmp)
  151. os.remove(fname)
  152. return f
  153. def _get_pretrained():
  154. with open('./ppcls/utils/pretrained.list') as flist:
  155. pretrained = [line.strip() for line in flist]
  156. return pretrained
  157. def _check_pretrained_name(architecture):
  158. assert isinstance(architecture, str), \
  159. ("the type of architecture({}) should be str". format(architecture))
  160. pretrained = _get_pretrained()
  161. similar_names = similar_architectures(architecture, pretrained)
  162. model_list = ', '.join(similar_names)
  163. err = "{} is not exist! Maybe you want: [{}]" \
  164. "".format(architecture, model_list)
  165. if architecture not in similar_names:
  166. raise ModelNameError(err)
  167. def list_models():
  168. pretrained = _get_pretrained()
  169. msg = "All avialable pretrained models are as follows: {}".format(
  170. pretrained)
  171. logger.info(msg)
  172. return
  173. def get(architecture, path, decompress=False, postfix="pdparams"):
  174. """
  175. Get the pretrained model.
  176. """
  177. _check_pretrained_name(architecture)
  178. url = _get_url(architecture, postfix=postfix)
  179. fname = _download(url, path)
  180. if postfix == "tar" and decompress:
  181. _decompress(fname)
  182. logger.info("download {} finished ".format(fname))