funcs.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, Dict, List, Optional, Union
  15. import numpy as np
  16. import pandas as pd
  17. from packaging.version import Version
  18. from pandas.tseries import holiday as hd
  19. from pandas.tseries.offsets import DateOffset, Day, Easter
  20. from .....utils.deps import function_requires_deps, get_dep_version, is_dep_available
  21. if is_dep_available("chinese-calendar"):
  22. import chinese_calendar
  23. if is_dep_available("scikit-learn"):
  24. from sklearn.preprocessing import StandardScaler
  25. MAX_WINDOW = 183 + 17
  26. EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
  27. NewYearsDay = hd.Holiday("New Years Day", month=1, day=1)
  28. SuperBowl = hd.Holiday("Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1)))
  29. MothersDay = hd.Holiday(
  30. "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
  31. )
  32. IndependenceDay = hd.Holiday("Independence Day", month=7, day=4)
  33. ChristmasEve = hd.Holiday("Christmas", month=12, day=24)
  34. ChristmasDay = hd.Holiday("Christmas", month=12, day=25)
  35. NewYearsEve = hd.Holiday("New Years Eve", month=12, day=31)
  36. BlackFriday = hd.Holiday(
  37. "Black Friday",
  38. month=11,
  39. day=1,
  40. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
  41. )
  42. CyberMonday = hd.Holiday(
  43. "Cyber Monday",
  44. month=11,
  45. day=1,
  46. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
  47. )
  48. HOLIDAYS = [
  49. hd.EasterMonday,
  50. hd.GoodFriday,
  51. hd.USColumbusDay,
  52. hd.USLaborDay,
  53. hd.USMartinLutherKingJr,
  54. hd.USMemorialDay,
  55. hd.USPresidentsDay,
  56. hd.USThanksgivingDay,
  57. EasterSunday,
  58. NewYearsDay,
  59. SuperBowl,
  60. MothersDay,
  61. IndependenceDay,
  62. ChristmasEve,
  63. ChristmasDay,
  64. NewYearsEve,
  65. BlackFriday,
  66. CyberMonday,
  67. ]
  68. def _cal_year(
  69. x: np.datetime64,
  70. ):
  71. return x.year
  72. def _cal_month(
  73. x: np.datetime64,
  74. ):
  75. return x.month
  76. def _cal_day(
  77. x: np.datetime64,
  78. ):
  79. return x.day
  80. def _cal_hour(
  81. x: np.datetime64,
  82. ):
  83. return x.hour
  84. def _cal_weekday(
  85. x: np.datetime64,
  86. ):
  87. return x.dayofweek
  88. def _cal_quarter(
  89. x: np.datetime64,
  90. ):
  91. return x.quarter
  92. def _cal_hourofday(
  93. x: np.datetime64,
  94. ):
  95. return x.hour / 23.0 - 0.5
  96. def _cal_dayofweek(
  97. x: np.datetime64,
  98. ):
  99. return x.dayofweek / 6.0 - 0.5
  100. def _cal_dayofmonth(
  101. x: np.datetime64,
  102. ):
  103. return x.day / 30.0 - 0.5
  104. def _cal_dayofyear(
  105. x: np.datetime64,
  106. ):
  107. return x.dayofyear / 364.0 - 0.5
  108. def _cal_weekofyear(
  109. x: np.datetime64,
  110. ):
  111. return x.weekofyear / 51.0 - 0.5
  112. @function_requires_deps("chinese-calendar")
  113. def _cal_holiday(
  114. x: np.datetime64,
  115. ):
  116. return float(chinese_calendar.is_holiday(x))
  117. @function_requires_deps("chinese-calendar")
  118. def _cal_workday(
  119. x: np.datetime64,
  120. ):
  121. return float(chinese_calendar.is_workday(x))
  122. def _cal_minuteofhour(
  123. x: np.datetime64,
  124. ):
  125. return x.minute / 59 - 0.5
  126. def _cal_monthofyear(
  127. x: np.datetime64,
  128. ):
  129. return x.month / 11.0 - 0.5
  130. CAL_DATE_METHOD = {
  131. "year": _cal_year,
  132. "month": _cal_month,
  133. "day": _cal_day,
  134. "hour": _cal_hour,
  135. "weekday": _cal_weekday,
  136. "quarter": _cal_quarter,
  137. "minuteofhour": _cal_minuteofhour,
  138. "monthofyear": _cal_monthofyear,
  139. "hourofday": _cal_hourofday,
  140. "dayofweek": _cal_dayofweek,
  141. "dayofmonth": _cal_dayofmonth,
  142. "dayofyear": _cal_dayofyear,
  143. "weekofyear": _cal_weekofyear,
  144. "is_holiday": _cal_holiday,
  145. "is_workday": _cal_workday,
  146. }
  147. def load_from_one_dataframe(
  148. data: Union[pd.DataFrame, pd.Series],
  149. time_col: Optional[str] = None,
  150. value_cols: Optional[Union[List[str], str]] = None,
  151. freq: Optional[Union[str, int]] = None,
  152. drop_tail_nan: bool = False,
  153. dtype: Optional[Union[type, Dict[str, type]]] = None,
  154. ) -> pd.DataFrame:
  155. """Transforms a DataFrame or Series into a time-indexed DataFrame.
  156. Args:
  157. data (Union[pd.DataFrame, pd.Series]): The input data containing time series information.
  158. time_col (Optional[str]): The column name representing time information. If None, uses the index.
  159. value_cols (Optional[Union[List[str], str]]): Columns to extract as values. If None, uses all except time_col.
  160. freq (Optional[Union[str, int]]): The frequency of the time series data.
  161. drop_tail_nan (bool): If True, drop trailing NaN values from the data.
  162. dtype (Optional[Union[type, Dict[str, type]]]): Enforce a specific data type on the resulting DataFrame.
  163. Returns:
  164. pd.DataFrame: A DataFrame with time as the index and specified value columns.
  165. Raises:
  166. ValueError: If the time column doesn't exist, or if frequency cannot be inferred.
  167. """
  168. # Initialize series_data with specified value columns or all except time_col
  169. series_data = None
  170. if value_cols is None:
  171. if isinstance(data, pd.Series):
  172. series_data = data.copy()
  173. else:
  174. series_data = data.loc[:, data.columns != time_col].copy()
  175. else:
  176. series_data = data.loc[:, value_cols].copy()
  177. # Determine the time column values
  178. if time_col:
  179. if time_col not in data.columns:
  180. raise ValueError(
  181. "The time column: {} doesn't exist in the `data`!".format(time_col)
  182. )
  183. time_col_vals = data.loc[:, time_col]
  184. else:
  185. time_col_vals = data.index
  186. # Handle integer-based time column values when frequency is a string
  187. if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
  188. time_col_vals = time_col_vals.astype(str)
  189. # Process integer-based time column values
  190. if np.issubdtype(time_col_vals.dtype, np.integer):
  191. if freq:
  192. if not isinstance(freq, int) or freq < 1:
  193. raise ValueError(
  194. "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
  195. )
  196. else:
  197. freq = 1 # Default frequency for integer index
  198. start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
  199. if (stop_idx - start_idx) / freq != len(data):
  200. raise ValueError("The number of rows doesn't match with the RangeIndex!")
  201. time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
  202. # Process datetime-like time column values
  203. elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
  204. time_col_vals.dtype, np.datetime64
  205. ):
  206. time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
  207. time_index = pd.DatetimeIndex(time_col_vals)
  208. if freq:
  209. if not isinstance(freq, str):
  210. raise ValueError(
  211. "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
  212. )
  213. else:
  214. # Attempt to infer frequency if not provided
  215. freq = pd.infer_freq(time_index)
  216. if freq is None:
  217. raise ValueError(
  218. "Failed to infer the `freq`. A valid `freq` is required."
  219. )
  220. if freq[0] == "-":
  221. freq = freq[1:]
  222. # Raise error for unsupported time column types
  223. else:
  224. raise ValueError("The type of `time_col` is invalid.")
  225. # Ensure series_data is a DataFrame
  226. if isinstance(series_data, pd.Series):
  227. series_data = series_data.to_frame()
  228. # Set time index and sort data
  229. series_data.set_index(time_index, inplace=True)
  230. series_data.sort_index(inplace=True)
  231. return series_data
  232. def load_from_dataframe(
  233. df: pd.DataFrame,
  234. group_id: Optional[str] = None,
  235. time_col: Optional[str] = None,
  236. target_cols: Optional[Union[List[str], str]] = None,
  237. label_col: Optional[Union[List[str], str]] = None,
  238. observed_cov_cols: Optional[Union[List[str], str]] = None,
  239. feature_cols: Optional[Union[List[str], str]] = None,
  240. known_cov_cols: Optional[Union[List[str], str]] = None,
  241. static_cov_cols: Optional[Union[List[str], str]] = None,
  242. freq: Optional[Union[str, int]] = None,
  243. fill_missing_dates: bool = False,
  244. fillna_method: str = "pre",
  245. fillna_window_size: int = 10,
  246. **kwargs,
  247. ) -> Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]:
  248. """Loads and processes time series data from a DataFrame.
  249. This function extracts and organizes time series data from a given DataFrame.
  250. It supports optional grouping and extraction of specific columns as features.
  251. Args:
  252. df (pd.DataFrame): The input DataFrame containing time series data.
  253. group_id (Optional[str]): Column name used for grouping the data.
  254. time_col (Optional[str]): Name of the time column.
  255. target_cols (Optional[Union[List[str], str]]): Columns to be used as target.
  256. label_col (Optional[Union[List[str], str]]): Columns to be used as label.
  257. observed_cov_cols (Optional[Union[List[str], str]]): Columns for observed covariates.
  258. feature_cols (Optional[Union[List[str], str]]): Columns to be used as features.
  259. known_cov_cols (Optional[Union[List[str], str]]): Columns for known covariates.
  260. static_cov_cols (Optional[Union[List[str], str]]): Columns for static covariates.
  261. freq (Optional[Union[str, int]]): Frequency of the time series data.
  262. fill_missing_dates (bool): Whether to fill missing dates in the time series.
  263. fillna_method (str): Method to fill missing values ('pre' or 'post').
  264. fillna_window_size (int): Window size for filling missing values.
  265. **kwargs: Additional keyword arguments.
  266. Returns:
  267. Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]: A dictionary containing processed time series data.
  268. """
  269. # List to store DataFrames if grouping is applied
  270. dfs = []
  271. # Separate the DataFrame into groups if group_id is provided
  272. if group_id is not None:
  273. group_unique = df[group_id].unique()
  274. for column in group_unique:
  275. dfs.append(df[df[group_id].isin([column])])
  276. else:
  277. dfs = [df]
  278. # Result list to store processed data from each group
  279. res = []
  280. # If label_col is provided, ensure it is a single column
  281. if label_col:
  282. if isinstance(label_col, str) and len(label_col) > 1:
  283. raise ValueError("The length of label_col must be 1.")
  284. target_cols = label_col
  285. # If feature_cols is provided, treat it as observed_cov_cols
  286. if feature_cols:
  287. observed_cov_cols = feature_cols
  288. # Process each DataFrame in the list
  289. for df in dfs:
  290. target = None
  291. observed_cov = None
  292. known_cov = None
  293. static_cov = dict()
  294. # If no specific columns are provided, use all columns except time_col
  295. if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
  296. target = load_from_one_dataframe(
  297. df,
  298. time_col,
  299. [a for a in df.columns if a != time_col],
  300. freq,
  301. )
  302. else:
  303. if target_cols:
  304. target = load_from_one_dataframe(
  305. df,
  306. time_col,
  307. target_cols,
  308. freq,
  309. )
  310. if observed_cov_cols:
  311. observed_cov = load_from_one_dataframe(
  312. df,
  313. time_col,
  314. observed_cov_cols,
  315. freq,
  316. )
  317. if known_cov_cols:
  318. known_cov = load_from_one_dataframe(
  319. df,
  320. time_col,
  321. known_cov_cols,
  322. freq,
  323. )
  324. if static_cov_cols:
  325. if isinstance(static_cov_cols, str):
  326. static_cov_cols = [static_cov_cols]
  327. for col in static_cov_cols:
  328. if col not in df.columns or len(np.unique(df[col])) != 1:
  329. raise ValueError(
  330. "Static covariate columns data is not in columns or schema is not correct!"
  331. )
  332. static_cov[col] = df[col].iloc[0]
  333. # Append the processed data into the results list
  334. res.append(
  335. {
  336. "past_target": target,
  337. "observed_cov_numeric": observed_cov,
  338. "known_cov_numeric": known_cov,
  339. "static_cov_numeric": static_cov,
  340. }
  341. )
  342. # Return the first processed result
  343. return res[0]
  344. def _distance_to_holiday(holiday) -> Callable[[pd.Timestamp], float]:
  345. """Creates a function to calculate the distance in days to the nearest holiday.
  346. This function generates a closure that computes the number of days from
  347. a given date index to the nearest holiday within a defined window.
  348. Args:
  349. holiday: An object that provides a `dates` method, which returns the
  350. dates of holidays within a specified range.
  351. Returns:
  352. Callable[[pd.Timestamp], float]: A function that takes a date index
  353. as input and returns the distance in days to the nearest holiday.
  354. """
  355. def _distance_to_day(index: pd.Timestamp) -> float:
  356. """Calculates the distance in days from a given date index to the nearest holiday.
  357. Args:
  358. index (pd.Timestamp): The date index for which the distance to the
  359. nearest holiday should be calculated.
  360. Returns:
  361. float: The number of days to the nearest holiday.
  362. Raises:
  363. AssertionError: If no holiday is found within the specified window.
  364. """
  365. holiday_date = holiday.dates(
  366. index - pd.Timedelta(days=MAX_WINDOW),
  367. index + pd.Timedelta(days=MAX_WINDOW),
  368. )
  369. assert (
  370. len(holiday_date) != 0
  371. ), f"No closest holiday for the date index {index} found."
  372. # It sometimes returns two dates if it is exactly half a year after the
  373. # holiday. In this case, the smaller distance (182 days) is returned.
  374. return float((index - holiday_date[0]).days)
  375. return _distance_to_day
  376. @function_requires_deps("scikit-learn")
  377. def time_feature(
  378. dataset: Dict,
  379. freq: Optional[Union[str, int]],
  380. feature_cols: List[str],
  381. extend_points: int,
  382. inplace: bool = False,
  383. ) -> Dict:
  384. """Transforms the time column of a dataset into time features.
  385. This function extracts time-related features from the time column in a
  386. dataset, optionally extending the time series for future points and
  387. normalizing holiday distances.
  388. Args:
  389. dataset (Dict): Dataset to be transformed.
  390. freq: Optional[Union[str, int]]: Frequency of the time series data. If not provided,
  391. the frequency will be inferred.
  392. feature_cols (List[str]): List of feature columns to be extracted.
  393. extend_points (int): Number of future points to extend the time series.
  394. inplace (bool): Whether to perform the transformation inplace. Default is False.
  395. Returns:
  396. Dict: The transformed dataset with time features added.
  397. Raises:
  398. ValueError: If the time column is of an integer type instead of datetime.
  399. """
  400. new_ts = dataset
  401. if not inplace:
  402. new_ts = dataset.copy()
  403. # Get known_cov_numeric or initialize with past target index
  404. kcov = new_ts["known_cov_numeric"]
  405. if not kcov:
  406. tf_kcov = new_ts["past_target"].index.to_frame()
  407. else:
  408. tf_kcov = kcov.index.to_frame()
  409. time_col = tf_kcov.columns[0]
  410. # Check if time column is of datetime type
  411. if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
  412. raise ValueError(
  413. "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
  414. )
  415. # Extend the time series if no known_cov_numeric
  416. if not kcov:
  417. freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
  418. pd_version = get_dep_version("pandas")
  419. if Version(pd_version) >= Version("1.4"):
  420. extend_time = pd.date_range(
  421. start=tf_kcov[time_col][-1],
  422. freq=freq,
  423. periods=extend_points + 1,
  424. inclusive="right",
  425. name=time_col,
  426. ).to_frame()
  427. else:
  428. extend_time = pd.date_range(
  429. start=tf_kcov[time_col][-1],
  430. freq=freq,
  431. periods=extend_points + 1,
  432. closed="right",
  433. name=time_col,
  434. ).to_frame()
  435. tf_kcov = pd.concat([tf_kcov, extend_time])
  436. # Extract and add time features to known_cov_numeric
  437. for k in feature_cols:
  438. if k != "holidays":
  439. v = tf_kcov[time_col].apply(lambda x: CAL_DATE_METHOD[k](x))
  440. v.index = tf_kcov[time_col]
  441. if new_ts["known_cov_numeric"] is None:
  442. new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
  443. else:
  444. new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
  445. new_ts["known_cov_numeric"].index
  446. )
  447. else:
  448. holidays_col = []
  449. for i, H in enumerate(HOLIDAYS):
  450. v = tf_kcov[time_col].apply(_distance_to_holiday(H))
  451. v.index = tf_kcov[time_col]
  452. holidays_col.append(k + "_" + str(i))
  453. if new_ts["known_cov_numeric"] is None:
  454. new_ts["known_cov_numeric"] = pd.DataFrame(
  455. v.rename(k + "_" + str(i)), index=v.index
  456. )
  457. else:
  458. new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
  459. new_ts["known_cov_numeric"].index
  460. )
  461. scaler = StandardScaler()
  462. scaler.fit(new_ts["known_cov_numeric"][holidays_col])
  463. new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
  464. new_ts["known_cov_numeric"][holidays_col]
  465. )
  466. return new_ts