Browse Source

support input config path

zhangyubo0722 1 year ago
parent
commit
093e570853

+ 3 - 1
paddlex/modules/base/evaluator.py

@@ -51,7 +51,9 @@ class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
         self.global_config = config.Global
         self.eval_config = config.Evaluate
 
-        config_path = self.get_config_path(self.eval_config.weight_path)
+        config_path = self.eval_config.get("basic_config_path", None)
+        if not config_path:
+            config_path = self.get_config_path(self.eval_config.weight_path)
 
         self.pdx_config, self.pdx_model = build_model(
             self.global_config.model, config_path=config_path

+ 3 - 1
paddlex/modules/base/exportor.py

@@ -51,7 +51,9 @@ class BaseExportor(ABC, metaclass=AutoRegisterABCMetaClass):
         self.global_config = config.Global
         self.export_config = config.Export
 
-        config_path = self.get_config_path(self.export_config.weight_path)
+        config_path = self.export_config.get("basic_config_path", None)
+        if not config_path:
+            config_path = self.get_config_path(self.export_config.weight_path)
 
         self.pdx_config, self.pdx_model = build_model(
             self.global_config.model, config_path=config_path

+ 4 - 1
paddlex/modules/base/trainer.py

@@ -50,8 +50,11 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
         self.global_config = config.Global
         self.train_config = config.Train
         self.benchmark_config = config.get("Benchmark", None)
+        config_path = self.train_config.get("basic_config_path", None)
 
-        self.pdx_config, self.pdx_model = build_model(self.global_config.model)
+        self.pdx_config, self.pdx_model = build_model(
+            self.global_config.model, config_path=config_path
+        )
 
     def train(self, *args, **kwargs):
         """execute model training"""