result.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import io
  15. from typing import Any
  16. from PIL import Image
  17. from ....utils.deps import function_requires_deps, is_dep_available
  18. from ...common.result import BaseTSResult
  19. if is_dep_available("matplotlib"):
  20. import matplotlib.pyplot as plt
  21. import pandas as pd
  22. @function_requires_deps("matplotlib")
  23. def visualize(forecast: pd.DataFrame) -> Image.Image:
  24. """
  25. Visualizes both the time series forecast and actual results, returning them as a Pillow image.
  26. Args:
  27. forecast (pd.DataFrame): The DataFrame containing the forecast data.
  28. Returns:
  29. Image.Image: The visualized result as a Pillow image.
  30. """
  31. plt.figure(figsize=(12, 6))
  32. forecast_columns = forecast.columns
  33. forecast.index.name
  34. forecast.index = forecast.index.astype(str)
  35. plt.step(
  36. forecast.index,
  37. forecast[forecast_columns[0]],
  38. where="post",
  39. label="Anomaly",
  40. color="red",
  41. )
  42. plt.title("Time Series Anomaly Detection")
  43. plt.xlabel("Time")
  44. plt.ylabel(forecast_columns[0])
  45. plt.legend()
  46. plt.grid(True)
  47. plt.xticks(ticks=range(0, len(forecast), 10))
  48. plt.xticks(rotation=45)
  49. buf = io.BytesIO()
  50. plt.savefig(buf, bbox_inches="tight")
  51. buf.seek(0)
  52. plt.close()
  53. image = Image.open(buf)
  54. return image
  55. class TSAdResult(BaseTSResult):
  56. """A class representing the result of a time series anomaly detection task."""
  57. def _to_img(self) -> Image.Image:
  58. """apply"""
  59. anomaly = self["anomaly"]
  60. return {"res": visualize(anomaly)}
  61. def _to_csv(self) -> Any:
  62. """
  63. Converts the anomaly detection results to a CSV format.
  64. Returns:
  65. Any: The anomaly data formatted for CSV output, typically a DataFrame or similar structure.
  66. """
  67. return {"res": self["anomaly"]}