will-jl944 4 лет назад
Родитель
Сommit
5e5417a490
2 измененных файлов с 6 добавлено и 3 удалено
  1. 5 2
      paddlex/cv/models/detector.py
  2. 1 1
      paddlex/utils/checkpoint.py

+ 5 - 2
paddlex/cv/models/detector.py

@@ -270,7 +270,9 @@ class BaseDetector(BaseModel):
         self.net_initialize(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,
-            resume_checkpoint=resume_checkpoint)
+            resume_checkpoint=resume_checkpoint,
+            is_backbone_weights=(pretrain_weights == 'IMAGENET' and
+                                 'ESNet_' in self.backbone_name))
 
         if use_ema:
             ema = ExponentialMovingAverage(
@@ -698,11 +700,12 @@ class PicoDet(BaseDetector):
                 num_classes=num_classes,
                 fpn_stride=[8, 16, 32, 64],
                 prior_prob=0.01,
+                reg_max=7,
+                cell_offset=.5,
                 loss_class=loss_class,
                 loss_dfl=loss_dfl,
                 loss_bbox=loss_bbox,
                 assigner=assigner,
-                reg_max=7,
                 feat_in_chan=neck_out_channels,
                 nms=nms)
             params.update({

+ 1 - 1
paddlex/utils/checkpoint.py

@@ -418,7 +418,7 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
             # hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone'
             # while res5 module is located in bbox_head.head. Replace the prefix of
             # res5 with 'bbox_head.head' to load pretrain weights correctly.
-            for k in list(param_state_dict.keys()):
+            for k in param_state_dict.keys():
                 if 'backbone.res5' in k:
                     new_k = k.replace('backbone', 'bbox_head.head')
                     if new_k in model_state_dict: