ts_functions.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict
  16. import numpy as np
  17. import pandas as pd
  18. import joblib
  19. import chinese_calendar
  20. from pandas.tseries.offsets import DateOffset, Easter, Day
  21. from pandas.tseries import holiday as hd
  22. from sklearn.preprocessing import StandardScaler
  23. MAX_WINDOW = 183 + 17
  24. EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
  25. NewYearsDay = hd.Holiday("New Years Day", month=1, day=1)
  26. SuperBowl = hd.Holiday("Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1)))
  27. MothersDay = hd.Holiday(
  28. "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
  29. )
  30. IndependenceDay = hd.Holiday("Independence Day", month=7, day=4)
  31. ChristmasEve = hd.Holiday("Christmas", month=12, day=24)
  32. ChristmasDay = hd.Holiday("Christmas", month=12, day=25)
  33. NewYearsEve = hd.Holiday("New Years Eve", month=12, day=31)
  34. BlackFriday = hd.Holiday(
  35. "Black Friday",
  36. month=11,
  37. day=1,
  38. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
  39. )
  40. CyberMonday = hd.Holiday(
  41. "Cyber Monday",
  42. month=11,
  43. day=1,
  44. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
  45. )
  46. HOLIDAYS = [
  47. hd.EasterMonday,
  48. hd.GoodFriday,
  49. hd.USColumbusDay,
  50. hd.USLaborDay,
  51. hd.USMartinLutherKingJr,
  52. hd.USMemorialDay,
  53. hd.USPresidentsDay,
  54. hd.USThanksgivingDay,
  55. EasterSunday,
  56. NewYearsDay,
  57. SuperBowl,
  58. MothersDay,
  59. IndependenceDay,
  60. ChristmasEve,
  61. ChristmasDay,
  62. NewYearsEve,
  63. BlackFriday,
  64. CyberMonday,
  65. ]
  66. def _cal_year(
  67. x: np.datetime64,
  68. ):
  69. return x.year
  70. def _cal_month(
  71. x: np.datetime64,
  72. ):
  73. return x.month
  74. def _cal_day(
  75. x: np.datetime64,
  76. ):
  77. return x.day
  78. def _cal_hour(
  79. x: np.datetime64,
  80. ):
  81. return x.hour
  82. def _cal_weekday(
  83. x: np.datetime64,
  84. ):
  85. return x.dayofweek
  86. def _cal_quarter(
  87. x: np.datetime64,
  88. ):
  89. return x.quarter
  90. def _cal_hourofday(
  91. x: np.datetime64,
  92. ):
  93. return x.hour / 23.0 - 0.5
  94. def _cal_dayofweek(
  95. x: np.datetime64,
  96. ):
  97. return x.dayofweek / 6.0 - 0.5
  98. def _cal_dayofmonth(
  99. x: np.datetime64,
  100. ):
  101. return x.day / 30.0 - 0.5
  102. def _cal_dayofyear(
  103. x: np.datetime64,
  104. ):
  105. return x.dayofyear / 364.0 - 0.5
  106. def _cal_weekofyear(
  107. x: np.datetime64,
  108. ):
  109. return x.weekofyear / 51.0 - 0.5
  110. def _cal_holiday(
  111. x: np.datetime64,
  112. ):
  113. return float(chinese_calendar.is_holiday(x))
  114. def _cal_workday(
  115. x: np.datetime64,
  116. ):
  117. return float(chinese_calendar.is_workday(x))
  118. def _cal_minuteofhour(
  119. x: np.datetime64,
  120. ):
  121. return x.minute / 59 - 0.5
  122. def _cal_monthofyear(
  123. x: np.datetime64,
  124. ):
  125. return x.month / 11.0 - 0.5
  126. CAL_DATE_METHOD = {
  127. "year": _cal_year,
  128. "month": _cal_month,
  129. "day": _cal_day,
  130. "hour": _cal_hour,
  131. "weekday": _cal_weekday,
  132. "quarter": _cal_quarter,
  133. "minuteofhour": _cal_minuteofhour,
  134. "monthofyear": _cal_monthofyear,
  135. "hourofday": _cal_hourofday,
  136. "dayofweek": _cal_dayofweek,
  137. "dayofmonth": _cal_dayofmonth,
  138. "dayofyear": _cal_dayofyear,
  139. "weekofyear": _cal_weekofyear,
  140. "is_holiday": _cal_holiday,
  141. "is_workday": _cal_workday,
  142. }
  143. def load_from_one_dataframe(
  144. data: Union[pd.DataFrame, pd.Series],
  145. time_col: Optional[str] = None,
  146. value_cols: Optional[Union[List[str], str]] = None,
  147. freq: Optional[Union[str, int]] = None,
  148. drop_tail_nan: bool = False,
  149. dtype: Optional[Union[type, Dict[str, type]]] = None,
  150. ):
  151. series_data = None
  152. if value_cols is None:
  153. if isinstance(data, pd.Series):
  154. series_data = data.copy()
  155. else:
  156. series_data = data.loc[:, data.columns != time_col].copy()
  157. else:
  158. series_data = data.loc[:, value_cols].copy()
  159. if time_col:
  160. if time_col not in data.columns:
  161. raise ValueError(
  162. "The time column: {} doesn't exist in the `data`!".format(time_col)
  163. )
  164. time_col_vals = data.loc[:, time_col]
  165. else:
  166. time_col_vals = data.index
  167. if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
  168. time_col_vals = time_col_vals.astype(str)
  169. if np.issubdtype(time_col_vals.dtype, np.integer):
  170. if freq:
  171. if not isinstance(freq, int) or freq < 1:
  172. raise ValueError(
  173. "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
  174. )
  175. else:
  176. freq = 1
  177. start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
  178. if (stop_idx - start_idx) / freq != len(data):
  179. raise ValueError("The number of rows doesn't match with the RangeIndex!")
  180. time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
  181. elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
  182. time_col_vals.dtype, np.datetime64
  183. ):
  184. time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
  185. time_index = pd.DatetimeIndex(time_col_vals)
  186. if freq:
  187. if not isinstance(freq, str):
  188. raise ValueError(
  189. "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
  190. )
  191. else:
  192. # If freq is not provided and automatic inference fail, throw exception
  193. freq = pd.infer_freq(time_index)
  194. if freq is None:
  195. raise ValueError(
  196. "Failed to infer the `freq`. A valid `freq` is required."
  197. )
  198. if freq[0] == "-":
  199. freq = freq[1:]
  200. else:
  201. raise ValueError("The type of `time_col` is invalid.")
  202. if isinstance(series_data, pd.Series):
  203. series_data = series_data.to_frame()
  204. series_data.set_index(time_index, inplace=True)
  205. series_data.sort_index(inplace=True)
  206. return series_data
  207. def load_from_dataframe(
  208. df: pd.DataFrame,
  209. group_id: str = None,
  210. time_col: Optional[str] = None,
  211. target_cols: Optional[Union[List[str], str]] = None,
  212. label_col: Optional[Union[List[str], str]] = None,
  213. observed_cov_cols: Optional[Union[List[str], str]] = None,
  214. feature_cols: Optional[Union[List[str], str]] = None,
  215. known_cov_cols: Optional[Union[List[str], str]] = None,
  216. static_cov_cols: Optional[Union[List[str], str]] = None,
  217. freq: Optional[Union[str, int]] = None,
  218. fill_missing_dates: bool = False,
  219. fillna_method: str = "pre",
  220. fillna_window_size: int = 10,
  221. **kwargs,
  222. ):
  223. dfs = [] # seperate multiple group
  224. if group_id is not None:
  225. group_unique = df[group_id].unique()
  226. for column in group_unique:
  227. dfs.append(df[df[group_id].isin([column])])
  228. else:
  229. dfs = [df]
  230. res = []
  231. if label_col:
  232. if isinstance(label_col, str) and len(label_col) > 1:
  233. raise ValueError("The length of label_col must be 1.")
  234. target_cols = label_col
  235. if feature_cols:
  236. observed_cov_cols = feature_cols
  237. for df in dfs:
  238. target = None
  239. observed_cov = None
  240. known_cov = None
  241. static_cov = dict()
  242. if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
  243. target = load_from_one_dataframe(
  244. df,
  245. time_col,
  246. [a for a in df.columns if a != time_col],
  247. freq,
  248. )
  249. else:
  250. if target_cols:
  251. target = load_from_one_dataframe(
  252. df,
  253. time_col,
  254. target_cols,
  255. freq,
  256. )
  257. if observed_cov_cols:
  258. observed_cov = load_from_one_dataframe(
  259. df,
  260. time_col,
  261. observed_cov_cols,
  262. freq,
  263. )
  264. if known_cov_cols:
  265. known_cov = load_from_one_dataframe(
  266. df,
  267. time_col,
  268. known_cov_cols,
  269. freq,
  270. )
  271. if static_cov_cols:
  272. if isinstance(static_cov_cols, str):
  273. static_cov_cols = [static_cov_cols]
  274. for col in static_cov_cols:
  275. if col not in df.columns or len(np.unique(df[col])) != 1:
  276. raise ValueError(
  277. "static cov cals data is not in columns or schema is not right!"
  278. )
  279. static_cov[col] = df[col].iloc[0]
  280. res.append(
  281. {
  282. "past_target": target,
  283. "observed_cov_numeric": observed_cov,
  284. "known_cov_numeric": known_cov,
  285. "static_cov_numeric": static_cov,
  286. }
  287. )
  288. return res[0]
  289. def _distance_to_holiday(holiday):
  290. def _distance_to_day(index):
  291. holiday_date = holiday.dates(
  292. index - pd.Timedelta(days=MAX_WINDOW),
  293. index + pd.Timedelta(days=MAX_WINDOW),
  294. )
  295. assert (
  296. len(holiday_date) != 0
  297. ), f"No closest holiday for the date index {index} found."
  298. # It sometimes returns two dates if it is exactly half a year after the
  299. # holiday. In this case, the smaller distance (182 days) is returned.
  300. return float((index - holiday_date[0]).days)
  301. return _distance_to_day
  302. def time_feature(dataset, freq, feature_cols, extend_points, inplace: bool = False):
  303. """
  304. Transform time column to time features.
  305. Args:
  306. dataset(TSDataset): Dataset to be transformed.
  307. inplace(bool): Whether to perform the transformation inplace. default=False
  308. Returns:
  309. TSDataset
  310. """
  311. new_ts = dataset
  312. if not inplace:
  313. new_ts = dataset.copy()
  314. # Get known_cov
  315. kcov = new_ts["known_cov_numeric"]
  316. if not kcov:
  317. tf_kcov = new_ts["past_target"].index.to_frame()
  318. else:
  319. tf_kcov = kcov.index.to_frame()
  320. time_col = tf_kcov.columns[0]
  321. if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
  322. raise ValueError(
  323. "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
  324. )
  325. if not kcov:
  326. freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
  327. extend_time = pd.date_range(
  328. start=tf_kcov[time_col][-1],
  329. freq=freq,
  330. periods=extend_points + 1,
  331. closed="right",
  332. name=time_col,
  333. ).to_frame()
  334. tf_kcov = pd.concat([tf_kcov, extend_time])
  335. for k in feature_cols:
  336. if k != "holidays":
  337. v = tf_kcov[time_col].apply(lambda x: CAL_DATE_METHOD[k](x))
  338. v.index = tf_kcov[time_col]
  339. if new_ts["known_cov_numeric"] is None:
  340. new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
  341. else:
  342. new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
  343. new_ts["known_cov_numeric"].index
  344. )
  345. else:
  346. holidays_col = []
  347. for i, H in enumerate(HOLIDAYS):
  348. v = tf_kcov[time_col].apply(_distance_to_holiday(H))
  349. v.index = tf_kcov[time_col]
  350. holidays_col.append(k + "_" + str(i))
  351. if new_ts["known_cov_numeric"] is None:
  352. new_ts["known_cov_numeric"] = pd.DataFrame(
  353. v.rename(k + "_" + str(i)), index=v.index
  354. )
  355. else:
  356. new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
  357. new_ts["known_cov_numeric"].index
  358. )
  359. scaler = StandardScaler()
  360. scaler.fit(new_ts["known_cov_numeric"][holidays_col])
  361. new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
  362. new_ts["known_cov_numeric"][holidays_col]
  363. )
  364. return new_ts