will-jl944 4 年之前
父节点
当前提交
62a1b636ec

+ 0 - 1
dygraph/paddlex/cv/__init__.py

@@ -13,6 +13,5 @@
 # limitations under the License.
 
 from . import models
-from . import nets
 from . import transforms
 from . import datasets

+ 8 - 1
dygraph/paddlex/cv/models/base.py

@@ -22,7 +22,6 @@ import yaml
 import json
 import paddle
 from paddle.io import DataLoader, DistributedBatchSampler
-from paddle.jit import to_static
 from paddleslim.analysis import flops
 from paddleslim import L1NormFilterPruner, FPGMFilterPruner
 import paddlex
@@ -193,6 +192,7 @@ class BaseModel:
                    save_interval_epochs=1,
                    log_interval_steps=10,
                    save_dir='output',
+                   ema=None,
                    early_stop=False,
                    early_stop_patience=5,
                    use_vdl=True):
@@ -268,6 +268,8 @@ class BaseModel:
 
                 train_avg_metrics.update(outputs)
                 outputs['lr'] = lr
+                if ema is not None:
+                    ema.update(self.net)
                 step_time_toc = time.time()
                 train_step_time.update(step_time_toc - step_time_tic)
                 step_time_tic = step_time_toc
@@ -305,6 +307,9 @@ class BaseModel:
             self.completed_epochs += 1
 
             # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
+            if ema is not None:
+                weight = self.net.state_dict()
+                self.net.set_dict(ema.apply())
             eval_epoch_tic = time.time()
             if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
                 if eval_dataset is not None and eval_dataset.num_samples > 0:
@@ -337,6 +342,8 @@ class BaseModel:
                     if eval_dataset is not None and early_stop:
                         if earlystop(current_accuracy):
                             break
+            if ema is not None:
+                self.net.set_dict(weight)
 
     def analyze_sensitivity(self,
                             dataset,

+ 9 - 0
dygraph/paddlex/cv/models/detector.py

@@ -31,6 +31,7 @@ from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResiz
 from paddlex.cv.transforms import arrange_transforms
 from .base import BaseModel
 from .utils.det_metrics import VOCMetric, COCOMetric
+from .utils.ema import ExponentialMovingAverage
 from paddlex.utils.checkpoint import det_pretrain_weights_dict
 
 __all__ = [
@@ -132,6 +133,7 @@ class BaseDetector(BaseModel):
               lr_decay_epochs=(216, 243),
               lr_decay_gamma=0.1,
               metric=None,
+              use_ema=False,
               early_stop=False,
               early_stop_patience=5,
               use_vdl=True):
@@ -157,6 +159,7 @@ class BaseDetector(BaseModel):
             lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
             metric({'VOC', 'COCO', None}, optional):
                 Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
+            use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
             early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
             early_stop_patience(int, optional): Early stop patience. Defaults to 5.
             use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
@@ -224,6 +227,11 @@ class BaseDetector(BaseModel):
         self.net_initialize(
             pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
 
+        if use_ema:
+            ema = ExponentialMovingAverage(
+                decay=.9998, model=self.net, use_thres_step=True)
+        else:
+            ema = None
         # start train loop
         self.train_loop(
             num_epochs=num_epochs,
@@ -233,6 +241,7 @@ class BaseDetector(BaseModel):
             save_interval_epochs=save_interval_epochs,
             log_interval_steps=log_interval_steps,
             save_dir=save_dir,
+            ema=ema,
             early_stop=early_stop,
             early_stop_patience=early_stop_patience,
             use_vdl=use_vdl)

+ 48 - 0
dygraph/paddlex/cv/models/utils/ema.py

@@ -0,0 +1,48 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+
+
+class ExponentialMovingAverage(object):
+    def __init__(self, decay, model, use_thres_step=False):
+        self.step = 0
+        self.decay = decay
+        self.shadow = dict()
+        for k, v in model.state_dict().items():
+            self.shadow[k] = paddle.zeros_like(v)
+        self.use_thres_step = use_thres_step
+
+    def update(self, model):
+        if self.use_thres_step:
+            decay = min(self.decay, (1 + self.step) / (10 + self.step))
+        else:
+            decay = self.decay
+        self._decay = decay
+        model_dict = model.state_dict()
+        for k, v in self.shadow.items():
+            v = decay * v + (1 - decay) * model_dict[k]
+            v.stop_gradient = True
+            self.shadow[k] = v
+        self.step += 1
+
+    def apply(self):
+        if self.step == 0:
+            return self.shadow
+        state_dict = dict()
+        for k, v in self.shadow.items():
+            v = v / (1 - self._decay**self.step)
+            v.stop_gradient = True
+            state_dict[k] = v
+        return state_dict