processors.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List
  15. import numpy as np
  16. import pandas as pd
  17. from ...utils.benchmark import benchmark
  18. @benchmark.timeit
  19. class GetAnomaly:
  20. """A class to detect anomalies in time series data based on a model threshold."""
  21. def __init__(self, model_threshold: float, info_params: Dict[str, Any]):
  22. """
  23. Initializes the GetAnomaly class with a model threshold and parameters information.
  24. Args:
  25. model_threshold (float): The threshold for determining anomalies.
  26. info_params (Dict[str, Any]): Configuration parameters including target columns and time column name.
  27. """
  28. super().__init__()
  29. self.model_threshold = model_threshold
  30. self.info_params = info_params
  31. def __call__(
  32. self, ori_ts_list: List[Dict[str, Any]], pred_list: List[np.ndarray]
  33. ) -> List[pd.DataFrame]:
  34. """
  35. Detects anomalies for a list of time series predictions.
  36. Args:
  37. ori_ts_list (List[Dict[str, Any]]): Original time series data for each prediction, including past and covariate information.
  38. pred_list (List[np.ndarray]): List of prediction arrays corresponding to each time series in ori_ts_list.
  39. Returns:
  40. List[pd.DataFrame]: A list of DataFrames, each containing anomaly labels for the time series.
  41. """
  42. return [
  43. self.getanomaly(ori_ts, pred)
  44. for ori_ts, pred in zip(ori_ts_list, pred_list)
  45. ]
  46. def getanomaly(self, ori_ts: Dict[str, Any], pred: np.ndarray) -> pd.DataFrame:
  47. """
  48. Detects anomalies in a single time series prediction.
  49. Args:
  50. ori_ts (Dict[str, Any]): Original time series data for a single time series.
  51. pred (np.ndarray): Prediction array for the given time series.
  52. Returns:
  53. pd.DataFrame: A DataFrame containing anomaly labels for the time series.
  54. Raises:
  55. ValueError: If none of the expected keys are found in ori_ts.
  56. """
  57. pred = pred[0]
  58. if ori_ts.get("past_target", None) is not None:
  59. ts = ori_ts["past_target"]
  60. elif ori_ts.get("observed_cov_numeric", None) is not None:
  61. ts = ori_ts["observed_cov_numeric"]
  62. elif ori_ts.get("known_cov_numeric", None) is not None:
  63. ts = ori_ts["known_cov_numeric"]
  64. elif ori_ts.get("static_cov_numeric", None) is not None:
  65. ts = ori_ts["static_cov_numeric"]
  66. else:
  67. raise ValueError("No value in ori_ts")
  68. column_name = (
  69. self.info_params["target_cols"]
  70. if "target_cols" in self.info_params
  71. else self.info_params["feature_cols"]
  72. )
  73. anomaly_score = np.mean(np.square(pred - np.array(ts)), axis=-1)
  74. anomaly_label = (anomaly_score >= self.model_threshold) + 0
  75. past_target_index = ts.index
  76. past_target_index.name = self.info_params["time_col"]
  77. anomaly_label_df = pd.DataFrame(
  78. np.reshape(anomaly_label, newshape=[pred.shape[0], -1]),
  79. index=past_target_index,
  80. columns=["label"],
  81. )
  82. return anomaly_label_df