Sunflower7788 8 meses atrás
pai
commit
da954c04b4

+ 3 - 2
paddlex/inference/common/result/base_ts_result.py

@@ -13,11 +13,11 @@
 # limitations under the License.
 
 from .base_result import BaseResult
-from .mixin import CSVMixin
+from .mixin import CSVMixin, ImgMixin
 from ...utils.io import CSVWriter
 
 
-class BaseTSResult(BaseResult, CSVMixin):
+class BaseTSResult(BaseResult, CSVMixin, ImgMixin):
     """Base class for times series results."""
 
     INPUT_TS_KEY = "input_ts"
@@ -39,3 +39,4 @@ class BaseTSResult(BaseResult, CSVMixin):
 
         super().__init__(data)
         CSVMixin.__init__(self, "pandas")
+        ImgMixin.__init__(self, "pillow")

+ 43 - 0
paddlex/inference/models/ts_anomaly_detection/result.py

@@ -13,12 +13,55 @@
 # limitations under the License.
 
 from typing import Any
+import io
+import pandas as pd
+import matplotlib.pyplot as plt
+from PIL import Image
+
 from ...common.result import BaseTSResult
 
 
+def visualize(forecast: pd.DataFrame) -> Image.Image:
+    """
+    Visualizes both the time series forecast and actual results, returning them as a Pillow image.
+
+    Args:
+        forecast (pd.DataFrame): The DataFrame containing the forecast data.
+
+    Returns:
+        Image.Image: The visualized result as a Pillow image.
+    """
+    plt.figure(figsize=(12, 6))
+    forecast_columns = forecast.columns
+    index_name = forecast.index.name
+    forecast.index = forecast.index.astype(str)
+
+    plt.step(forecast.index, forecast[forecast_columns[0]], where='post', label='Anomaly', color='red')
+    plt.title('Time Series Anomaly Detection')
+    plt.xlabel('Time')
+    plt.ylabel(forecast_columns[0])
+    plt.legend()
+    plt.grid(True)
+    plt.xticks(ticks=range(0, len(forecast), 10))
+    plt.xticks(rotation=45)
+
+    buf = io.BytesIO()
+    plt.savefig(buf, bbox_inches='tight')
+    buf.seek(0)
+    plt.close()
+    image = Image.open(buf)
+
+    return image
+
+
 class TSAdResult(BaseTSResult):
     """A class representing the result of a time series anomaly detection task."""
 
+    def _to_img(self) -> Image.Image:
+        """apply"""
+        anomaly = self["anomaly"]
+        return {"res": visualize(anomaly)}
+    
     def _to_csv(self) -> Any:
         """
         Converts the anomaly detection results to a CSV format.

+ 1 - 0
paddlex/inference/models/ts_classification/predictor.py

@@ -128,4 +128,5 @@ class TSClsPredictor(BasicPredictor):
             "input_path": batch_data,
             "input_ts": batch_raw_ts,
             "classification": batch_ts_preds,
+            "target_cols": [self.config["info_params"]["target_cols"]]
         }

+ 49 - 0
paddlex/inference/models/ts_classification/result.py

@@ -13,12 +13,61 @@
 # limitations under the License.
 
 from typing import Any
+import io
+import pandas as pd
+import matplotlib.pyplot as plt
+from PIL import Image
+
 from ...common.result import BaseTSResult
 
 
+def visualize(predicted_label, input_ts, target_cols):
+    """
+    Visualize time series data and its prediction results.
+
+    Parameters:
+    - input_ts: A DataFrame containing the input_ts.
+    - predicted_label: A list of predicted class labels.
+
+    Returns:
+    - image: An image object containing the visualization result.
+    """
+    # 设置图形大小
+    plt.figure(figsize=(12, 6))
+    input_ts_columns = input_ts.columns
+    input_ts.index = input_ts.index.astype(str)
+    length = len(input_ts)
+    value = predicted_label.loc[0, 'classid']
+    plt.plot(input_ts.index, input_ts[target_cols[0]], label=f'Predicted classid: {value}', color='blue')
+
+    # 设置图形标题和标签
+    plt.title('Time Series input_ts with Predicted Labels')
+    plt.xlabel('Time')
+    plt.ylabel('Value')
+    plt.legend()
+    plt.grid(True)
+    plt.xticks(ticks=range(0, length, 10))
+    plt.xticks(rotation=45)
+
+    # 保存图像到内存
+    buf = io.BytesIO()
+    plt.savefig(buf, bbox_inches='tight')
+    buf.seek(0)
+    plt.close()
+    image = Image.open(buf)
+
+    return image
+
+
 class TSClsResult(BaseTSResult):
     """A class representing the result of a time series classification task."""
 
+    def _to_img(self) -> Image.Image:
+        """apply"""
+        classification = self["classification"]
+        ts_input = pd.read_csv(self["input_path"])
+        return {"res": visualize(classification, ts_input, self["target_cols"])}
+    
     def _to_csv(self) -> Any:
         """
         Converts the classification results to a CSV format.

+ 54 - 0
paddlex/inference/models/ts_forecasting/result.py

@@ -13,12 +13,66 @@
 # limitations under the License.
 
 from typing import Any
+import io
+import pandas as pd
+import matplotlib.pyplot as plt
+from PIL import Image
+
 from ...common.result import BaseTSResult
 
 
+def visualize(forecast: pd.DataFrame, actual_data: pd.DataFrame) -> Image.Image:
+    """
+    Visualizes both the time series forecast and actual results, returning them as a Pillow image.
+
+    Args:
+        forecast (pd.DataFrame): The DataFrame containing the forecast data.
+        actual_data (pd.Series): The actual observed data for comparison.
+        title (str): The title of the plot.
+
+    Returns:
+        Image.Image: The visualized result as a Pillow image.
+    """
+    plt.figure(figsize=(12, 6))
+    forecast_columns = forecast.columns
+    index_name = forecast.index.name
+    actual_data = actual_data.set_index(index_name)
+
+    actual_data.index = actual_data.index.astype(str)
+    forecast.index = forecast.index.astype(str)
+
+    length = min(len(forecast), len(actual_data))
+    actual_data = actual_data.tail(length)
+    
+    plt.plot(actual_data.index, actual_data[forecast_columns[0]], label='Actual Data', color='blue', linestyle='--')
+    plt.plot(forecast.index, forecast[forecast_columns[0]], label='Forecast', color='red')
+
+    plt.title('Time Series Forecast')
+    plt.xlabel('Time')
+    plt.ylabel(forecast_columns[0])
+    plt.legend()
+    plt.grid(True)
+    plt.xticks(ticks=range(0, 2*length, 10))
+    plt.xticks(rotation=45)
+
+    buf = io.BytesIO()
+    plt.savefig(buf, bbox_inches='tight')
+    buf.seek(0)
+    plt.close()
+    image = Image.open(buf)
+
+    return image
+
+
 class TSFcResult(BaseTSResult):
     """A class representing the result of a time series forecasting task."""
 
+    def _to_img(self) -> Image.Image:
+        """apply"""
+        forecast = self["forecast"]
+        ts_input = pd.read_csv(self["input_path"])
+        return {"res": visualize(forecast, ts_input)}
+    
     def _to_csv(self) -> Any:
         """
         Converts the forecasting results to a CSV format.