processors.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List
  15. import numpy as np
  16. import pandas as pd
  17. from ....utils.deps import class_requires_deps, is_dep_available
  18. from ...utils.benchmark import benchmark
  19. if is_dep_available("joblib"):
  20. import joblib
  21. @benchmark.timeit
  22. @class_requires_deps("joblib")
  23. class TSDeNormalize:
  24. """A class to de-normalize time series prediction data using a pre-fitted scaler."""
  25. def __init__(self, scale_path: str, params_info: dict):
  26. """
  27. Initializes the TSDeNormalize class with a scaler and parameters information.
  28. Args:
  29. scale_path (str): The file path to the serialized scaler object.
  30. params_info (dict): Additional parameters information.
  31. """
  32. super().__init__()
  33. self.scaler = joblib.load(scale_path)
  34. self.params_info = params_info
  35. def __call__(self, preds_list: List[pd.DataFrame]) -> List[pd.DataFrame]:
  36. """
  37. Applies de-normalization to a list of prediction DataFrames.
  38. Args:
  39. preds_list (List[pd.DataFrame]): A list of DataFrames containing normalized prediction data.
  40. Returns:
  41. List[pd.DataFrame]: A list of DataFrames with de-normalized prediction data.
  42. """
  43. return [self.tsdenorm(pred) for pred in preds_list]
  44. def tsdenorm(self, pred: pd.DataFrame) -> pd.DataFrame:
  45. """
  46. De-normalizes a single prediction DataFrame.
  47. Args:
  48. pred (pd.DataFrame): A DataFrame containing normalized prediction data.
  49. Returns:
  50. pd.DataFrame: A DataFrame with de-normalized prediction data.
  51. """
  52. scale_cols = pred.columns.values.tolist()
  53. pred[scale_cols] = self.scaler.inverse_transform(pred[scale_cols])
  54. return pred
  55. @benchmark.timeit
  56. class ArraytoTS:
  57. """A class to convert arrays of predictions into time series format."""
  58. def __init__(self, info_params: Dict[str, Any]):
  59. """
  60. Initializes the ArraytoTS class with the given parameters.
  61. Args:
  62. info_params (Dict[str, Any]): Configuration parameters including target columns, frequency, and time column name.
  63. """
  64. super().__init__()
  65. self.info_params = info_params
  66. def __call__(
  67. self, ori_ts_list: List[Dict[str, Any]], pred_list: List[np.ndarray]
  68. ) -> List[pd.DataFrame]:
  69. """
  70. Converts a list of arrays to a list of time series DataFrames.
  71. Args:
  72. ori_ts_list (List[Dict[str, Any]]): Original time series data for each prediction, including past and covariate information.
  73. pred_list (List[np.ndarray]): List of prediction arrays corresponding to each time series in ori_ts_list.
  74. Returns:
  75. List[pd.DataFrame]: A list of DataFrames, each representing the forecasted time series.
  76. """
  77. return [
  78. self.arraytots(ori_ts, pred) for ori_ts, pred in zip(ori_ts_list, pred_list)
  79. ]
  80. def arraytots(self, ori_ts: Dict[str, Any], pred: np.ndarray) -> pd.DataFrame:
  81. """
  82. Converts a single array prediction to a time series DataFrame.
  83. Args:
  84. ori_ts (Dict[str, Any]): Original time series data for a single time series.
  85. pred (np.ndarray): Prediction array for the given time series.
  86. Returns:
  87. pd.DataFrame: A DataFrame representing the forecasted time series.
  88. Raises:
  89. ValueError: If none of the expected keys are found in ori_ts.
  90. """
  91. pred = pred[0]
  92. if ori_ts.get("past_target", None) is not None:
  93. ts = ori_ts["past_target"]
  94. elif ori_ts.get("observed_cov_numeric", None) is not None:
  95. ts = ori_ts["observed_cov_numeric"]
  96. elif ori_ts.get("known_cov_numeric", None) is not None:
  97. ts = ori_ts["known_cov_numeric"]
  98. elif ori_ts.get("static_cov_numeric", None) is not None:
  99. ts = ori_ts["static_cov_numeric"]
  100. else:
  101. raise ValueError("No value in ori_ts")
  102. column_name = (
  103. self.info_params["target_cols"]
  104. if "target_cols" in self.info_params
  105. else self.info_params["feature_cols"]
  106. )
  107. if isinstance(self.info_params["freq"], str):
  108. past_target_index = ts.index
  109. if past_target_index.freq is None:
  110. past_target_index.freq = pd.infer_freq(ts.index)
  111. future_target_index = pd.date_range(
  112. past_target_index[-1] + past_target_index.freq,
  113. periods=pred.shape[0],
  114. freq=self.info_params["freq"],
  115. name=self.info_params["time_col"],
  116. )
  117. elif isinstance(self.info_params["freq"], int):
  118. start_idx = max(ts.index) + 1
  119. stop_idx = start_idx + pred.shape[0]
  120. future_target_index = pd.RangeIndex(
  121. start=start_idx,
  122. stop=stop_idx,
  123. step=self.info_params["freq"],
  124. name=self.info_params["time_col"],
  125. )
  126. future_target = pd.DataFrame(
  127. np.reshape(pred, newshape=[pred.shape[0], -1]),
  128. index=future_target_index,
  129. columns=column_name,
  130. )
  131. return future_target