Explorar el Código

add ts amp and static (#2581)

* add ts amp and static

* add ts amp and static

* add ts amp and static
Sunflower7788 hace 11 meses
padre
commit
7c37847467

+ 6 - 1
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -34,7 +34,10 @@ class TSADTrainer(BaseTrainer):
         os.makedirs(self.global_config.output, exist_ok=True)
         self.update_config()
         self.dump_config()
-        train_result = self.pdx_model.train(**self.get_train_kwargs())
+        train_args = self.get_train_kwargs()
+        if self.benchmark_config is not None:
+            train_args.update({"benchmark": self.benchmark_config})
+        train_result = self.pdx_model.train(**train_args)
         assert (
             train_result.returncode == 0
         ), f"Encountered an unexpected error({train_result.returncode}) in \
@@ -80,6 +83,8 @@ training!"
             self.pdx_config.update_learning_rate(self.train_config.learning_rate)
         if self.train_config.epochs_iters is not None:
             self.pdx_config.update_epochs(self.train_config.epochs_iters)
+        if self.train_config.get("dy2st", False):
+            self.pdx_config.update_to_static(self.train_config.dy2st)
         if self.train_config.log_interval is not None:
             self.pdx_config.update_log_interval(self.train_config.log_interval)
         if self.global_config.output is not None:

+ 6 - 1
paddlex/modules/ts_classification/trainer.py

@@ -34,7 +34,10 @@ class TSCLSTrainer(BaseTrainer):
         os.makedirs(self.global_config.output, exist_ok=True)
         self.update_config()
         self.dump_config()
-        train_result = self.pdx_model.train(**self.get_train_kwargs())
+        train_args = self.get_train_kwargs()
+        if self.benchmark_config is not None:
+            train_args.update({"benchmark": self.benchmark_config})
+        train_result = self.pdx_model.train(**train_args)
         assert (
             train_result.returncode == 0
         ), f"Encountered an unexpected error({train_result.returncode}) in \
@@ -77,6 +80,8 @@ training!"
             self.pdx_config.update_epochs(self.train_config.epochs_iters)
         if self.train_config.log_interval is not None:
             self.pdx_config.update_log_interval(self.train_config.log_interval)
+        if self.train_config.get("dy2st", False):
+            self.pdx_config.update_to_static(self.train_config.dy2st)
         if self.global_config.output is not None:
             self.pdx_config.update_save_dir(self.global_config.output)
 

+ 6 - 1
paddlex/modules/ts_forecast/trainer.py

@@ -34,7 +34,10 @@ class TSFCTrainer(BaseTrainer):
         os.makedirs(self.global_config.output, exist_ok=True)
         self.update_config()
         self.dump_config()
-        train_result = self.pdx_model.train(**self.get_train_kwargs())
+        train_args = self.get_train_kwargs()
+        if self.benchmark_config is not None:
+            train_args.update({"benchmark": self.benchmark_config})
+        train_result = self.pdx_model.train(**train_args)
         assert (
             train_result.returncode == 0
         ), f"Encountered an unexpected error({train_result.returncode}) in \
@@ -77,6 +80,8 @@ training!"
             self.pdx_config.update_epochs(self.train_config.epochs_iters)
         if self.train_config.log_interval is not None:
             self.pdx_config.update_log_interval(self.train_config.log_interval)
+        if self.train_config.get("dy2st", False):
+            self.pdx_config.update_to_static(self.train_config.dy2st)
         if self.global_config.output is not None:
             self.pdx_config.update_save_dir(self.global_config.output)
 

+ 4 - 0
paddlex/repo_apis/PaddleTS_api/configs/AutoEncoder_ad.yaml

@@ -4,6 +4,10 @@ epoch: 5
 training: True 
 do_eval: True
 task: anomaly
+to_static_train: False
+use_amp: False
+amp_level: O2
+
 
 dataset: 
   name: TSADDataset

+ 4 - 0
paddlex/repo_apis/PaddleTS_api/configs/DLinear.yaml

@@ -4,6 +4,10 @@ predict_len: 336
 sampling_stride: 1
 do_eval: True
 epoch: 10
+to_static_train: False
+use_amp: False
+amp_level: O2
+
 dataset:
   name: TSDataset
   dataset_root: /data/ 

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/DLinear_ad.yaml

@@ -4,6 +4,9 @@ do_eval: True
 epoch: 5
 training: True 
 task: anomaly
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 dataset: 
   name: TSADDataset

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/NLinear.yaml

@@ -4,6 +4,9 @@ predict_len: 336
 sampling_stride: 1
 do_eval: True
 epoch: 5
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 
 dataset: 

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/Nonstationary.yaml

@@ -3,6 +3,9 @@ seq_len: 96
 predict_len: 96
 do_eval: True
 epoch: 5
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 
 dataset: 

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/Nonstationary_ad.yaml

@@ -4,6 +4,9 @@ do_eval: True
 epoch: 5
 training: True 
 task: anomaly
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 dataset: 
   name: TSADDataset

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/PatchTST.yaml

@@ -3,6 +3,9 @@ seq_len: 96
 predict_len: 96
 do_eval: True
 epoch: 5
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 dataset: 
   name: TSDataset

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/PatchTST_ad.yaml

@@ -4,6 +4,9 @@ epoch: 5
 training: True 
 do_eval: True
 task: anomaly
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 dataset: 
   name: TSADDataset

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/RLinear.yaml

@@ -4,6 +4,9 @@ predict_len: 336
 do_eval: True
 sampling_stride: 1
 epoch: 10
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 
 dataset: 

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/TiDE.yaml

@@ -3,6 +3,9 @@ seq_len: 720
 predict_len: 96
 do_eval: True
 epoch: 2
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 dataset: 
   name: TSDataset

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/TimesNet.yaml

@@ -4,6 +4,9 @@ predict_len: 96 #
 do_eval: True #
 sampling_stride: 1
 epoch: 10 # max_epochs
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 
 dataset: 

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/TimesNet_ad.yaml

@@ -4,6 +4,9 @@ do_eval: True
 epoch: 5
 training: True 
 task: anomaly
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 dataset: 
   name: TSADDataset

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/configs/TimesNet_cls.yaml

@@ -7,6 +7,9 @@ epoch: 30 # max_epochs
 training: True # 
 eval_metrics:  ['acc', ] 
 task: classification
+to_static_train: False
+use_amp: False
+amp_level: O2
 
 dataset: 
   name: TSCLSDataset

+ 21 - 0
paddlex/repo_apis/PaddleTS_api/ts_base/config.py

@@ -71,6 +71,26 @@ class BaseTSConfig(BaseConfig):
         """
         self.update({"epoch": epochs})
 
+    def update_to_static(self, dy2st: bool):
+        """update config to set dynamic to static mode
+
+        Args:
+            dy2st (bool): whether or not to use the dynamic to static mode.
+        """
+        self.update({"to_static_train": dy2st})
+
+    def update_amp(self, amp: str = "O2"):
+        """update AMP settings
+
+        Args:
+            amp (None | str): the AMP level if it is not None or `OFF`.
+        """
+        _cfg = {
+            "use_amp": True if amp is not None else False,
+            "amp_level": amp,
+        }
+        self.update(_cfg)
+
     def update_weights(self, weight_path: str):
         """update weight path
 
@@ -161,6 +181,7 @@ class BaseTSConfig(BaseConfig):
         """
         self.update({"log_interval": log_interval})
 
+
     def update_dataset(self, dataset_dir: str, dataset_type: str = None):
         """update dataset settings"""
         raise NotImplementedError

+ 5 - 6
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -73,11 +73,11 @@ class TSModel(BaseModel):
         if resume_path:
             raise ValueError("`resume_path` is not supported.")
         # No need to handle `ips`
-        if amp is not None and amp != "OFF":
-            raise ValueError(f"`amp`={amp} is not supported.")
-
-        if dy2st:
-            raise ValueError(f"`dy2st`={dy2st} is not supported.")
+        benchmark = kwargs.pop("benchmark", None)
+        if benchmark is not None:
+            amp = benchmark.get("amp", None)
+            if amp in ["O1", "O2"]:
+                config.update_amp(amp)
 
         if use_vdl:
             raise ValueError(f"`use_vdl`={use_vdl} is not supported.")
@@ -85,7 +85,6 @@ class TSModel(BaseModel):
         if device is not None:
             device_type, _ = parse_device(device)
             cli_args.append(CLIArgument("--device", device_type))
-
         if save_dir is not None:
             save_dir = abspath(save_dir)
         else: