gaotingquan 1 vuosi sitten
vanhempi
commit
272c3bf977

+ 10 - 4
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import json
 import time
@@ -43,9 +42,16 @@ class TSADTrainer(BaseTrainer):
     def train(self):
         """firstly, update and dump train config, then train model
         """
-        rtn = super().train()
+        # XXX: using super().train() instead when the train_hook() is supported.
+        os.makedirs(self.global_config.output, exist_ok=True)
+        self.update_config()
+        self.dump_config()
+        train_result = self.pdx_model.train(**self.get_train_kwargs())
+        assert train_result.returncode == 0, f"Encountered an unexpected error({train_result.returncode}) in \
+training!"
+
         self.make_tar_file()
-        return rtn
+        self.deamon.stop()
 
     def make_tar_file(self):
         """make tar file to package the training outputs

+ 10 - 4
paddlex/modules/ts_classification/trainer.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import json
 import time
@@ -43,9 +42,16 @@ class TSCLSTrainer(BaseTrainer):
     def train(self):
         """firstly, update and dump train config, then train model
         """
-        rtn = super().train()
+        # XXX: using super().train() instead when the train_hook() is supported.
+        os.makedirs(self.global_config.output, exist_ok=True)
+        self.update_config()
+        self.dump_config()
+        train_result = self.pdx_model.train(**self.get_train_kwargs())
+        assert train_result.returncode == 0, f"Encountered an unexpected error({train_result.returncode}) in \
+training!"
+
         self.make_tar_file()
-        return rtn
+        self.deamon.stop()
 
     def make_tar_file(self):
         """make tar file to package the training outputs

+ 10 - 4
paddlex/modules/ts_forecast/trainer.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import json
 import time
@@ -43,9 +42,16 @@ class TSFCTrainer(BaseTrainer):
     def train(self):
         """firstly, update and dump train config, then train model
         """
-        rtn = super().train()
+        # XXX: using super().train() instead when the train_hook() is supported.
+        os.makedirs(self.global_config.output, exist_ok=True)
+        self.update_config()
+        self.dump_config()
+        train_result = self.pdx_model.train(**self.get_train_kwargs())
+        assert train_result.returncode == 0, f"Encountered an unexpected error({train_result.returncode}) in \
+training!"
+
         self.make_tar_file()
-        return rtn
+        self.deamon.stop()
 
     def make_tar_file(self):
         """make tar file to package the training outputs

+ 2 - 2
paddlex/repo_manager/requirements.txt

@@ -13,5 +13,5 @@ premailer
 python-docx
 unstructured
 networkx
-Pillow
-requests <= 2.29
+Pillow==9.5
+requests <= 2.29