소스 검색

fix errors

FlyingQianMM 5 년 전
부모
커밋
fb0dd18a51
3개의 변경된 파일4개의 추가작업 그리고 6개의 파일을 삭제
  1. 1 2
      paddlex/cv/models/faster_rcnn.py
  2. 1 2
      paddlex/cv/models/mask_rcnn.py
  3. 2 2
      paddlex/utils/utils.py

+ 1 - 2
paddlex/cv/models/faster_rcnn.py

@@ -232,9 +232,8 @@ class FasterRCNN(BaseAPI):
         self.net_initialize(
             startup_prog=fluid.default_startup_program(),
             pretrain_weights=pretrain_weights,
+            fuse_bn=fuse_bn,
             save_dir=save_dir,
-            sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss,
             resume_checkpoint=resume_checkpoint)
         start_epoch = 0
         if resume_checkpoint:

+ 1 - 2
paddlex/cv/models/mask_rcnn.py

@@ -199,9 +199,8 @@ class MaskRCNN(FasterRCNN):
         self.net_initialize(
             startup_prog=fluid.default_startup_program(),
             pretrain_weights=pretrain_weights,
+            fuse_bn=fuse_bn,
             save_dir=save_dir,
-            sensitivities_file=sensitivities_file,
-            eval_metric_loss=eval_metric_loss,
             resume_checkpoint=resume_checkpoint)
         start_epoch = 0
         if resume_checkpoint:

+ 2 - 2
paddlex/utils/utils.py

@@ -300,8 +300,8 @@ def load_pretrain_weights(exe,
                 if pretrained_shape != actual_shape:
                     raise Exception(
                         "Shape of optimizer variable {} doesn't match.(Last: {}, Now: {}), {}"
-                        .format(var.name, opt_dict[var.name].shape,
-                                var.shape), exception_message)
+                        .format(var.name, pretrained_shape,
+                                actual_shape), exception_message)
             optimizer_varname_list = [var.name for var in optimizer_var_list]
             if os.exists(osp.join(weights_dir, 'learning_rate')
                          ) and 'learning_rate' not in optimizer_varname_list: