Sunflower7788 8 luni în urmă
părinte
comite
54252e58c6

+ 3 - 1
paddlex/inference/models/ts_classification/predictor.py

@@ -16,6 +16,7 @@ from typing import Any, Union, Dict, List, Tuple
 import numpy as np
 import pandas as pd
 import os
+import copy
 
 from ....modules.ts_classification.model_list import MODELS
 from ...common.batch_sampler import TSBatchSampler
@@ -109,6 +110,7 @@ class TSClsPredictor(BasicPredictor):
             Dict[str, Any]: A dictionary containing the paths to the input data, the raw input time series, and the classification results.
         """
         batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+        batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
 
         if "TSNormalize" in self.preprocessors:
             batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_raw_ts)
@@ -127,7 +129,7 @@ class TSClsPredictor(BasicPredictor):
         return {
             "input_path": batch_data,
             "input_ts": batch_raw_ts,
-            "input_ts_data": batch_raw_ts,
+            "input_ts_data": batch_raw_ts_ori,
             "classification": batch_ts_preds,
             "target_cols": [self.config["info_params"]["target_cols"]]
         }

+ 3 - 1
paddlex/inference/models/ts_forecasting/predictor.py

@@ -16,6 +16,7 @@ from typing import Any, Union, Dict, List, Tuple
 import numpy as np
 import pandas as pd
 import os
+import copy
 
 from ....modules.ts_forecast.model_list import MODELS
 from ...common.batch_sampler import TSBatchSampler
@@ -122,6 +123,7 @@ class TSFcPredictor(BasicPredictor):
         """
 
         batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+        batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
         batch_cutoff_ts = self.preprocessors["TSCutOff"](ts_list=batch_raw_ts)
 
         if "TSNormalize" in self.preprocessors:
@@ -152,6 +154,6 @@ class TSFcPredictor(BasicPredictor):
         return {
             "input_path": batch_data,
             "input_ts": batch_raw_ts,
-            "cutoff_ts": batch_raw_ts,
+            "cutoff_ts": batch_raw_ts_ori,
             "forecast": batch_ts_preds,
         }