result.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any
  15. import io
  16. import pandas as pd
  17. import matplotlib.pyplot as plt
  18. from PIL import Image
  19. from ...common.result import BaseTSResult
  20. def visualize(forecast: pd.DataFrame) -> Image.Image:
  21. """
  22. Visualizes both the time series forecast and actual results, returning them as a Pillow image.
  23. Args:
  24. forecast (pd.DataFrame): The DataFrame containing the forecast data.
  25. Returns:
  26. Image.Image: The visualized result as a Pillow image.
  27. """
  28. plt.figure(figsize=(12, 6))
  29. forecast_columns = forecast.columns
  30. index_name = forecast.index.name
  31. forecast.index = forecast.index.astype(str)
  32. plt.step(forecast.index, forecast[forecast_columns[0]], where='post', label='Anomaly', color='red')
  33. plt.title('Time Series Anomaly Detection')
  34. plt.xlabel('Time')
  35. plt.ylabel(forecast_columns[0])
  36. plt.legend()
  37. plt.grid(True)
  38. plt.xticks(ticks=range(0, len(forecast), 10))
  39. plt.xticks(rotation=45)
  40. buf = io.BytesIO()
  41. plt.savefig(buf, bbox_inches='tight')
  42. buf.seek(0)
  43. plt.close()
  44. image = Image.open(buf)
  45. return image
  46. class TSAdResult(BaseTSResult):
  47. """A class representing the result of a time series anomaly detection task."""
  48. def _to_img(self) -> Image.Image:
  49. """apply"""
  50. anomaly = self["anomaly"]
  51. return {"res": visualize(anomaly)}
  52. def _to_csv(self) -> Any:
  53. """
  54. Converts the anomaly detection results to a CSV format.
  55. Returns:
  56. Any: The anomaly data formatted for CSV output, typically a DataFrame or similar structure.
  57. """
  58. return {"res": self["anomaly"]}