Pārlūkot izejas kodu

deal with res5 head weights explictly while loading pretrained weights

will-jl944 4 gadi atpakaļ
vecāks
revīzija
315ab5c6d9
1 mainītis faili ar 16 papildinājumiem un 8 dzēšanām
  1. 16 8
      dygraph/paddlex/utils/checkpoint.py

+ 16 - 8
dygraph/paddlex/utils/checkpoint.py

@@ -381,21 +381,29 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
             use_color=True)
 
         if os.path.exists(pretrain_weights):
-            para_state_dict = paddle.load(pretrain_weights)
+            param_state_dict = paddle.load(pretrain_weights)
             model_state_dict = model.state_dict()
-            keys = model_state_dict.keys()
+            # hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone'
+            # while res5 module is located in bbox_head.head. Replace the prefix of
+            # res5 with 'bbox_head.head' to load pretrain weights correctly.
+            for k in list(param_state_dict.keys()):
+                if 'backbone.res5' in k:
+                    new_k = k.replace('backbone', 'bbox_head.head')
+                    if new_k in model_state_dict:
+                        value = param_state_dict.pop(k)
+                        param_state_dict[new_k] = value
             num_params_loaded = 0
-            for k in keys:
-                if k not in para_state_dict:
+            for k in model_state_dict:
+                if k not in param_state_dict:
                     logging.warning("{} is not in pretrained model".format(k))
-                elif list(para_state_dict[k].shape) != list(model_state_dict[k]
-                                                            .shape):
+                elif list(param_state_dict[k].shape) != list(model_state_dict[
+                        k].shape):
                     logging.warning(
                         "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
-                        .format(k, para_state_dict[k].shape, model_state_dict[
+                        .format(k, param_state_dict[k].shape, model_state_dict[
                             k].shape))
                 else:
-                    model_state_dict[k] = para_state_dict[k]
+                    model_state_dict[k] = param_state_dict[k]
                     num_params_loaded += 1
             model.set_state_dict(model_state_dict)
             logging.info("There are {}/{} variables loaded into {}.".format(