Bläddra i källkod

support to disable train deamon for benchmark

gaotingquan 1 år sedan
förälder
incheckning
68f743c10d
2 ändrade filer med 18 tillägg och 10 borttagningar
  1. 15 8
      paddlex/modules/base/trainer/train_deamon.py
  2. 3 2
      paddlex/modules/base/trainer/trainer.py

+ 15 - 8
paddlex/modules/base/trainer/train_deamon.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -48,11 +48,13 @@ class BaseTrainDeamon(ABC):
     update_interval = 600
     last_k = 5
 
-    def __init__(self, global_config):
+    def __init__(self, config):
         """ init """
-        self.global_config = global_config
+        self.global_config = config.Global
+        self.disable_deamon = config.get("Benchmark", {}).get(
+            "disable_deamon", False)
         self.init_pre_hook()
-        self.output = global_config.output
+        self.output = self.global_config.output
         self.train_outputs = self.get_train_outputs()
         self.save_paths = self.get_save_paths()
         self.results = self.init_train_result()
@@ -145,7 +147,8 @@ class BaseTrainDeamon(ABC):
         self.exit = False
         self.thread = threading.Thread(target=self.run)
         self.thread.daemon = True
-        self.thread.start()
+        if not self.disable_deamon:
+            self.thread.start()
 
     def stop_hook(self):
         """ hook befor stop """
@@ -332,10 +335,14 @@ class BaseTrainDeamon(ABC):
             inference_config = export_save_dir.joinpath("inference.yml")
             if not inference_config.exists():
                 inference_config = ""
-            use_pir = hasattr(paddle.framework, "use_pir_api") and paddle.framework.use_pir_api()
-            pdmodel = export_save_dir.joinpath("inference.json") if use_pir else export_save_dir.joinpath("inference.pdmodel")
+            use_pir = hasattr(paddle.framework,
+                              "use_pir_api") and paddle.framework.use_pir_api()
+            pdmodel = export_save_dir.joinpath(
+                "inference.json") if use_pir else export_save_dir.joinpath(
+                    "inference.pdmodel")
             pdiparams = export_save_dir.joinpath("inference.pdiparams")
-            pdiparams_info = "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info") 
+            pdiparams_info = "" if use_pir else export_save_dir.joinpath(
+                "inference.pdiparams.info")
         else:
             inference_config = ""
             pdmodel = ""

+ 3 - 2
paddlex/modules/base/trainer/trainer.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -45,11 +45,12 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
             config (AttrDict):  PaddleX pipeline config, which is loaded from pipeline yaml file.
         """
         super().__init__()
+        self.config = config
         self.global_config = config.Global
         self.train_config = config.Train
         self.benchmark_config = config.get('Benchmark', None)
 
-        self.deamon = self.build_deamon(self.global_config)
+        self.deamon = self.build_deamon(self.config)
         self.pdx_config, self.pdx_model = build_model(self.global_config.model)
 
     def train(self, *args, **kwargs):