瀏覽代碼

initiate weigths before doing pruning in legacy_train

will-jl944 4 年之前
父節點
當前提交
9b956dac06
共有 3 個文件被更改,包括 44 次插入44 次删除
  1. 13 13
      dygraph/paddlex/cls.py
  2. 16 16
      dygraph/paddlex/det.py
  3. 15 15
      dygraph/paddlex/seg.py

+ 13 - 13
dygraph/paddlex/cls.py

@@ -1202,6 +1202,19 @@ def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
                   early_stop, early_stop_patience):
     model.labels = train_dataset.labels
 
+    # initiate weights
+    if pretrain_weights is not None and not osp.exists(pretrain_weights):
+        if pretrain_weights not in ['IMAGENET']:
+            logging.warning("Path of pretrain_weights('{}') does not exist!".
+                            format(pretrain_weights))
+            logging.warning("Pretrain_weights is forcibly set to 'IMAGENET'. "
+                            "If don't want to use pretrain weights, "
+                            "set pretrain_weights to be None.")
+            pretrain_weights = 'IMAGENET'
+    pretrained_dir = osp.join(save_dir, 'pretrain')
+    model.net_initialize(
+        pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
+
     if sensitivities_file is not None:
         dataset = eval_dataset or train_dataset
         inputs = [1, 3] + list(dataset[0]['image'].shape[:2])
@@ -1223,19 +1236,6 @@ def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
     else:
         model.optimizer = optimizer
 
-    # initiate weights
-    if pretrain_weights is not None and not osp.exists(pretrain_weights):
-        if pretrain_weights not in ['IMAGENET']:
-            logging.warning("Path of pretrain_weights('{}') does not exist!".
-                            format(pretrain_weights))
-            logging.warning("Pretrain_weights is forcibly set to 'IMAGENET'. "
-                            "If don't want to use pretrain weights, "
-                            "set pretrain_weights to be None.")
-            pretrain_weights = 'IMAGENET'
-    pretrained_dir = osp.join(save_dir, 'pretrain')
-    model.net_initialize(
-        pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
-
     model.train_loop(
         num_epochs=num_epochs,
         train_dataset=train_dataset,

+ 16 - 16
dygraph/paddlex/det.py

@@ -480,6 +480,22 @@ def _legacy_train(model,
     train_dataset.batch_transforms = model._compose_batch_transform(
         train_dataset.transforms, mode='train')
 
+    # initiate weights
+    if pretrain_weights is not None and not osp.exists(pretrain_weights):
+        if pretrain_weights not in det_pretrain_weights_dict['_'.join(
+            [model.model_name, model.backbone_name])]:
+            logging.warning("Path of pretrain_weights('{}') does not exist!".
+                            format(pretrain_weights))
+            pretrain_weights = det_pretrain_weights_dict['_'.join(
+                [model.model_name, model.backbone_name])][0]
+            logging.warning("Pretrain_weights is forcibly set to '{}'. "
+                            "If you don't want to use pretrain weights, "
+                            "set pretrain_weights to be None.".format(
+                                pretrain_weights))
+    pretrained_dir = osp.join(save_dir, 'pretrain')
+    model.net_initialize(
+        pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
+
     if sensitivities_file is not None:
         dataset = eval_dataset or train_dataset
         im_shape = dataset[0]['image'].shape[:2]
@@ -512,22 +528,6 @@ def _legacy_train(model,
     else:
         model.optimizer = optimizer
 
-    # initiate weights
-    if pretrain_weights is not None and not osp.exists(pretrain_weights):
-        if pretrain_weights not in det_pretrain_weights_dict['_'.join(
-            [model.model_name, model.backbone_name])]:
-            logging.warning("Path of pretrain_weights('{}') does not exist!".
-                            format(pretrain_weights))
-            pretrain_weights = det_pretrain_weights_dict['_'.join(
-                [model.model_name, model.backbone_name])][0]
-            logging.warning("Pretrain_weights is forcibly set to '{}'. "
-                            "If you don't want to use pretrain weights, "
-                            "set pretrain_weights to be None.".format(
-                                pretrain_weights))
-    pretrained_dir = osp.join(save_dir, 'pretrain')
-    model.net_initialize(
-        pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
-
     if use_ema:
         ema = ExponentialMovingAverage(
             decay=ema_decay, model=model.net, use_thres_step=True)

+ 15 - 15
dygraph/paddlex/seg.py

@@ -387,6 +387,21 @@ def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
     if model.losses is None:
         model.losses = model.default_loss()
 
+    # initiate weights
+    if pretrain_weights is not None and not osp.exists(pretrain_weights):
+        if pretrain_weights not in seg_pretrain_weights_dict[model.model_name]:
+            logging.warning("Path of pretrain_weights('{}') does not exist!".
+                            format(pretrain_weights))
+            logging.warning("Pretrain_weights is forcibly set to '{}'. "
+                            "If don't want to use pretrain weights, "
+                            "set pretrain_weights to be None.".format(
+                                seg_pretrain_weights_dict[model.model_name][
+                                    0]))
+            pretrain_weights = seg_pretrain_weights_dict[model.model_name][0]
+    pretrained_dir = osp.join(save_dir, 'pretrain')
+    model.net_initialize(
+        pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
+
     if sensitivities_file is not None:
         dataset = eval_dataset or train_dataset
         inputs = [1, 3] + list(dataset[0]['image'].shape[:2])
@@ -403,21 +418,6 @@ def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
     else:
         model.optimizer = optimizer
 
-    # initiate weights
-    if pretrain_weights is not None and not osp.exists(pretrain_weights):
-        if pretrain_weights not in seg_pretrain_weights_dict[model.model_name]:
-            logging.warning("Path of pretrain_weights('{}') does not exist!".
-                            format(pretrain_weights))
-            logging.warning("Pretrain_weights is forcibly set to '{}'. "
-                            "If don't want to use pretrain weights, "
-                            "set pretrain_weights to be None.".format(
-                                seg_pretrain_weights_dict[model.model_name][
-                                    0]))
-            pretrain_weights = seg_pretrain_weights_dict[model.model_name][0]
-    pretrained_dir = osp.join(save_dir, 'pretrain')
-    model.net_initialize(
-        pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
-
     model.train_loop(
         num_epochs=num_epochs,
         train_dataset=train_dataset,