processors.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Dict, Any, Union
  15. import joblib
  16. import numpy as np
  17. import pandas as pd
  18. from ...utils.benchmark import benchmark
  19. class TSDeNormalize:
  20. """A class to de-normalize time series prediction data using a pre-fitted scaler."""
  21. def __init__(self, scale_path: str, params_info: dict):
  22. """
  23. Initializes the TSDeNormalize class with a scaler and parameters information.
  24. Args:
  25. scale_path (str): The file path to the serialized scaler object.
  26. params_info (dict): Additional parameters information.
  27. """
  28. super().__init__()
  29. self.scaler = joblib.load(scale_path)
  30. self.params_info = params_info
  31. @benchmark.timeit
  32. def __call__(self, preds_list: List[pd.DataFrame]) -> List[pd.DataFrame]:
  33. """
  34. Applies de-normalization to a list of prediction DataFrames.
  35. Args:
  36. preds_list (List[pd.DataFrame]): A list of DataFrames containing normalized prediction data.
  37. Returns:
  38. List[pd.DataFrame]: A list of DataFrames with de-normalized prediction data.
  39. """
  40. return [self.tsdenorm(pred) for pred in preds_list]
  41. def tsdenorm(self, pred: pd.DataFrame) -> pd.DataFrame:
  42. """
  43. De-normalizes a single prediction DataFrame.
  44. Args:
  45. pred (pd.DataFrame): A DataFrame containing normalized prediction data.
  46. Returns:
  47. pd.DataFrame: A DataFrame with de-normalized prediction data.
  48. """
  49. scale_cols = pred.columns.values.tolist()
  50. pred[scale_cols] = self.scaler.inverse_transform(pred[scale_cols])
  51. return pred
  52. class ArraytoTS:
  53. """A class to convert arrays of predictions into time series format."""
  54. def __init__(self, info_params: Dict[str, Any]):
  55. """
  56. Initializes the ArraytoTS class with the given parameters.
  57. Args:
  58. info_params (Dict[str, Any]): Configuration parameters including target columns, frequency, and time column name.
  59. """
  60. super().__init__()
  61. self.info_params = info_params
  62. @benchmark.timeit
  63. def __call__(
  64. self, ori_ts_list: List[Dict[str, Any]], pred_list: List[np.ndarray]
  65. ) -> List[pd.DataFrame]:
  66. """
  67. Converts a list of arrays to a list of time series DataFrames.
  68. Args:
  69. ori_ts_list (List[Dict[str, Any]]): Original time series data for each prediction, including past and covariate information.
  70. pred_list (List[np.ndarray]): List of prediction arrays corresponding to each time series in ori_ts_list.
  71. Returns:
  72. List[pd.DataFrame]: A list of DataFrames, each representing the forecasted time series.
  73. """
  74. return [
  75. self.arraytots(ori_ts, pred) for ori_ts, pred in zip(ori_ts_list, pred_list)
  76. ]
  77. def arraytots(self, ori_ts: Dict[str, Any], pred: np.ndarray) -> pd.DataFrame:
  78. """
  79. Converts a single array prediction to a time series DataFrame.
  80. Args:
  81. ori_ts (Dict[str, Any]): Original time series data for a single time series.
  82. pred (np.ndarray): Prediction array for the given time series.
  83. Returns:
  84. pd.DataFrame: A DataFrame representing the forecasted time series.
  85. Raises:
  86. ValueError: If none of the expected keys are found in ori_ts.
  87. """
  88. pred = pred[0]
  89. if ori_ts.get("past_target", None) is not None:
  90. ts = ori_ts["past_target"]
  91. elif ori_ts.get("observed_cov_numeric", None) is not None:
  92. ts = ori_ts["observed_cov_numeric"]
  93. elif ori_ts.get("known_cov_numeric", None) is not None:
  94. ts = ori_ts["known_cov_numeric"]
  95. elif ori_ts.get("static_cov_numeric", None) is not None:
  96. ts = ori_ts["static_cov_numeric"]
  97. else:
  98. raise ValueError("No value in ori_ts")
  99. column_name = (
  100. self.info_params["target_cols"]
  101. if "target_cols" in self.info_params
  102. else self.info_params["feature_cols"]
  103. )
  104. if isinstance(self.info_params["freq"], str):
  105. past_target_index = ts.index
  106. if past_target_index.freq is None:
  107. past_target_index.freq = pd.infer_freq(ts.index)
  108. future_target_index = pd.date_range(
  109. past_target_index[-1] + past_target_index.freq,
  110. periods=pred.shape[0],
  111. freq=self.info_params["freq"],
  112. name=self.info_params["time_col"],
  113. )
  114. elif isinstance(self.info_params["freq"], int):
  115. start_idx = max(ts.index) + 1
  116. stop_idx = start_idx + pred.shape[0]
  117. future_target_index = pd.RangeIndex(
  118. start=start_idx,
  119. stop=stop_idx,
  120. step=self.info_params["freq"],
  121. name=self.info_params["time_col"],
  122. )
  123. future_target = pd.DataFrame(
  124. np.reshape(pred, newshape=[pred.shape[0], -1]),
  125. index=future_target_index,
  126. columns=column_name,
  127. )
  128. return future_target