result.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import io
  15. from typing import Any
  16. from PIL import Image
  17. from ....utils.deps import function_requires_deps, is_dep_available
  18. from ...common.result import BaseTSResult
  19. if is_dep_available("matplotlib"):
  20. import matplotlib.pyplot as plt
  21. @function_requires_deps("matplotlib")
  22. def visualize(predicted_label, input_ts, target_cols):
  23. """
  24. Visualize time series data and its prediction results.
  25. Parameters:
  26. - input_ts: A DataFrame containing the input_ts.
  27. - predicted_label: A list of predicted class labels.
  28. Returns:
  29. - image: An image object containing the visualization result.
  30. """
  31. # 设置图形大小
  32. plt.figure(figsize=(12, 6))
  33. input_ts.columns
  34. input_ts.index = input_ts.index.astype(str)
  35. length = len(input_ts)
  36. value = predicted_label.loc[0, "classid"]
  37. plt.plot(
  38. input_ts.index,
  39. input_ts[target_cols[0]],
  40. label=f"Predicted classid: {value}",
  41. color="blue",
  42. )
  43. # 设置图形标题和标签
  44. plt.title("Time Series input_ts with Predicted Labels")
  45. plt.xlabel("Time")
  46. plt.ylabel("Value")
  47. plt.legend()
  48. plt.grid(True)
  49. plt.xticks(ticks=range(0, length, 10))
  50. plt.xticks(rotation=45)
  51. # 保存图像到内存
  52. buf = io.BytesIO()
  53. plt.savefig(buf, bbox_inches="tight")
  54. buf.seek(0)
  55. plt.close()
  56. image = Image.open(buf)
  57. return image
  58. class TSClsResult(BaseTSResult):
  59. """A class representing the result of a time series classification task."""
  60. def _to_img(self) -> Image.Image:
  61. """apply"""
  62. classification = self["classification"]
  63. ts_input = self["input_ts_data"]
  64. return {"res": visualize(classification, ts_input, self["target_cols"])}
  65. def _to_csv(self) -> Any:
  66. """
  67. Converts the classification results to a CSV format.
  68. Returns:
  69. Any: The classification data formatted for CSV output, typically a DataFrame or similar structure.
  70. """
  71. return {"res": self["classification"]}