浏览代码

fix ts show (#3536)

Sunflower7788 8 月之前
父节点
当前提交
fad90cde5c

+ 1 - 0
paddlex/inference/models/ts_classification/predictor.py

@@ -127,6 +127,7 @@ class TSClsPredictor(BasicPredictor):
         return {
             "input_path": batch_data,
             "input_ts": batch_raw_ts,
+            "input_ts_data": batch_raw_ts,
             "classification": batch_ts_preds,
             "target_cols": [self.config["info_params"]["target_cols"]]
         }

+ 1 - 1
paddlex/inference/models/ts_classification/result.py

@@ -65,7 +65,7 @@ class TSClsResult(BaseTSResult):
     def _to_img(self) -> Image.Image:
         """apply"""
         classification = self["classification"]
-        ts_input = pd.read_csv(self["input_path"])
+        ts_input = self["input_ts_data"]
         return {"res": visualize(classification, ts_input, self["target_cols"])}
     
     def _to_csv(self) -> Any:

+ 1 - 0
paddlex/inference/models/ts_forecasting/predictor.py

@@ -152,5 +152,6 @@ class TSFcPredictor(BasicPredictor):
         return {
             "input_path": batch_data,
             "input_ts": batch_raw_ts,
+            "cutoff_ts": batch_raw_ts,
             "forecast": batch_ts_preds,
         }

+ 1 - 1
paddlex/inference/models/ts_forecasting/result.py

@@ -70,7 +70,7 @@ class TSFcResult(BaseTSResult):
     def _to_img(self) -> Image.Image:
         """apply"""
         forecast = self["forecast"]
-        ts_input = pd.read_csv(self["input_path"])
+        ts_input = self["cutoff_ts"]
         return {"res": visualize(forecast, ts_input)}
     
     def _to_csv(self) -> Any: