|
|
@@ -31,6 +31,7 @@ from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResiz
|
|
|
from paddlex.cv.transforms import arrange_transforms
|
|
|
from .base import BaseModel
|
|
|
from .utils.det_metrics import VOCMetric, COCOMetric
|
|
|
+from .utils.ema import ExponentialMovingAverage
|
|
|
from paddlex.utils.checkpoint import det_pretrain_weights_dict
|
|
|
|
|
|
__all__ = [
|
|
|
@@ -132,6 +133,7 @@ class BaseDetector(BaseModel):
|
|
|
lr_decay_epochs=(216, 243),
|
|
|
lr_decay_gamma=0.1,
|
|
|
metric=None,
|
|
|
+ use_ema=False,
|
|
|
early_stop=False,
|
|
|
early_stop_patience=5,
|
|
|
use_vdl=True):
|
|
|
@@ -157,6 +159,7 @@ class BaseDetector(BaseModel):
|
|
|
lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
|
|
|
metric({'VOC', 'COCO', None}, optional):
|
|
|
Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
|
|
|
+ use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
|
|
|
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
|
|
|
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
|
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
|
@@ -224,6 +227,11 @@ class BaseDetector(BaseModel):
|
|
|
self.net_initialize(
|
|
|
pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
|
|
|
|
|
|
+ if use_ema:
|
|
|
+ ema = ExponentialMovingAverage(
|
|
|
+ decay=.9998, model=self.net, use_thres_step=True)
|
|
|
+ else:
|
|
|
+ ema = None
|
|
|
# start train loop
|
|
|
self.train_loop(
|
|
|
num_epochs=num_epochs,
|
|
|
@@ -233,6 +241,7 @@ class BaseDetector(BaseModel):
|
|
|
save_interval_epochs=save_interval_epochs,
|
|
|
log_interval_steps=log_interval_steps,
|
|
|
save_dir=save_dir,
|
|
|
+ ema=ema,
|
|
|
early_stop=early_stop,
|
|
|
early_stop_patience=early_stop_patience,
|
|
|
use_vdl=use_vdl)
|