| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import contextlib
- import filelock
- import math
- import os
- import tempfile
- from urllib.parse import urlparse, unquote
- import paddle
- from paddlex.paddleseg.utils import logger, seg_env
- from paddlex.paddleseg.utils.download import download_file_and_uncompress
- @contextlib.contextmanager
- def generate_tempdir(directory: str=None, **kwargs):
- '''Generate a temporary directory'''
- directory = seg_env.TMP_HOME if not directory else directory
- with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
- yield _dir
- def load_entire_model(model, pretrained):
- if pretrained is not None:
- load_pretrained_model(model, pretrained)
- else:
- logger.warning('Not all pretrained params of {} are loaded, ' \
- 'training from scratch or a pretrained backbone.'.format(model.__class__.__name__))
- def load_pretrained_model(model, pretrained_model):
- if pretrained_model is not None:
- logger.info('Loading pretrained model from {}'.format(
- pretrained_model))
- # download pretrained model from url
- if urlparse(pretrained_model).netloc:
- pretrained_model = unquote(pretrained_model)
- savename = pretrained_model.split('/')[-1]
- if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
- savename = pretrained_model.split('/')[-2]
- else:
- savename = savename.split('.')[0]
- with generate_tempdir() as _dir:
- with filelock.FileLock(
- os.path.join(seg_env.TMP_HOME, savename)):
- pretrained_model = download_file_and_uncompress(
- pretrained_model,
- savepath=_dir,
- extrapath=seg_env.PRETRAINED_MODEL_HOME,
- extraname=savename)
- pretrained_model = os.path.join(pretrained_model,
- 'model.pdparams')
- if os.path.exists(pretrained_model):
- para_state_dict = paddle.load(pretrained_model)
- model_state_dict = model.state_dict()
- keys = model_state_dict.keys()
- num_params_loaded = 0
- for k in keys:
- if k not in para_state_dict:
- logger.warning("{} is not in pretrained model".format(k))
- elif list(para_state_dict[k].shape) != list(model_state_dict[k]
- .shape):
- logger.warning(
- "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
- .format(k, para_state_dict[k].shape, model_state_dict[
- k].shape))
- else:
- model_state_dict[k] = para_state_dict[k]
- num_params_loaded += 1
- model.set_dict(model_state_dict)
- logger.info("There are {}/{} variables loaded into {}.".format(
- num_params_loaded,
- len(model_state_dict), model.__class__.__name__))
- else:
- raise ValueError('The pretrained model directory is not Found: {}'.
- format(pretrained_model))
- else:
- logger.info(
- 'No pretrained model to load, {} will be trained from scratch.'.
- format(model.__class__.__name__))
- def resume(model, optimizer, resume_model):
- if resume_model is not None:
- logger.info('Resume model from {}'.format(resume_model))
- if os.path.exists(resume_model):
- resume_model = os.path.normpath(resume_model)
- ckpt_path = os.path.join(resume_model, 'model.pdparams')
- para_state_dict = paddle.load(ckpt_path)
- ckpt_path = os.path.join(resume_model, 'model.pdopt')
- opti_state_dict = paddle.load(ckpt_path)
- model.set_state_dict(para_state_dict)
- optimizer.set_state_dict(opti_state_dict)
- iter = resume_model.split('_')[-1]
- iter = int(iter)
- return iter
- else:
- raise ValueError(
- 'Directory of the model needed to resume is not Found: {}'.
- format(resume_model))
- else:
- logger.info('No model needed to resume.')
|