|
|
@@ -50,8 +50,11 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
|
|
|
self.global_config = config.Global
|
|
|
self.train_config = config.Train
|
|
|
self.benchmark_config = config.get("Benchmark", None)
|
|
|
+ config_path = self.train_config.get("basic_config_path", None)
|
|
|
|
|
|
- self.pdx_config, self.pdx_model = build_model(self.global_config.model)
|
|
|
+ self.pdx_config, self.pdx_model = build_model(
|
|
|
+ self.global_config.model, config_path=config_path
|
|
|
+ )
|
|
|
|
|
|
def train(self, *args, **kwargs):
|
|
|
"""execute model training"""
|