result.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import io
  15. from typing import Any
  16. import pandas as pd
  17. from PIL import Image
  18. from ....utils.deps import function_requires_deps, is_dep_available
  19. from ...common.result import BaseTSResult
  20. if is_dep_available("matplotlib"):
  21. import matplotlib.pyplot as plt
  22. @function_requires_deps("matplotlib")
  23. def visualize(forecast: pd.DataFrame, actual_data: pd.DataFrame) -> Image.Image:
  24. """
  25. Visualizes both the time series forecast and actual results, returning them as a Pillow image.
  26. Args:
  27. forecast (pd.DataFrame): The DataFrame containing the forecast data.
  28. actual_data (pd.Series): The actual observed data for comparison.
  29. title (str): The title of the plot.
  30. Returns:
  31. Image.Image: The visualized result as a Pillow image.
  32. """
  33. plt.figure(figsize=(12, 6))
  34. forecast_columns = forecast.columns
  35. index_name = forecast.index.name
  36. actual_data = actual_data.set_index(index_name)
  37. actual_data.index = actual_data.index.astype(str)
  38. forecast.index = forecast.index.astype(str)
  39. length = min(len(forecast), len(actual_data))
  40. actual_data = actual_data.tail(length)
  41. plt.plot(
  42. actual_data.index,
  43. actual_data[forecast_columns[0]],
  44. label="Actual Data",
  45. color="blue",
  46. linestyle="--",
  47. )
  48. plt.plot(
  49. forecast.index, forecast[forecast_columns[0]], label="Forecast", color="red"
  50. )
  51. plt.title("Time Series Forecast")
  52. plt.xlabel("Time")
  53. plt.ylabel(forecast_columns[0])
  54. plt.legend()
  55. plt.grid(True)
  56. plt.xticks(ticks=range(0, 2 * length, 10))
  57. plt.xticks(rotation=45)
  58. buf = io.BytesIO()
  59. plt.savefig(buf, bbox_inches="tight")
  60. buf.seek(0)
  61. plt.close()
  62. image = Image.open(buf)
  63. return image
  64. class TSFcResult(BaseTSResult):
  65. """A class representing the result of a time series forecasting task."""
  66. def _to_img(self) -> Image.Image:
  67. """apply"""
  68. forecast = self["forecast"]
  69. ts_input = self["cutoff_ts"]
  70. return {"res": visualize(forecast, ts_input)}
  71. def _to_csv(self) -> Any:
  72. """
  73. Converts the forecasting results to a CSV format.
  74. Returns:
  75. Any: The forecast data formatted for CSV output, typically a DataFrame or similar structure.
  76. """
  77. return {"res": self["forecast"]}