processors.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Union, Dict
  15. import chinese_calendar
  16. import joblib
  17. import numpy as np
  18. import pandas as pd
  19. from pandas.tseries.offsets import DateOffset, Easter, Day
  20. from pandas.tseries import holiday as hd
  21. from sklearn.preprocessing import StandardScaler
  22. from ..base import PyOnlyProcessor
  23. __all__ = [
  24. "CutOff",
  25. "Normalize",
  26. "Denormalize",
  27. "BuildTSDataset",
  28. "CalcTimeFeatures",
  29. "BuildPaddedMask",
  30. "DataFrame2Arrays",
  31. ]
  32. _MAX_WINDOW = 183 + 17
  33. _EASTER_SUNDAY = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
  34. _NEW_YEARS_DAY = hd.Holiday("New Years Day", month=1, day=1)
  35. _SUPER_BOWL = hd.Holiday(
  36. "Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1))
  37. )
  38. _MOTHERS_DAY = hd.Holiday(
  39. "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
  40. )
  41. _INDEPENDENCE_DAY = hd.Holiday("Independence Day", month=7, day=4)
  42. _CHRISTMAS_EVE = hd.Holiday("Christmas", month=12, day=24)
  43. _CHRISTMAS_DAY = hd.Holiday("Christmas", month=12, day=25)
  44. _NEW_YEARS_EVE = hd.Holiday("New Years Eve", month=12, day=31)
  45. _BLACK_FRIDAY = hd.Holiday(
  46. "Black Friday",
  47. month=11,
  48. day=1,
  49. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
  50. )
  51. _CYBER_MONDAY = hd.Holiday(
  52. "Cyber Monday",
  53. month=11,
  54. day=1,
  55. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
  56. )
  57. _HOLYDAYS = [
  58. hd.EasterMonday,
  59. hd.GoodFriday,
  60. hd.USColumbusDay,
  61. hd.USLaborDay,
  62. hd.USMartinLutherKingJr,
  63. hd.USMemorialDay,
  64. hd.USPresidentsDay,
  65. hd.USThanksgivingDay,
  66. _EASTER_SUNDAY,
  67. _NEW_YEARS_DAY,
  68. _SUPER_BOWL,
  69. _MOTHERS_DAY,
  70. _INDEPENDENCE_DAY,
  71. _CHRISTMAS_EVE,
  72. _CHRISTMAS_DAY,
  73. _NEW_YEARS_EVE,
  74. _BLACK_FRIDAY,
  75. _CYBER_MONDAY,
  76. ]
  77. def _cal_year(
  78. x: np.datetime64,
  79. ):
  80. return x.year
  81. def _cal_month(
  82. x: np.datetime64,
  83. ):
  84. return x.month
  85. def _cal_day(
  86. x: np.datetime64,
  87. ):
  88. return x.day
  89. def _cal_hour(
  90. x: np.datetime64,
  91. ):
  92. return x.hour
  93. def _cal_weekday(
  94. x: np.datetime64,
  95. ):
  96. return x.dayofweek
  97. def _cal_quarter(
  98. x: np.datetime64,
  99. ):
  100. return x.quarter
  101. def _cal_hourofday(
  102. x: np.datetime64,
  103. ):
  104. return x.hour / 23.0 - 0.5
  105. def _cal_dayofweek(
  106. x: np.datetime64,
  107. ):
  108. return x.dayofweek / 6.0 - 0.5
  109. def _cal_dayofmonth(
  110. x: np.datetime64,
  111. ):
  112. return x.day / 30.0 - 0.5
  113. def _cal_dayofyear(
  114. x: np.datetime64,
  115. ):
  116. return x.dayofyear / 364.0 - 0.5
  117. def _cal_weekofyear(
  118. x: np.datetime64,
  119. ):
  120. return x.weekofyear / 51.0 - 0.5
  121. def _cal_holiday(
  122. x: np.datetime64,
  123. ):
  124. return float(chinese_calendar.is_holiday(x))
  125. def _cal_workday(
  126. x: np.datetime64,
  127. ):
  128. return float(chinese_calendar.is_workday(x))
  129. def _cal_minuteofhour(
  130. x: np.datetime64,
  131. ):
  132. return x.minute / 59 - 0.5
  133. def _cal_monthofyear(
  134. x: np.datetime64,
  135. ):
  136. return x.month / 11.0 - 0.5
  137. _CAL_DATE_METHOD = {
  138. "year": _cal_year,
  139. "month": _cal_month,
  140. "day": _cal_day,
  141. "hour": _cal_hour,
  142. "weekday": _cal_weekday,
  143. "quarter": _cal_quarter,
  144. "minuteofhour": _cal_minuteofhour,
  145. "monthofyear": _cal_monthofyear,
  146. "hourofday": _cal_hourofday,
  147. "dayofweek": _cal_dayofweek,
  148. "dayofmonth": _cal_dayofmonth,
  149. "dayofyear": _cal_dayofyear,
  150. "weekofyear": _cal_weekofyear,
  151. "is_holiday": _cal_holiday,
  152. "is_workday": _cal_workday,
  153. }
  154. def _load_from_one_dataframe(
  155. data: Union[pd.DataFrame, pd.Series],
  156. time_col: Optional[str] = None,
  157. value_cols: Optional[Union[List[str], str]] = None,
  158. freq: Optional[Union[str, int]] = None,
  159. drop_tail_nan: bool = False,
  160. dtype: Optional[Union[type, Dict[str, type]]] = None,
  161. ):
  162. series_data = None
  163. if value_cols is None:
  164. if isinstance(data, pd.Series):
  165. series_data = data.copy()
  166. else:
  167. series_data = data.loc[:, data.columns != time_col].copy()
  168. else:
  169. series_data = data.loc[:, value_cols].copy()
  170. if time_col:
  171. if time_col not in data.columns:
  172. raise ValueError(
  173. "The time column: {} doesn't exist in the `data`!".format(time_col)
  174. )
  175. time_col_vals = data.loc[:, time_col]
  176. else:
  177. time_col_vals = data.index
  178. if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
  179. time_col_vals = time_col_vals.astype(str)
  180. if np.issubdtype(time_col_vals.dtype, np.integer):
  181. if freq:
  182. if not isinstance(freq, int) or freq < 1:
  183. raise ValueError(
  184. "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
  185. )
  186. else:
  187. freq = 1
  188. start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
  189. if (stop_idx - start_idx) / freq != len(data):
  190. raise ValueError("The number of rows doesn't match with the RangeIndex!")
  191. time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
  192. elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
  193. time_col_vals.dtype, np.datetime64
  194. ):
  195. time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
  196. time_index = pd.DatetimeIndex(time_col_vals)
  197. if freq:
  198. if not isinstance(freq, str):
  199. raise ValueError(
  200. "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
  201. )
  202. else:
  203. # If freq is not provided and automatic inference fail, throw exception
  204. freq = pd.infer_freq(time_index)
  205. if freq is None:
  206. raise ValueError(
  207. "Failed to infer the `freq`. A valid `freq` is required."
  208. )
  209. if freq[0] == "-":
  210. freq = freq[1:]
  211. else:
  212. raise ValueError("The type of `time_col` is invalid.")
  213. if isinstance(series_data, pd.Series):
  214. series_data = series_data.to_frame()
  215. series_data.set_index(time_index, inplace=True)
  216. series_data.sort_index(inplace=True)
  217. return series_data
  218. def _load_from_dataframe(
  219. df: pd.DataFrame,
  220. group_id: str = None,
  221. time_col: Optional[str] = None,
  222. target_cols: Optional[Union[List[str], str]] = None,
  223. label_col: Optional[Union[List[str], str]] = None,
  224. observed_cov_cols: Optional[Union[List[str], str]] = None,
  225. feature_cols: Optional[Union[List[str], str]] = None,
  226. known_cov_cols: Optional[Union[List[str], str]] = None,
  227. static_cov_cols: Optional[Union[List[str], str]] = None,
  228. freq: Optional[Union[str, int]] = None,
  229. fill_missing_dates: bool = False,
  230. fillna_method: str = "pre",
  231. fillna_window_size: int = 10,
  232. **kwargs,
  233. ):
  234. dfs = [] # seperate multiple group
  235. if group_id is not None:
  236. group_unique = df[group_id].unique()
  237. for column in group_unique:
  238. dfs.append(df[df[group_id].isin([column])])
  239. else:
  240. dfs = [df]
  241. res = []
  242. if label_col:
  243. if isinstance(label_col, str) and len(label_col) > 1:
  244. raise ValueError("The length of label_col must be 1.")
  245. target_cols = label_col
  246. if feature_cols:
  247. observed_cov_cols = feature_cols
  248. for df in dfs:
  249. target = None
  250. observed_cov = None
  251. known_cov = None
  252. static_cov = dict()
  253. if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
  254. target = _load_from_one_dataframe(
  255. df,
  256. time_col,
  257. [a for a in df.columns if a != time_col],
  258. freq,
  259. )
  260. else:
  261. if target_cols:
  262. target = _load_from_one_dataframe(
  263. df,
  264. time_col,
  265. target_cols,
  266. freq,
  267. )
  268. if observed_cov_cols:
  269. observed_cov = _load_from_one_dataframe(
  270. df,
  271. time_col,
  272. observed_cov_cols,
  273. freq,
  274. )
  275. if known_cov_cols:
  276. known_cov = _load_from_one_dataframe(
  277. df,
  278. time_col,
  279. known_cov_cols,
  280. freq,
  281. )
  282. if static_cov_cols:
  283. if isinstance(static_cov_cols, str):
  284. static_cov_cols = [static_cov_cols]
  285. for col in static_cov_cols:
  286. if col not in df.columns or len(np.unique(df[col])) != 1:
  287. raise ValueError(
  288. "static cov cals data is not in columns or schema is not right!"
  289. )
  290. static_cov[col] = df[col].iloc[0]
  291. res.append(
  292. {
  293. "past_target": target,
  294. "observed_cov_numeric": observed_cov,
  295. "known_cov_numeric": known_cov,
  296. "static_cov_numeric": static_cov,
  297. }
  298. )
  299. return res[0]
  300. def _distance_to_holiday(holiday):
  301. def _distance_to_day(index):
  302. holiday_date = holiday.dates(
  303. index - pd.Timedelta(days=_MAX_WINDOW),
  304. index + pd.Timedelta(days=_MAX_WINDOW),
  305. )
  306. assert (
  307. len(holiday_date) != 0
  308. ), f"No closest holiday for the date index {index} found."
  309. # It sometimes returns two dates if it is exactly half a year after the
  310. # holiday. In this case, the smaller distance (182 days) is returned.
  311. return float((index - holiday_date[0]).days)
  312. return _distance_to_day
  313. def _to_time_features(
  314. dataset, freq, feature_cols, extend_points, inplace: bool = False
  315. ):
  316. new_ts = dataset
  317. if not inplace:
  318. new_ts = dataset.copy()
  319. # Get known_cov
  320. kcov = new_ts["known_cov_numeric"]
  321. if not kcov:
  322. tf_kcov = new_ts["past_target"].index.to_frame()
  323. else:
  324. tf_kcov = kcov.index.to_frame()
  325. time_col = tf_kcov.columns[0]
  326. if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
  327. raise ValueError(
  328. "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
  329. )
  330. if not kcov:
  331. freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
  332. extend_time = pd.date_range(
  333. start=tf_kcov[time_col][-1],
  334. freq=freq,
  335. periods=extend_points + 1,
  336. closed="right",
  337. name=time_col,
  338. ).to_frame()
  339. tf_kcov = pd.concat([tf_kcov, extend_time])
  340. for k in feature_cols:
  341. if k != "holidays":
  342. v = tf_kcov[time_col].apply(lambda x: _CAL_DATE_METHOD[k](x))
  343. v.index = tf_kcov[time_col]
  344. if new_ts["known_cov_numeric"] is None:
  345. new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
  346. else:
  347. new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
  348. new_ts["known_cov_numeric"].index
  349. )
  350. else:
  351. holidays_col = []
  352. for i, H in enumerate(_HOLYDAYS):
  353. v = tf_kcov[time_col].apply(_distance_to_holiday(H))
  354. v.index = tf_kcov[time_col]
  355. holidays_col.append(k + "_" + str(i))
  356. if new_ts["known_cov_numeric"] is None:
  357. new_ts["known_cov_numeric"] = pd.DataFrame(
  358. v.rename(k + "_" + str(i)), index=v.index
  359. )
  360. else:
  361. new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
  362. new_ts["known_cov_numeric"].index
  363. )
  364. scaler = StandardScaler()
  365. scaler.fit(new_ts["known_cov_numeric"][holidays_col])
  366. new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
  367. new_ts["known_cov_numeric"][holidays_col]
  368. )
  369. return new_ts
  370. class CutOff(PyOnlyProcessor):
  371. def __init__(self, size):
  372. super().__init__()
  373. self._size = size
  374. def __call__(self, data):
  375. ts = data["ts"]
  376. ori_ts = data["ori_ts"]
  377. skip_len = self._size.get("skip_chunk_len", 0)
  378. if len(ts) < self._size["in_chunk_len"] + skip_len:
  379. raise ValueError(
  380. f"The length of the input data is {len(ts)}, but it should be at least {self._size['in_chunk_len'] + self._size['skip_chunk_len']} for training."
  381. )
  382. ts_data = ts[-(self._size["in_chunk_len"] + skip_len) :]
  383. return {**data, "ts": ts_data, "ori_ts": ts_data}
  384. class Normalize(PyOnlyProcessor):
  385. def __init__(self, scale_path, params_info):
  386. super().__init__()
  387. self._scaler = joblib.load(scale_path)
  388. self._params_info = params_info
  389. def __call__(self, data):
  390. ts = data["ts"]
  391. if self._params_info.get("target_cols", None) is not None:
  392. ts[self._params_info["target_cols"]] = self._scaler.transform(
  393. ts[self._params_info["target_cols"]]
  394. )
  395. if self._params_info.get("feature_cols", None) is not None:
  396. ts[self._params_info["feature_cols"]] = self._scaler.transform(
  397. ts[self._params_info["feature_cols"]]
  398. )
  399. return {**data, "ts": ts}
  400. class Denormalize(PyOnlyProcessor):
  401. def __init__(self, scale_path, params_info):
  402. super().__init__()
  403. self._scaler = joblib.load(scale_path)
  404. self._params_info = params_info
  405. def __call__(self, data):
  406. pred = data["pred"]
  407. scale_cols = pred.columns.values.tolist()
  408. pred[scale_cols] = self._scaler.inverse_transform(pred[scale_cols])
  409. return {**data, "pred": pred}
  410. class BuildTSDataset(PyOnlyProcessor):
  411. def __init__(self, params_info):
  412. super().__init__()
  413. self._params_info = params_info
  414. def __call__(self, data):
  415. ts = data["ts"]
  416. ori_ts = data["ori_ts"]
  417. ts_data = _load_from_dataframe(ts, **self._params_info)
  418. return {**data, "ts": ts_data, "ori_ts": ts_data}
  419. class CalcTimeFeatures(PyOnlyProcessor):
  420. def __init__(self, params_info, size, holiday=False):
  421. super().__init__()
  422. self._freq = params_info["freq"]
  423. self._size = size
  424. self._holiday = holiday
  425. def __call__(self, data):
  426. ts = data["ts"]
  427. if not self._holiday:
  428. ts = _to_time_features(
  429. ts,
  430. self._freq,
  431. ["hourofday", "dayofmonth", "dayofweek", "dayofyear"],
  432. self._size["out_chunk_len"],
  433. )
  434. else:
  435. ts = _to_time_features(
  436. ts,
  437. self._freq,
  438. [
  439. "minuteofhour",
  440. "hourofday",
  441. "dayofmonth",
  442. "dayofweek",
  443. "dayofyear",
  444. "monthofyear",
  445. "weekofyear",
  446. "holidays",
  447. ],
  448. self._size["out_chunk_len"],
  449. )
  450. return {**data, "ts": ts}
  451. class BuildPaddedMask(PyOnlyProcessor):
  452. def __init__(self, input_data):
  453. super().__init__()
  454. self._input_data = input_data
  455. def __call__(self, data):
  456. ts = data["ts"]
  457. if "features" in self._input_data:
  458. ts["features"] = ts["past_target"]
  459. if "pad_mask" in self._input_data:
  460. target_dim = len(ts["features"])
  461. max_length = self._input_data["pad_mask"][-1]
  462. if max_length > 0:
  463. ones = np.ones(max_length, dtype=np.int32)
  464. if max_length != target_dim:
  465. target_ndarray = np.array(ts["features"]).astype(np.float32)
  466. target_ndarray_final = np.zeros(
  467. [max_length, target_dim], dtype=np.int32
  468. )
  469. end = min(target_dim, max_length)
  470. target_ndarray_final[:end, :] = target_ndarray
  471. ts["features"] = target_ndarray_final
  472. ones[end:] = 0.0
  473. ts["pad_mask"] = ones
  474. else:
  475. ts["pad_mask"] = ones
  476. return {**data, "ts": ts}
  477. class DataFrame2Arrays(PyOnlyProcessor):
  478. def __init__(self, input_data):
  479. super().__init__()
  480. self._input_data = input_data
  481. def __call__(self, data):
  482. ts = data["ts"]
  483. ts_list = []
  484. input_name = list(self._input_data.keys())
  485. input_name.sort()
  486. for key in input_name:
  487. ts_list.append(np.array(ts[key]).astype("float32"))
  488. return {**data, "ts": ts_list}