|
|
@@ -16,6 +16,7 @@ from typing import Any, Union, Dict, List, Tuple
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import os
|
|
|
+import copy
|
|
|
|
|
|
from ....modules.ts_classification.model_list import MODELS
|
|
|
from ...common.batch_sampler import TSBatchSampler
|
|
|
@@ -109,6 +110,7 @@ class TSClsPredictor(BasicPredictor):
|
|
|
Dict[str, Any]: A dictionary containing the paths to the input data, the raw input time series, and the classification results.
|
|
|
"""
|
|
|
batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
|
|
|
+ batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
|
|
|
|
|
|
if "TSNormalize" in self.preprocessors:
|
|
|
batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_raw_ts)
|
|
|
@@ -127,7 +129,7 @@ class TSClsPredictor(BasicPredictor):
|
|
|
return {
|
|
|
"input_path": batch_data,
|
|
|
"input_ts": batch_raw_ts,
|
|
|
- "input_ts_data": batch_raw_ts,
|
|
|
+ "input_ts_data": batch_raw_ts_ori,
|
|
|
"classification": batch_ts_preds,
|
|
|
"target_cols": [self.config["info_params"]["target_cols"]]
|
|
|
}
|