processors.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Dict, Any, Union
  15. import joblib
  16. import numpy as np
  17. import pandas as pd
  18. class TSDeNormalize:
  19. """A class to de-normalize time series prediction data using a pre-fitted scaler."""
  20. def __init__(self, scale_path: str, params_info: dict):
  21. """
  22. Initializes the TSDeNormalize class with a scaler and parameters information.
  23. Args:
  24. scale_path (str): The file path to the serialized scaler object.
  25. params_info (dict): Additional parameters information.
  26. """
  27. super().__init__()
  28. self.scaler = joblib.load(scale_path)
  29. self.params_info = params_info
  30. def __call__(self, preds_list: List[pd.DataFrame]) -> List[pd.DataFrame]:
  31. """
  32. Applies de-normalization to a list of prediction DataFrames.
  33. Args:
  34. preds_list (List[pd.DataFrame]): A list of DataFrames containing normalized prediction data.
  35. Returns:
  36. List[pd.DataFrame]: A list of DataFrames with de-normalized prediction data.
  37. """
  38. return [self.tsdenorm(pred) for pred in preds_list]
  39. def tsdenorm(self, pred: pd.DataFrame) -> pd.DataFrame:
  40. """
  41. De-normalizes a single prediction DataFrame.
  42. Args:
  43. pred (pd.DataFrame): A DataFrame containing normalized prediction data.
  44. Returns:
  45. pd.DataFrame: A DataFrame with de-normalized prediction data.
  46. """
  47. scale_cols = pred.columns.values.tolist()
  48. pred[scale_cols] = self.scaler.inverse_transform(pred[scale_cols])
  49. return pred
  50. class ArraytoTS:
  51. """A class to convert arrays of predictions into time series format."""
  52. def __init__(self, info_params: Dict[str, Any]):
  53. """
  54. Initializes the ArraytoTS class with the given parameters.
  55. Args:
  56. info_params (Dict[str, Any]): Configuration parameters including target columns, frequency, and time column name.
  57. """
  58. super().__init__()
  59. self.info_params = info_params
  60. def __call__(
  61. self, ori_ts_list: List[Dict[str, Any]], pred_list: List[np.ndarray]
  62. ) -> List[pd.DataFrame]:
  63. """
  64. Converts a list of arrays to a list of time series DataFrames.
  65. Args:
  66. ori_ts_list (List[Dict[str, Any]]): Original time series data for each prediction, including past and covariate information.
  67. pred_list (List[np.ndarray]): List of prediction arrays corresponding to each time series in ori_ts_list.
  68. Returns:
  69. List[pd.DataFrame]: A list of DataFrames, each representing the forecasted time series.
  70. """
  71. return [
  72. self.arraytots(ori_ts, pred) for ori_ts, pred in zip(ori_ts_list, pred_list)
  73. ]
  74. def arraytots(self, ori_ts: Dict[str, Any], pred: np.ndarray) -> pd.DataFrame:
  75. """
  76. Converts a single array prediction to a time series DataFrame.
  77. Args:
  78. ori_ts (Dict[str, Any]): Original time series data for a single time series.
  79. pred (np.ndarray): Prediction array for the given time series.
  80. Returns:
  81. pd.DataFrame: A DataFrame representing the forecasted time series.
  82. Raises:
  83. ValueError: If none of the expected keys are found in ori_ts.
  84. """
  85. pred = pred[0]
  86. if ori_ts.get("past_target", None) is not None:
  87. ts = ori_ts["past_target"]
  88. elif ori_ts.get("observed_cov_numeric", None) is not None:
  89. ts = ori_ts["observed_cov_numeric"]
  90. elif ori_ts.get("known_cov_numeric", None) is not None:
  91. ts = ori_ts["known_cov_numeric"]
  92. elif ori_ts.get("static_cov_numeric", None) is not None:
  93. ts = ori_ts["static_cov_numeric"]
  94. else:
  95. raise ValueError("No value in ori_ts")
  96. column_name = (
  97. self.info_params["target_cols"]
  98. if "target_cols" in self.info_params
  99. else self.info_params["feature_cols"]
  100. )
  101. if isinstance(self.info_params["freq"], str):
  102. past_target_index = ts.index
  103. if past_target_index.freq is None:
  104. past_target_index.freq = pd.infer_freq(ts.index)
  105. future_target_index = pd.date_range(
  106. past_target_index[-1] + past_target_index.freq,
  107. periods=pred.shape[0],
  108. freq=self.info_params["freq"],
  109. name=self.info_params["time_col"],
  110. )
  111. elif isinstance(self.info_params["freq"], int):
  112. start_idx = max(ts.index) + 1
  113. stop_idx = start_idx + pred.shape[0]
  114. future_target_index = pd.RangeIndex(
  115. start=start_idx,
  116. stop=stop_idx,
  117. step=self.info_params["freq"],
  118. name=self.info_params["time_col"],
  119. )
  120. future_target = pd.DataFrame(
  121. np.reshape(pred, newshape=[pred.shape[0], -1]),
  122. index=future_target_index,
  123. columns=column_name,
  124. )
  125. return future_target