funcs.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, Dict, List, Optional, Union
  15. import numpy as np
  16. import pandas as pd
  17. from pandas.tseries import holiday as hd
  18. from pandas.tseries.offsets import DateOffset, Day, Easter
  19. from .....utils.deps import function_requires_deps, is_dep_available
  20. if is_dep_available("chinese-calendar"):
  21. import chinese_calendar
  22. if is_dep_available("scikit-learn"):
  23. from sklearn.preprocessing import StandardScaler
  24. MAX_WINDOW = 183 + 17
  25. EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
  26. NewYearsDay = hd.Holiday("New Years Day", month=1, day=1)
  27. SuperBowl = hd.Holiday("Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1)))
  28. MothersDay = hd.Holiday(
  29. "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
  30. )
  31. IndependenceDay = hd.Holiday("Independence Day", month=7, day=4)
  32. ChristmasEve = hd.Holiday("Christmas", month=12, day=24)
  33. ChristmasDay = hd.Holiday("Christmas", month=12, day=25)
  34. NewYearsEve = hd.Holiday("New Years Eve", month=12, day=31)
  35. BlackFriday = hd.Holiday(
  36. "Black Friday",
  37. month=11,
  38. day=1,
  39. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
  40. )
  41. CyberMonday = hd.Holiday(
  42. "Cyber Monday",
  43. month=11,
  44. day=1,
  45. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
  46. )
  47. HOLIDAYS = [
  48. hd.EasterMonday,
  49. hd.GoodFriday,
  50. hd.USColumbusDay,
  51. hd.USLaborDay,
  52. hd.USMartinLutherKingJr,
  53. hd.USMemorialDay,
  54. hd.USPresidentsDay,
  55. hd.USThanksgivingDay,
  56. EasterSunday,
  57. NewYearsDay,
  58. SuperBowl,
  59. MothersDay,
  60. IndependenceDay,
  61. ChristmasEve,
  62. ChristmasDay,
  63. NewYearsEve,
  64. BlackFriday,
  65. CyberMonday,
  66. ]
  67. def _cal_year(
  68. x: np.datetime64,
  69. ):
  70. return x.year
  71. def _cal_month(
  72. x: np.datetime64,
  73. ):
  74. return x.month
  75. def _cal_day(
  76. x: np.datetime64,
  77. ):
  78. return x.day
  79. def _cal_hour(
  80. x: np.datetime64,
  81. ):
  82. return x.hour
  83. def _cal_weekday(
  84. x: np.datetime64,
  85. ):
  86. return x.dayofweek
  87. def _cal_quarter(
  88. x: np.datetime64,
  89. ):
  90. return x.quarter
  91. def _cal_hourofday(
  92. x: np.datetime64,
  93. ):
  94. return x.hour / 23.0 - 0.5
  95. def _cal_dayofweek(
  96. x: np.datetime64,
  97. ):
  98. return x.dayofweek / 6.0 - 0.5
  99. def _cal_dayofmonth(
  100. x: np.datetime64,
  101. ):
  102. return x.day / 30.0 - 0.5
  103. def _cal_dayofyear(
  104. x: np.datetime64,
  105. ):
  106. return x.dayofyear / 364.0 - 0.5
  107. def _cal_weekofyear(
  108. x: np.datetime64,
  109. ):
  110. return x.weekofyear / 51.0 - 0.5
  111. @function_requires_deps("chinese-calendar")
  112. def _cal_holiday(
  113. x: np.datetime64,
  114. ):
  115. return float(chinese_calendar.is_holiday(x))
  116. @function_requires_deps("chinese-calendar")
  117. def _cal_workday(
  118. x: np.datetime64,
  119. ):
  120. return float(chinese_calendar.is_workday(x))
  121. def _cal_minuteofhour(
  122. x: np.datetime64,
  123. ):
  124. return x.minute / 59 - 0.5
  125. def _cal_monthofyear(
  126. x: np.datetime64,
  127. ):
  128. return x.month / 11.0 - 0.5
  129. CAL_DATE_METHOD = {
  130. "year": _cal_year,
  131. "month": _cal_month,
  132. "day": _cal_day,
  133. "hour": _cal_hour,
  134. "weekday": _cal_weekday,
  135. "quarter": _cal_quarter,
  136. "minuteofhour": _cal_minuteofhour,
  137. "monthofyear": _cal_monthofyear,
  138. "hourofday": _cal_hourofday,
  139. "dayofweek": _cal_dayofweek,
  140. "dayofmonth": _cal_dayofmonth,
  141. "dayofyear": _cal_dayofyear,
  142. "weekofyear": _cal_weekofyear,
  143. "is_holiday": _cal_holiday,
  144. "is_workday": _cal_workday,
  145. }
  146. def load_from_one_dataframe(
  147. data: Union[pd.DataFrame, pd.Series],
  148. time_col: Optional[str] = None,
  149. value_cols: Optional[Union[List[str], str]] = None,
  150. freq: Optional[Union[str, int]] = None,
  151. drop_tail_nan: bool = False,
  152. dtype: Optional[Union[type, Dict[str, type]]] = None,
  153. ) -> pd.DataFrame:
  154. """Transforms a DataFrame or Series into a time-indexed DataFrame.
  155. Args:
  156. data (Union[pd.DataFrame, pd.Series]): The input data containing time series information.
  157. time_col (Optional[str]): The column name representing time information. If None, uses the index.
  158. value_cols (Optional[Union[List[str], str]]): Columns to extract as values. If None, uses all except time_col.
  159. freq (Optional[Union[str, int]]): The frequency of the time series data.
  160. drop_tail_nan (bool): If True, drop trailing NaN values from the data.
  161. dtype (Optional[Union[type, Dict[str, type]]]): Enforce a specific data type on the resulting DataFrame.
  162. Returns:
  163. pd.DataFrame: A DataFrame with time as the index and specified value columns.
  164. Raises:
  165. ValueError: If the time column doesn't exist, or if frequency cannot be inferred.
  166. """
  167. # Initialize series_data with specified value columns or all except time_col
  168. series_data = None
  169. if value_cols is None:
  170. if isinstance(data, pd.Series):
  171. series_data = data.copy()
  172. else:
  173. series_data = data.loc[:, data.columns != time_col].copy()
  174. else:
  175. series_data = data.loc[:, value_cols].copy()
  176. # Determine the time column values
  177. if time_col:
  178. if time_col not in data.columns:
  179. raise ValueError(
  180. "The time column: {} doesn't exist in the `data`!".format(time_col)
  181. )
  182. time_col_vals = data.loc[:, time_col]
  183. else:
  184. time_col_vals = data.index
  185. # Handle integer-based time column values when frequency is a string
  186. if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
  187. time_col_vals = time_col_vals.astype(str)
  188. # Process integer-based time column values
  189. if np.issubdtype(time_col_vals.dtype, np.integer):
  190. if freq:
  191. if not isinstance(freq, int) or freq < 1:
  192. raise ValueError(
  193. "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
  194. )
  195. else:
  196. freq = 1 # Default frequency for integer index
  197. start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
  198. if (stop_idx - start_idx) / freq != len(data):
  199. raise ValueError("The number of rows doesn't match with the RangeIndex!")
  200. time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
  201. # Process datetime-like time column values
  202. elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
  203. time_col_vals.dtype, np.datetime64
  204. ):
  205. time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
  206. time_index = pd.DatetimeIndex(time_col_vals)
  207. if freq:
  208. if not isinstance(freq, str):
  209. raise ValueError(
  210. "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
  211. )
  212. else:
  213. # Attempt to infer frequency if not provided
  214. freq = pd.infer_freq(time_index)
  215. if freq is None:
  216. raise ValueError(
  217. "Failed to infer the `freq`. A valid `freq` is required."
  218. )
  219. if freq[0] == "-":
  220. freq = freq[1:]
  221. # Raise error for unsupported time column types
  222. else:
  223. raise ValueError("The type of `time_col` is invalid.")
  224. # Ensure series_data is a DataFrame
  225. if isinstance(series_data, pd.Series):
  226. series_data = series_data.to_frame()
  227. # Set time index and sort data
  228. series_data.set_index(time_index, inplace=True)
  229. series_data.sort_index(inplace=True)
  230. return series_data
  231. def load_from_dataframe(
  232. df: pd.DataFrame,
  233. group_id: Optional[str] = None,
  234. time_col: Optional[str] = None,
  235. target_cols: Optional[Union[List[str], str]] = None,
  236. label_col: Optional[Union[List[str], str]] = None,
  237. observed_cov_cols: Optional[Union[List[str], str]] = None,
  238. feature_cols: Optional[Union[List[str], str]] = None,
  239. known_cov_cols: Optional[Union[List[str], str]] = None,
  240. static_cov_cols: Optional[Union[List[str], str]] = None,
  241. freq: Optional[Union[str, int]] = None,
  242. fill_missing_dates: bool = False,
  243. fillna_method: str = "pre",
  244. fillna_window_size: int = 10,
  245. **kwargs,
  246. ) -> Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]:
  247. """Loads and processes time series data from a DataFrame.
  248. This function extracts and organizes time series data from a given DataFrame.
  249. It supports optional grouping and extraction of specific columns as features.
  250. Args:
  251. df (pd.DataFrame): The input DataFrame containing time series data.
  252. group_id (Optional[str]): Column name used for grouping the data.
  253. time_col (Optional[str]): Name of the time column.
  254. target_cols (Optional[Union[List[str], str]]): Columns to be used as target.
  255. label_col (Optional[Union[List[str], str]]): Columns to be used as label.
  256. observed_cov_cols (Optional[Union[List[str], str]]): Columns for observed covariates.
  257. feature_cols (Optional[Union[List[str], str]]): Columns to be used as features.
  258. known_cov_cols (Optional[Union[List[str], str]]): Columns for known covariates.
  259. static_cov_cols (Optional[Union[List[str], str]]): Columns for static covariates.
  260. freq (Optional[Union[str, int]]): Frequency of the time series data.
  261. fill_missing_dates (bool): Whether to fill missing dates in the time series.
  262. fillna_method (str): Method to fill missing values ('pre' or 'post').
  263. fillna_window_size (int): Window size for filling missing values.
  264. **kwargs: Additional keyword arguments.
  265. Returns:
  266. Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]: A dictionary containing processed time series data.
  267. """
  268. # List to store DataFrames if grouping is applied
  269. dfs = []
  270. # Separate the DataFrame into groups if group_id is provided
  271. if group_id is not None:
  272. group_unique = df[group_id].unique()
  273. for column in group_unique:
  274. dfs.append(df[df[group_id].isin([column])])
  275. else:
  276. dfs = [df]
  277. # Result list to store processed data from each group
  278. res = []
  279. # If label_col is provided, ensure it is a single column
  280. if label_col:
  281. if isinstance(label_col, str) and len(label_col) > 1:
  282. raise ValueError("The length of label_col must be 1.")
  283. target_cols = label_col
  284. # If feature_cols is provided, treat it as observed_cov_cols
  285. if feature_cols:
  286. observed_cov_cols = feature_cols
  287. # Process each DataFrame in the list
  288. for df in dfs:
  289. target = None
  290. observed_cov = None
  291. known_cov = None
  292. static_cov = dict()
  293. # If no specific columns are provided, use all columns except time_col
  294. if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
  295. target = load_from_one_dataframe(
  296. df,
  297. time_col,
  298. [a for a in df.columns if a != time_col],
  299. freq,
  300. )
  301. else:
  302. if target_cols:
  303. target = load_from_one_dataframe(
  304. df,
  305. time_col,
  306. target_cols,
  307. freq,
  308. )
  309. if observed_cov_cols:
  310. observed_cov = load_from_one_dataframe(
  311. df,
  312. time_col,
  313. observed_cov_cols,
  314. freq,
  315. )
  316. if known_cov_cols:
  317. known_cov = load_from_one_dataframe(
  318. df,
  319. time_col,
  320. known_cov_cols,
  321. freq,
  322. )
  323. if static_cov_cols:
  324. if isinstance(static_cov_cols, str):
  325. static_cov_cols = [static_cov_cols]
  326. for col in static_cov_cols:
  327. if col not in df.columns or len(np.unique(df[col])) != 1:
  328. raise ValueError(
  329. "Static covariate columns data is not in columns or schema is not correct!"
  330. )
  331. static_cov[col] = df[col].iloc[0]
  332. # Append the processed data into the results list
  333. res.append(
  334. {
  335. "past_target": target,
  336. "observed_cov_numeric": observed_cov,
  337. "known_cov_numeric": known_cov,
  338. "static_cov_numeric": static_cov,
  339. }
  340. )
  341. # Return the first processed result
  342. return res[0]
  343. def _distance_to_holiday(holiday) -> Callable[[pd.Timestamp], float]:
  344. """Creates a function to calculate the distance in days to the nearest holiday.
  345. This function generates a closure that computes the number of days from
  346. a given date index to the nearest holiday within a defined window.
  347. Args:
  348. holiday: An object that provides a `dates` method, which returns the
  349. dates of holidays within a specified range.
  350. Returns:
  351. Callable[[pd.Timestamp], float]: A function that takes a date index
  352. as input and returns the distance in days to the nearest holiday.
  353. """
  354. def _distance_to_day(index: pd.Timestamp) -> float:
  355. """Calculates the distance in days from a given date index to the nearest holiday.
  356. Args:
  357. index (pd.Timestamp): The date index for which the distance to the
  358. nearest holiday should be calculated.
  359. Returns:
  360. float: The number of days to the nearest holiday.
  361. Raises:
  362. AssertionError: If no holiday is found within the specified window.
  363. """
  364. holiday_date = holiday.dates(
  365. index - pd.Timedelta(days=MAX_WINDOW),
  366. index + pd.Timedelta(days=MAX_WINDOW),
  367. )
  368. assert (
  369. len(holiday_date) != 0
  370. ), f"No closest holiday for the date index {index} found."
  371. # It sometimes returns two dates if it is exactly half a year after the
  372. # holiday. In this case, the smaller distance (182 days) is returned.
  373. return float((index - holiday_date[0]).days)
  374. return _distance_to_day
  375. @function_requires_deps("scikit-learn")
  376. def time_feature(
  377. dataset: Dict,
  378. freq: Optional[Union[str, int]],
  379. feature_cols: List[str],
  380. extend_points: int,
  381. inplace: bool = False,
  382. ) -> Dict:
  383. """Transforms the time column of a dataset into time features.
  384. This function extracts time-related features from the time column in a
  385. dataset, optionally extending the time series for future points and
  386. normalizing holiday distances.
  387. Args:
  388. dataset (Dict): Dataset to be transformed.
  389. freq: Optional[Union[str, int]]: Frequency of the time series data. If not provided,
  390. the frequency will be inferred.
  391. feature_cols (List[str]): List of feature columns to be extracted.
  392. extend_points (int): Number of future points to extend the time series.
  393. inplace (bool): Whether to perform the transformation inplace. Default is False.
  394. Returns:
  395. Dict: The transformed dataset with time features added.
  396. Raises:
  397. ValueError: If the time column is of an integer type instead of datetime.
  398. """
  399. new_ts = dataset
  400. if not inplace:
  401. new_ts = dataset.copy()
  402. # Get known_cov_numeric or initialize with past target index
  403. kcov = new_ts["known_cov_numeric"]
  404. if not kcov:
  405. tf_kcov = new_ts["past_target"].index.to_frame()
  406. else:
  407. tf_kcov = kcov.index.to_frame()
  408. time_col = tf_kcov.columns[0]
  409. # Check if time column is of datetime type
  410. if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
  411. raise ValueError(
  412. "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
  413. )
  414. # Extend the time series if no known_cov_numeric
  415. if not kcov:
  416. freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
  417. extend_time = pd.date_range(
  418. start=tf_kcov[time_col][-1],
  419. freq=freq,
  420. periods=extend_points + 1,
  421. closed="right",
  422. name=time_col,
  423. ).to_frame()
  424. tf_kcov = pd.concat([tf_kcov, extend_time])
  425. # Extract and add time features to known_cov_numeric
  426. for k in feature_cols:
  427. if k != "holidays":
  428. v = tf_kcov[time_col].apply(lambda x: CAL_DATE_METHOD[k](x))
  429. v.index = tf_kcov[time_col]
  430. if new_ts["known_cov_numeric"] is None:
  431. new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
  432. else:
  433. new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
  434. new_ts["known_cov_numeric"].index
  435. )
  436. else:
  437. holidays_col = []
  438. for i, H in enumerate(HOLIDAYS):
  439. v = tf_kcov[time_col].apply(_distance_to_holiday(H))
  440. v.index = tf_kcov[time_col]
  441. holidays_col.append(k + "_" + str(i))
  442. if new_ts["known_cov_numeric"] is None:
  443. new_ts["known_cov_numeric"] = pd.DataFrame(
  444. v.rename(k + "_" + str(i)), index=v.index
  445. )
  446. else:
  447. new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
  448. new_ts["known_cov_numeric"].index
  449. )
  450. scaler = StandardScaler()
  451. scaler.fit(new_ts["known_cov_numeric"][holidays_col])
  452. new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
  453. new_ts["known_cov_numeric"][holidays_col]
  454. )
  455. return new_ts