|
|
@@ -34,7 +34,10 @@ class TSADTrainer(BaseTrainer):
|
|
|
os.makedirs(self.global_config.output, exist_ok=True)
|
|
|
self.update_config()
|
|
|
self.dump_config()
|
|
|
- train_result = self.pdx_model.train(**self.get_train_kwargs())
|
|
|
+ train_args = self.get_train_kwargs()
|
|
|
+ if self.benchmark_config is not None:
|
|
|
+ train_args.update({"benchmark": self.benchmark_config})
|
|
|
+ train_result = self.pdx_model.train(**train_args)
|
|
|
assert (
|
|
|
train_result.returncode == 0
|
|
|
), f"Encountered an unexpected error({train_result.returncode}) in \
|
|
|
@@ -80,6 +83,8 @@ training!"
|
|
|
self.pdx_config.update_learning_rate(self.train_config.learning_rate)
|
|
|
if self.train_config.epochs_iters is not None:
|
|
|
self.pdx_config.update_epochs(self.train_config.epochs_iters)
|
|
|
+ if self.train_config.get("dy2st", False):
|
|
|
+ self.pdx_config.update_to_static(self.train_config.dy2st)
|
|
|
if self.train_config.log_interval is not None:
|
|
|
self.pdx_config.update_log_interval(self.train_config.log_interval)
|
|
|
if self.global_config.output is not None:
|