| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import requests
- import shutil
- import tarfile
- import tqdm
- import zipfile
- from paddlex.ppcls.arch import similar_architectures
- from paddlex.ppcls.utils import logger
- __all__ = ['get']
- DOWNLOAD_RETRY_LIMIT = 3
- class UrlError(Exception):
- """ UrlError
- """
- def __init__(self, url='', code=''):
- message = "Downloading from {} failed with code {}!".format(url, code)
- super(UrlError, self).__init__(message)
- class ModelNameError(Exception):
- """ ModelNameError
- """
- def __init__(self, message=''):
- super(ModelNameError, self).__init__(message)
- class RetryError(Exception):
- """ RetryError
- """
- def __init__(self, url='', times=''):
- message = "Download from {} failed. Retry({}) limit reached".format(
- url, times)
- super(RetryError, self).__init__(message)
- def _get_url(architecture, postfix="pdparams"):
- prefix = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/"
- fname = architecture + "_pretrained." + postfix
- return prefix + fname
- def _move_and_merge_tree(src, dst):
- """
- Move src directory to dst, if dst is already exists,
- merge src to dst
- """
- if not os.path.exists(dst):
- shutil.move(src, dst)
- elif os.path.isfile(src):
- shutil.move(src, dst)
- else:
- for fp in os.listdir(src):
- src_fp = os.path.join(src, fp)
- dst_fp = os.path.join(dst, fp)
- if os.path.isdir(src_fp):
- if os.path.isdir(dst_fp):
- _move_and_merge_tree(src_fp, dst_fp)
- else:
- shutil.move(src_fp, dst_fp)
- elif os.path.isfile(src_fp) and \
- not os.path.isfile(dst_fp):
- shutil.move(src_fp, dst_fp)
- def _download(url, path):
- """
- Download from url, save to path.
- url (str): download url
- path (str): download to given path
- """
- if not os.path.exists(path):
- os.makedirs(path)
- fname = os.path.split(url)[-1]
- fullname = os.path.join(path, fname)
- retry_cnt = 0
- while not os.path.exists(fullname):
- if retry_cnt < DOWNLOAD_RETRY_LIMIT:
- retry_cnt += 1
- else:
- raise RetryError(url, DOWNLOAD_RETRY_LIMIT)
- logger.info("Downloading {} from {}".format(fname, url))
- req = requests.get(url, stream=True)
- if req.status_code != 200:
- raise UrlError(url, req.status_code)
- # For protecting download interupted, download to
- # tmp_fullname firstly, move tmp_fullname to fullname
- # after download finished
- tmp_fullname = fullname + "_tmp"
- total_size = req.headers.get('content-length')
- with open(tmp_fullname, 'wb') as f:
- if total_size:
- for chunk in tqdm.tqdm(
- req.iter_content(chunk_size=1024),
- total=(int(total_size) + 1023) // 1024,
- unit='KB'):
- f.write(chunk)
- else:
- for chunk in req.iter_content(chunk_size=1024):
- if chunk:
- f.write(chunk)
- shutil.move(tmp_fullname, fullname)
- return fullname
- def _decompress(fname):
- """
- Decompress for zip and tar file
- """
- logger.info("Decompressing {}...".format(fname))
- # For protecting decompressing interupted,
- # decompress to fpath_tmp directory firstly, if decompress
- # successed, move decompress files to fpath and delete
- # fpath_tmp and remove download compress file.
- fpath = os.path.split(fname)[0]
- fpath_tmp = os.path.join(fpath, 'tmp')
- if os.path.isdir(fpath_tmp):
- shutil.rmtree(fpath_tmp)
- os.makedirs(fpath_tmp)
- if fname.find('tar') >= 0:
- with tarfile.open(fname) as tf:
- def is_within_directory(directory, target):
-
- abs_directory = os.path.abspath(directory)
- abs_target = os.path.abspath(target)
-
- prefix = os.path.commonprefix([abs_directory, abs_target])
-
- return prefix == abs_directory
-
- def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
-
- for member in tar.getmembers():
- member_path = os.path.join(path, member.name)
- if not is_within_directory(path, member_path):
- raise Exception("Attempted Path Traversal in Tar File")
-
- tar.extractall(path, members, numeric_owner=numeric_owner)
-
-
- safe_extract(tf, path=fpath_tmp)
- elif fname.find('zip') >= 0:
- with zipfile.ZipFile(fname) as zf:
- zf.extractall(path=fpath_tmp)
- else:
- raise TypeError("Unsupport compress file type {}".format(fname))
- fs = os.listdir(fpath_tmp)
- assert len(
- fs
- ) == 1, "There should just be 1 pretrained path in an archive file but got {}.".format(
- len(fs))
- f = fs[0]
- src_dir = os.path.join(fpath_tmp, f)
- dst_dir = os.path.join(fpath, f)
- _move_and_merge_tree(src_dir, dst_dir)
- shutil.rmtree(fpath_tmp)
- os.remove(fname)
- return f
- def _get_pretrained():
- with open('./ppcls/utils/pretrained.list') as flist:
- pretrained = [line.strip() for line in flist]
- return pretrained
- def _check_pretrained_name(architecture):
- assert isinstance(architecture, str), \
- ("the type of architecture({}) should be str". format(architecture))
- pretrained = _get_pretrained()
- similar_names = similar_architectures(architecture, pretrained)
- model_list = ', '.join(similar_names)
- err = "{} is not exist! Maybe you want: [{}]" \
- "".format(architecture, model_list)
- if architecture not in similar_names:
- raise ModelNameError(err)
- def list_models():
- pretrained = _get_pretrained()
- msg = "All avialable pretrained models are as follows: {}".format(
- pretrained)
- logger.info(msg)
- return
- def get(architecture, path, decompress=False, postfix="pdparams"):
- """
- Get the pretrained model.
- """
- _check_pretrained_name(architecture)
- url = _get_url(architecture, postfix=postfix)
- fname = _download(url, path)
- if postfix == "tar" and decompress:
- _decompress(fname)
- logger.info("download {} finished ".format(fname))
|