|
|
@@ -1,5 +1,5 @@
|
|
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
|
-#
|
|
|
+#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
# You may obtain a copy of the License at
|
|
|
@@ -48,11 +48,13 @@ class BaseTrainDeamon(ABC):
|
|
|
update_interval = 600
|
|
|
last_k = 5
|
|
|
|
|
|
- def __init__(self, global_config):
|
|
|
+ def __init__(self, config):
|
|
|
""" init """
|
|
|
- self.global_config = global_config
|
|
|
+ self.global_config = config.Global
|
|
|
+ self.disable_deamon = config.get("Benchmark", {}).get(
|
|
|
+ "disable_deamon", False)
|
|
|
self.init_pre_hook()
|
|
|
- self.output = global_config.output
|
|
|
+ self.output = self.global_config.output
|
|
|
self.train_outputs = self.get_train_outputs()
|
|
|
self.save_paths = self.get_save_paths()
|
|
|
self.results = self.init_train_result()
|
|
|
@@ -145,7 +147,8 @@ class BaseTrainDeamon(ABC):
|
|
|
self.exit = False
|
|
|
self.thread = threading.Thread(target=self.run)
|
|
|
self.thread.daemon = True
|
|
|
- self.thread.start()
|
|
|
+ if not self.disable_deamon:
|
|
|
+ self.thread.start()
|
|
|
|
|
|
def stop_hook(self):
|
|
|
""" hook befor stop """
|
|
|
@@ -332,10 +335,14 @@ class BaseTrainDeamon(ABC):
|
|
|
inference_config = export_save_dir.joinpath("inference.yml")
|
|
|
if not inference_config.exists():
|
|
|
inference_config = ""
|
|
|
- use_pir = hasattr(paddle.framework, "use_pir_api") and paddle.framework.use_pir_api()
|
|
|
- pdmodel = export_save_dir.joinpath("inference.json") if use_pir else export_save_dir.joinpath("inference.pdmodel")
|
|
|
+ use_pir = hasattr(paddle.framework,
|
|
|
+ "use_pir_api") and paddle.framework.use_pir_api()
|
|
|
+ pdmodel = export_save_dir.joinpath(
|
|
|
+ "inference.json") if use_pir else export_save_dir.joinpath(
|
|
|
+ "inference.pdmodel")
|
|
|
pdiparams = export_save_dir.joinpath("inference.pdiparams")
|
|
|
- pdiparams_info = "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
|
|
|
+ pdiparams_info = "" if use_pir else export_save_dir.joinpath(
|
|
|
+ "inference.pdiparams.info")
|
|
|
else:
|
|
|
inference_config = ""
|
|
|
pdmodel = ""
|