|
|
@@ -14,7 +14,7 @@
|
|
|
|
|
|
import json
|
|
|
import shutil
|
|
|
-import paddle
|
|
|
+import lazy_paddle as paddle
|
|
|
from pathlib import Path
|
|
|
|
|
|
from ..base import BaseTrainer, BaseTrainDeamon
|
|
|
@@ -48,34 +48,32 @@ class MLClsTrainer(BaseTrainer):
|
|
|
return ClsTrainDeamon(config)
|
|
|
|
|
|
def update_config(self):
|
|
|
- """update training config
|
|
|
- """
|
|
|
+ """update training config"""
|
|
|
if self.train_config.log_interval:
|
|
|
self.pdx_config.update_log_interval(self.train_config.log_interval)
|
|
|
if self.train_config.eval_interval:
|
|
|
- self.pdx_config.update_eval_interval(
|
|
|
- self.train_config.eval_interval)
|
|
|
+ self.pdx_config.update_eval_interval(self.train_config.eval_interval)
|
|
|
if self.train_config.save_interval:
|
|
|
- self.pdx_config.update_save_interval(
|
|
|
- self.train_config.save_interval)
|
|
|
+ self.pdx_config.update_save_interval(self.train_config.save_interval)
|
|
|
|
|
|
- self.pdx_config.update_dataset(self.global_config.dataset_dir,
|
|
|
- "MLClsDataset")
|
|
|
+ self.pdx_config.update_dataset(self.global_config.dataset_dir, "MLClsDataset")
|
|
|
if self.train_config.num_classes is not None:
|
|
|
self.pdx_config.update_num_classes(self.train_config.num_classes)
|
|
|
- if self.train_config.pretrain_weight_path and self.train_config.pretrain_weight_path != "":
|
|
|
+ if (
|
|
|
+ self.train_config.pretrain_weight_path
|
|
|
+ and self.train_config.pretrain_weight_path != ""
|
|
|
+ ):
|
|
|
self.pdx_config.update_pretrained_weights(
|
|
|
- self.train_config.pretrain_weight_path)
|
|
|
+ self.train_config.pretrain_weight_path
|
|
|
+ )
|
|
|
|
|
|
- label_dict_path = Path(self.global_config.dataset_dir).joinpath(
|
|
|
- "label.txt")
|
|
|
+ label_dict_path = Path(self.global_config.dataset_dir).joinpath("label.txt")
|
|
|
if label_dict_path.exists():
|
|
|
self.dump_label_dict(label_dict_path)
|
|
|
if self.train_config.batch_size is not None:
|
|
|
self.pdx_config.update_batch_size(self.train_config.batch_size)
|
|
|
if self.train_config.learning_rate is not None:
|
|
|
- self.pdx_config.update_learning_rate(
|
|
|
- self.train_config.learning_rate)
|
|
|
+ self.pdx_config.update_learning_rate(self.train_config.learning_rate)
|
|
|
if self.train_config.epochs_iters is not None:
|
|
|
self.pdx_config._update_epochs(self.train_config.epochs_iters)
|
|
|
if self.train_config.warmup_steps is not None:
|