|
@@ -381,21 +381,29 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
|
|
|
use_color=True)
|
|
use_color=True)
|
|
|
|
|
|
|
|
if os.path.exists(pretrain_weights):
|
|
if os.path.exists(pretrain_weights):
|
|
|
- para_state_dict = paddle.load(pretrain_weights)
|
|
|
|
|
|
|
+ param_state_dict = paddle.load(pretrain_weights)
|
|
|
model_state_dict = model.state_dict()
|
|
model_state_dict = model.state_dict()
|
|
|
- keys = model_state_dict.keys()
|
|
|
|
|
|
|
+ # hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone'
|
|
|
|
|
+ # while res5 module is located in bbox_head.head. Replace the prefix of
|
|
|
|
|
+ # res5 with 'bbox_head.head' to load pretrain weights correctly.
|
|
|
|
|
+ for k in list(param_state_dict.keys()):
|
|
|
|
|
+ if 'backbone.res5' in k:
|
|
|
|
|
+ new_k = k.replace('backbone', 'bbox_head.head')
|
|
|
|
|
+ if new_k in model_state_dict:
|
|
|
|
|
+ value = param_state_dict.pop(k)
|
|
|
|
|
+ param_state_dict[new_k] = value
|
|
|
num_params_loaded = 0
|
|
num_params_loaded = 0
|
|
|
- for k in keys:
|
|
|
|
|
- if k not in para_state_dict:
|
|
|
|
|
|
|
+ for k in model_state_dict:
|
|
|
|
|
+ if k not in param_state_dict:
|
|
|
logging.warning("{} is not in pretrained model".format(k))
|
|
logging.warning("{} is not in pretrained model".format(k))
|
|
|
- elif list(para_state_dict[k].shape) != list(model_state_dict[k]
|
|
|
|
|
- .shape):
|
|
|
|
|
|
|
+ elif list(param_state_dict[k].shape) != list(model_state_dict[
|
|
|
|
|
+ k].shape):
|
|
|
logging.warning(
|
|
logging.warning(
|
|
|
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
|
|
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
|
|
|
- .format(k, para_state_dict[k].shape, model_state_dict[
|
|
|
|
|
|
|
+ .format(k, param_state_dict[k].shape, model_state_dict[
|
|
|
k].shape))
|
|
k].shape))
|
|
|
else:
|
|
else:
|
|
|
- model_state_dict[k] = para_state_dict[k]
|
|
|
|
|
|
|
+ model_state_dict[k] = param_state_dict[k]
|
|
|
num_params_loaded += 1
|
|
num_params_loaded += 1
|
|
|
model.set_state_dict(model_state_dict)
|
|
model.set_state_dict(model_state_dict)
|
|
|
logging.info("There are {}/{} variables loaded into {}.".format(
|
|
logging.info("There are {}/{} variables loaded into {}.".format(
|