funcs.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict
  15. import os
  16. import numpy as np
  17. import pandas as pd
  18. import chinese_calendar
  19. from pandas.tseries.offsets import DateOffset, Easter, Day
  20. from pandas.tseries import holiday as hd
  21. from sklearn.preprocessing import StandardScaler
  22. MAX_WINDOW = 183 + 17
  23. EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
  24. NewYearsDay = hd.Holiday("New Years Day", month=1, day=1)
  25. SuperBowl = hd.Holiday("Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1)))
  26. MothersDay = hd.Holiday(
  27. "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
  28. )
  29. IndependenceDay = hd.Holiday("Independence Day", month=7, day=4)
  30. ChristmasEve = hd.Holiday("Christmas", month=12, day=24)
  31. ChristmasDay = hd.Holiday("Christmas", month=12, day=25)
  32. NewYearsEve = hd.Holiday("New Years Eve", month=12, day=31)
  33. BlackFriday = hd.Holiday(
  34. "Black Friday",
  35. month=11,
  36. day=1,
  37. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
  38. )
  39. CyberMonday = hd.Holiday(
  40. "Cyber Monday",
  41. month=11,
  42. day=1,
  43. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
  44. )
  45. HOLIDAYS = [
  46. hd.EasterMonday,
  47. hd.GoodFriday,
  48. hd.USColumbusDay,
  49. hd.USLaborDay,
  50. hd.USMartinLutherKingJr,
  51. hd.USMemorialDay,
  52. hd.USPresidentsDay,
  53. hd.USThanksgivingDay,
  54. EasterSunday,
  55. NewYearsDay,
  56. SuperBowl,
  57. MothersDay,
  58. IndependenceDay,
  59. ChristmasEve,
  60. ChristmasDay,
  61. NewYearsEve,
  62. BlackFriday,
  63. CyberMonday,
  64. ]
  65. def _cal_year(
  66. x: np.datetime64,
  67. ):
  68. return x.year
  69. def _cal_month(
  70. x: np.datetime64,
  71. ):
  72. return x.month
  73. def _cal_day(
  74. x: np.datetime64,
  75. ):
  76. return x.day
  77. def _cal_hour(
  78. x: np.datetime64,
  79. ):
  80. return x.hour
  81. def _cal_weekday(
  82. x: np.datetime64,
  83. ):
  84. return x.dayofweek
  85. def _cal_quarter(
  86. x: np.datetime64,
  87. ):
  88. return x.quarter
  89. def _cal_hourofday(
  90. x: np.datetime64,
  91. ):
  92. return x.hour / 23.0 - 0.5
  93. def _cal_dayofweek(
  94. x: np.datetime64,
  95. ):
  96. return x.dayofweek / 6.0 - 0.5
  97. def _cal_dayofmonth(
  98. x: np.datetime64,
  99. ):
  100. return x.day / 30.0 - 0.5
  101. def _cal_dayofyear(
  102. x: np.datetime64,
  103. ):
  104. return x.dayofyear / 364.0 - 0.5
  105. def _cal_weekofyear(
  106. x: np.datetime64,
  107. ):
  108. return x.weekofyear / 51.0 - 0.5
  109. def _cal_holiday(
  110. x: np.datetime64,
  111. ):
  112. return float(chinese_calendar.is_holiday(x))
  113. def _cal_workday(
  114. x: np.datetime64,
  115. ):
  116. return float(chinese_calendar.is_workday(x))
  117. def _cal_minuteofhour(
  118. x: np.datetime64,
  119. ):
  120. return x.minute / 59 - 0.5
  121. def _cal_monthofyear(
  122. x: np.datetime64,
  123. ):
  124. return x.month / 11.0 - 0.5
  125. CAL_DATE_METHOD = {
  126. "year": _cal_year,
  127. "month": _cal_month,
  128. "day": _cal_day,
  129. "hour": _cal_hour,
  130. "weekday": _cal_weekday,
  131. "quarter": _cal_quarter,
  132. "minuteofhour": _cal_minuteofhour,
  133. "monthofyear": _cal_monthofyear,
  134. "hourofday": _cal_hourofday,
  135. "dayofweek": _cal_dayofweek,
  136. "dayofmonth": _cal_dayofmonth,
  137. "dayofyear": _cal_dayofyear,
  138. "weekofyear": _cal_weekofyear,
  139. "is_holiday": _cal_holiday,
  140. "is_workday": _cal_workday,
  141. }
  142. def load_from_one_dataframe(
  143. data: Union[pd.DataFrame, pd.Series],
  144. time_col: Optional[str] = None,
  145. value_cols: Optional[Union[List[str], str]] = None,
  146. freq: Optional[Union[str, int]] = None,
  147. drop_tail_nan: bool = False,
  148. dtype: Optional[Union[type, Dict[str, type]]] = None,
  149. ) -> pd.DataFrame:
  150. """Transforms a DataFrame or Series into a time-indexed DataFrame.
  151. Args:
  152. data (Union[pd.DataFrame, pd.Series]): The input data containing time series information.
  153. time_col (Optional[str]): The column name representing time information. If None, uses the index.
  154. value_cols (Optional[Union[List[str], str]]): Columns to extract as values. If None, uses all except time_col.
  155. freq (Optional[Union[str, int]]): The frequency of the time series data.
  156. drop_tail_nan (bool): If True, drop trailing NaN values from the data.
  157. dtype (Optional[Union[type, Dict[str, type]]]): Enforce a specific data type on the resulting DataFrame.
  158. Returns:
  159. pd.DataFrame: A DataFrame with time as the index and specified value columns.
  160. Raises:
  161. ValueError: If the time column doesn't exist, or if frequency cannot be inferred.
  162. """
  163. # Initialize series_data with specified value columns or all except time_col
  164. series_data = None
  165. if value_cols is None:
  166. if isinstance(data, pd.Series):
  167. series_data = data.copy()
  168. else:
  169. series_data = data.loc[:, data.columns != time_col].copy()
  170. else:
  171. series_data = data.loc[:, value_cols].copy()
  172. # Determine the time column values
  173. if time_col:
  174. if time_col not in data.columns:
  175. raise ValueError(
  176. "The time column: {} doesn't exist in the `data`!".format(time_col)
  177. )
  178. time_col_vals = data.loc[:, time_col]
  179. else:
  180. time_col_vals = data.index
  181. # Handle integer-based time column values when frequency is a string
  182. if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
  183. time_col_vals = time_col_vals.astype(str)
  184. # Process integer-based time column values
  185. if np.issubdtype(time_col_vals.dtype, np.integer):
  186. if freq:
  187. if not isinstance(freq, int) or freq < 1:
  188. raise ValueError(
  189. "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
  190. )
  191. else:
  192. freq = 1 # Default frequency for integer index
  193. start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
  194. if (stop_idx - start_idx) / freq != len(data):
  195. raise ValueError("The number of rows doesn't match with the RangeIndex!")
  196. time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
  197. # Process datetime-like time column values
  198. elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
  199. time_col_vals.dtype, np.datetime64
  200. ):
  201. time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
  202. time_index = pd.DatetimeIndex(time_col_vals)
  203. if freq:
  204. if not isinstance(freq, str):
  205. raise ValueError(
  206. "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
  207. )
  208. else:
  209. # Attempt to infer frequency if not provided
  210. freq = pd.infer_freq(time_index)
  211. if freq is None:
  212. raise ValueError(
  213. "Failed to infer the `freq`. A valid `freq` is required."
  214. )
  215. if freq[0] == "-":
  216. freq = freq[1:]
  217. # Raise error for unsupported time column types
  218. else:
  219. raise ValueError("The type of `time_col` is invalid.")
  220. # Ensure series_data is a DataFrame
  221. if isinstance(series_data, pd.Series):
  222. series_data = series_data.to_frame()
  223. # Set time index and sort data
  224. series_data.set_index(time_index, inplace=True)
  225. series_data.sort_index(inplace=True)
  226. return series_data
  227. def load_from_dataframe(
  228. df: pd.DataFrame,
  229. group_id: Optional[str] = None,
  230. time_col: Optional[str] = None,
  231. target_cols: Optional[Union[List[str], str]] = None,
  232. label_col: Optional[Union[List[str], str]] = None,
  233. observed_cov_cols: Optional[Union[List[str], str]] = None,
  234. feature_cols: Optional[Union[List[str], str]] = None,
  235. known_cov_cols: Optional[Union[List[str], str]] = None,
  236. static_cov_cols: Optional[Union[List[str], str]] = None,
  237. freq: Optional[Union[str, int]] = None,
  238. fill_missing_dates: bool = False,
  239. fillna_method: str = "pre",
  240. fillna_window_size: int = 10,
  241. **kwargs,
  242. ) -> Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]:
  243. """Loads and processes time series data from a DataFrame.
  244. This function extracts and organizes time series data from a given DataFrame.
  245. It supports optional grouping and extraction of specific columns as features.
  246. Args:
  247. df (pd.DataFrame): The input DataFrame containing time series data.
  248. group_id (Optional[str]): Column name used for grouping the data.
  249. time_col (Optional[str]): Name of the time column.
  250. target_cols (Optional[Union[List[str], str]]): Columns to be used as target.
  251. label_col (Optional[Union[List[str], str]]): Columns to be used as label.
  252. observed_cov_cols (Optional[Union[List[str], str]]): Columns for observed covariates.
  253. feature_cols (Optional[Union[List[str], str]]): Columns to be used as features.
  254. known_cov_cols (Optional[Union[List[str], str]]): Columns for known covariates.
  255. static_cov_cols (Optional[Union[List[str], str]]): Columns for static covariates.
  256. freq (Optional[Union[str, int]]): Frequency of the time series data.
  257. fill_missing_dates (bool): Whether to fill missing dates in the time series.
  258. fillna_method (str): Method to fill missing values ('pre' or 'post').
  259. fillna_window_size (int): Window size for filling missing values.
  260. **kwargs: Additional keyword arguments.
  261. Returns:
  262. Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]: A dictionary containing processed time series data.
  263. """
  264. # List to store DataFrames if grouping is applied
  265. dfs = []
  266. # Separate the DataFrame into groups if group_id is provided
  267. if group_id is not None:
  268. group_unique = df[group_id].unique()
  269. for column in group_unique:
  270. dfs.append(df[df[group_id].isin([column])])
  271. else:
  272. dfs = [df]
  273. # Result list to store processed data from each group
  274. res = []
  275. # If label_col is provided, ensure it is a single column
  276. if label_col:
  277. if isinstance(label_col, str) and len(label_col) > 1:
  278. raise ValueError("The length of label_col must be 1.")
  279. target_cols = label_col
  280. # If feature_cols is provided, treat it as observed_cov_cols
  281. if feature_cols:
  282. observed_cov_cols = feature_cols
  283. # Process each DataFrame in the list
  284. for df in dfs:
  285. target = None
  286. observed_cov = None
  287. known_cov = None
  288. static_cov = dict()
  289. # If no specific columns are provided, use all columns except time_col
  290. if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
  291. target = load_from_one_dataframe(
  292. df,
  293. time_col,
  294. [a for a in df.columns if a != time_col],
  295. freq,
  296. )
  297. else:
  298. if target_cols:
  299. target = load_from_one_dataframe(
  300. df,
  301. time_col,
  302. target_cols,
  303. freq,
  304. )
  305. if observed_cov_cols:
  306. observed_cov = load_from_one_dataframe(
  307. df,
  308. time_col,
  309. observed_cov_cols,
  310. freq,
  311. )
  312. if known_cov_cols:
  313. known_cov = load_from_one_dataframe(
  314. df,
  315. time_col,
  316. known_cov_cols,
  317. freq,
  318. )
  319. if static_cov_cols:
  320. if isinstance(static_cov_cols, str):
  321. static_cov_cols = [static_cov_cols]
  322. for col in static_cov_cols:
  323. if col not in df.columns or len(np.unique(df[col])) != 1:
  324. raise ValueError(
  325. "Static covariate columns data is not in columns or schema is not correct!"
  326. )
  327. static_cov[col] = df[col].iloc[0]
  328. # Append the processed data into the results list
  329. res.append(
  330. {
  331. "past_target": target,
  332. "observed_cov_numeric": observed_cov,
  333. "known_cov_numeric": known_cov,
  334. "static_cov_numeric": static_cov,
  335. }
  336. )
  337. # Return the first processed result
  338. return res[0]
  339. def _distance_to_holiday(holiday) -> Callable[[pd.Timestamp], float]:
  340. """Creates a function to calculate the distance in days to the nearest holiday.
  341. This function generates a closure that computes the number of days from
  342. a given date index to the nearest holiday within a defined window.
  343. Args:
  344. holiday: An object that provides a `dates` method, which returns the
  345. dates of holidays within a specified range.
  346. Returns:
  347. Callable[[pd.Timestamp], float]: A function that takes a date index
  348. as input and returns the distance in days to the nearest holiday.
  349. """
  350. def _distance_to_day(index: pd.Timestamp) -> float:
  351. """Calculates the distance in days from a given date index to the nearest holiday.
  352. Args:
  353. index (pd.Timestamp): The date index for which the distance to the
  354. nearest holiday should be calculated.
  355. Returns:
  356. float: The number of days to the nearest holiday.
  357. Raises:
  358. AssertionError: If no holiday is found within the specified window.
  359. """
  360. holiday_date = holiday.dates(
  361. index - pd.Timedelta(days=MAX_WINDOW),
  362. index + pd.Timedelta(days=MAX_WINDOW),
  363. )
  364. assert (
  365. len(holiday_date) != 0
  366. ), f"No closest holiday for the date index {index} found."
  367. # It sometimes returns two dates if it is exactly half a year after the
  368. # holiday. In this case, the smaller distance (182 days) is returned.
  369. return float((index - holiday_date[0]).days)
  370. return _distance_to_day
  371. def time_feature(
  372. dataset: Dict,
  373. freq: Optional[Union[str, int]],
  374. feature_cols: List[str],
  375. extend_points: int,
  376. inplace: bool = False,
  377. ) -> Dict:
  378. """Transforms the time column of a dataset into time features.
  379. This function extracts time-related features from the time column in a
  380. dataset, optionally extending the time series for future points and
  381. normalizing holiday distances.
  382. Args:
  383. dataset (Dict): Dataset to be transformed.
  384. freq: Optional[Union[str, int]]: Frequency of the time series data. If not provided,
  385. the frequency will be inferred.
  386. feature_cols (List[str]): List of feature columns to be extracted.
  387. extend_points (int): Number of future points to extend the time series.
  388. inplace (bool): Whether to perform the transformation inplace. Default is False.
  389. Returns:
  390. Dict: The transformed dataset with time features added.
  391. Raises:
  392. ValueError: If the time column is of an integer type instead of datetime.
  393. """
  394. new_ts = dataset
  395. if not inplace:
  396. new_ts = dataset.copy()
  397. # Get known_cov_numeric or initialize with past target index
  398. kcov = new_ts["known_cov_numeric"]
  399. if not kcov:
  400. tf_kcov = new_ts["past_target"].index.to_frame()
  401. else:
  402. tf_kcov = kcov.index.to_frame()
  403. time_col = tf_kcov.columns[0]
  404. # Check if time column is of datetime type
  405. if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
  406. raise ValueError(
  407. "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
  408. )
  409. # Extend the time series if no known_cov_numeric
  410. if not kcov:
  411. freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
  412. extend_time = pd.date_range(
  413. start=tf_kcov[time_col][-1],
  414. freq=freq,
  415. periods=extend_points + 1,
  416. closed="right",
  417. name=time_col,
  418. ).to_frame()
  419. tf_kcov = pd.concat([tf_kcov, extend_time])
  420. # Extract and add time features to known_cov_numeric
  421. for k in feature_cols:
  422. if k != "holidays":
  423. v = tf_kcov[time_col].apply(lambda x: CAL_DATE_METHOD[k](x))
  424. v.index = tf_kcov[time_col]
  425. if new_ts["known_cov_numeric"] is None:
  426. new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
  427. else:
  428. new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
  429. new_ts["known_cov_numeric"].index
  430. )
  431. else:
  432. holidays_col = []
  433. for i, H in enumerate(HOLIDAYS):
  434. v = tf_kcov[time_col].apply(_distance_to_holiday(H))
  435. v.index = tf_kcov[time_col]
  436. holidays_col.append(k + "_" + str(i))
  437. if new_ts["known_cov_numeric"] is None:
  438. new_ts["known_cov_numeric"] = pd.DataFrame(
  439. v.rename(k + "_" + str(i)), index=v.index
  440. )
  441. else:
  442. new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
  443. new_ts["known_cov_numeric"].index
  444. )
  445. scaler = StandardScaler()
  446. scaler.fit(new_ts["known_cov_numeric"][holidays_col])
  447. new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
  448. new_ts["known_cov_numeric"][holidays_col]
  449. )
  450. return new_ts