zhangyubo0722 1 vuosi sitten
vanhempi
commit
52bd378a6f

+ 0 - 1
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -42,7 +42,6 @@ class TSADTrainer(BaseTrainer):
 training!"
 
         self.make_tar_file()
-        self.deamon.stop()
 
     def make_tar_file(self):
         """make tar file to package the training outputs"""

+ 0 - 1
paddlex/modules/ts_classification/trainer.py

@@ -42,7 +42,6 @@ class TSCLSTrainer(BaseTrainer):
 training!"
 
         self.make_tar_file()
-        self.deamon.stop()
 
     def make_tar_file(self):
         """make tar file to package the training outputs"""

+ 0 - 1
paddlex/modules/ts_forecast/trainer.py

@@ -42,7 +42,6 @@ class TSFCTrainer(BaseTrainer):
 training!"
 
         self.make_tar_file()
-        self.deamon.stop()
 
     def make_tar_file(self):
         """make tar file to package the training outputs"""

+ 4 - 2
paddlex/repo_apis/PaddleTS_api/ts_base/config.py

@@ -18,7 +18,7 @@ from urllib.parse import urlparse
 import ruamel.yaml
 
 from ...base import BaseConfig
-from ....utils.misc import abspath
+from ....utils.misc import abspath, convert_and_remove_types
 
 
 class BaseTSConfig(BaseConfig):
@@ -59,7 +59,9 @@ class BaseTSConfig(BaseConfig):
         """
         yaml = ruamel.yaml.YAML()
         with open(config_file_path, "w", encoding="utf-8") as f:
-            yaml.dump(self.dict, f)
+            dict_to_dump = self.dict
+            dict_to_dump = convert_and_remove_types(dict_to_dump)
+            yaml.dump(dict_to_dump, f)
 
     def update_epochs(self, epochs: int):
         """update epochs setting

+ 23 - 0
paddlex/utils/misc.py

@@ -15,6 +15,7 @@
 
 import os
 import threading
+import numpy as np
 from abc import ABCMeta
 from .errors import (
     raise_class_not_found_error,
@@ -24,6 +25,28 @@ from .errors import (
 from .logging import *
 
 
+def convert_and_remove_types(data):
+    if isinstance(data, dict):
+        return {
+            k: convert_and_remove_types(v)
+            for k, v in data.items()
+            if not isinstance(v, type)
+        }
+    elif isinstance(data, list):
+        return [convert_and_remove_types(v) for v in data]
+    elif isinstance(data, np.ndarray):
+        return data.tolist()
+    elif isinstance(data, (np.float32, np.float64)):
+        return float(data)
+    elif isinstance(data, (np.int32, np.int64)):
+        return int(data)
+    elif isinstance(data, np.bool_):
+        return bool(data)
+    elif isinstance(data, (np.str_, np.unicode_)):
+        return str(data)
+    return data
+
+
 def abspath(path: str):
     """get absolute path