|
|
@@ -13,12 +13,66 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
from typing import Any
|
|
|
+import io
|
|
|
+import pandas as pd
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
from ...common.result import BaseTSResult
|
|
|
|
|
|
|
|
|
+def visualize(forecast: pd.DataFrame, actual_data: pd.DataFrame) -> Image.Image:
|
|
|
+ """
|
|
|
+ Visualizes both the time series forecast and actual results, returning them as a Pillow image.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ forecast (pd.DataFrame): The DataFrame containing the forecast data.
|
|
|
+ actual_data (pd.Series): The actual observed data for comparison.
|
|
|
+ title (str): The title of the plot.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Image.Image: The visualized result as a Pillow image.
|
|
|
+ """
|
|
|
+ plt.figure(figsize=(12, 6))
|
|
|
+ forecast_columns = forecast.columns
|
|
|
+ index_name = forecast.index.name
|
|
|
+ actual_data = actual_data.set_index(index_name)
|
|
|
+
|
|
|
+ actual_data.index = actual_data.index.astype(str)
|
|
|
+ forecast.index = forecast.index.astype(str)
|
|
|
+
|
|
|
+ length = min(len(forecast), len(actual_data))
|
|
|
+ actual_data = actual_data.tail(length)
|
|
|
+
|
|
|
+ plt.plot(actual_data.index, actual_data[forecast_columns[0]], label='Actual Data', color='blue', linestyle='--')
|
|
|
+ plt.plot(forecast.index, forecast[forecast_columns[0]], label='Forecast', color='red')
|
|
|
+
|
|
|
+ plt.title('Time Series Forecast')
|
|
|
+ plt.xlabel('Time')
|
|
|
+ plt.ylabel(forecast_columns[0])
|
|
|
+ plt.legend()
|
|
|
+ plt.grid(True)
|
|
|
+ plt.xticks(ticks=range(0, 2*length, 10))
|
|
|
+ plt.xticks(rotation=45)
|
|
|
+
|
|
|
+ buf = io.BytesIO()
|
|
|
+ plt.savefig(buf, bbox_inches='tight')
|
|
|
+ buf.seek(0)
|
|
|
+ plt.close()
|
|
|
+ image = Image.open(buf)
|
|
|
+
|
|
|
+ return image
|
|
|
+
|
|
|
+
|
|
|
class TSFcResult(BaseTSResult):
|
|
|
"""A class representing the result of a time series forecasting task."""
|
|
|
|
|
|
+ def _to_img(self) -> Image.Image:
|
|
|
+ """apply"""
|
|
|
+ forecast = self["forecast"]
|
|
|
+ ts_input = pd.read_csv(self["input_path"])
|
|
|
+ return {"res": visualize(forecast, ts_input)}
|
|
|
+
|
|
|
def _to_csv(self) -> Any:
|
|
|
"""
|
|
|
Converts the forecasting results to a CSV format.
|