processors.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Union, Dict
  15. import numpy as np
  16. from ..base import PyOnlyProcessor
  17. __all__ = [
  18. "CutOff",
  19. "Normalize",
  20. "Denormalize",
  21. "BuildTSDataset",
  22. "CalcTimeFeatures",
  23. "BuildPaddedMask",
  24. "DataFrame2Arrays",
  25. ]
  26. _MAX_WINDOW = 183 + 17
  27. def _cal_year(
  28. x: np.datetime64,
  29. ):
  30. return x.year
  31. def _cal_month(
  32. x: np.datetime64,
  33. ):
  34. return x.month
  35. def _cal_day(
  36. x: np.datetime64,
  37. ):
  38. return x.day
  39. def _cal_hour(
  40. x: np.datetime64,
  41. ):
  42. return x.hour
  43. def _cal_weekday(
  44. x: np.datetime64,
  45. ):
  46. return x.dayofweek
  47. def _cal_quarter(
  48. x: np.datetime64,
  49. ):
  50. return x.quarter
  51. def _cal_hourofday(
  52. x: np.datetime64,
  53. ):
  54. return x.hour / 23.0 - 0.5
  55. def _cal_dayofweek(
  56. x: np.datetime64,
  57. ):
  58. return x.dayofweek / 6.0 - 0.5
  59. def _cal_dayofmonth(
  60. x: np.datetime64,
  61. ):
  62. return x.day / 30.0 - 0.5
  63. def _cal_dayofyear(
  64. x: np.datetime64,
  65. ):
  66. return x.dayofyear / 364.0 - 0.5
  67. def _cal_weekofyear(
  68. x: np.datetime64,
  69. ):
  70. return x.weekofyear / 51.0 - 0.5
  71. def _cal_holiday(
  72. x: np.datetime64,
  73. ):
  74. import chinese_calendar
  75. return float(chinese_calendar.is_holiday(x))
  76. def _cal_workday(
  77. x: np.datetime64,
  78. ):
  79. import chinese_calendar
  80. return float(chinese_calendar.is_workday(x))
  81. def _cal_minuteofhour(
  82. x: np.datetime64,
  83. ):
  84. return x.minute / 59 - 0.5
  85. def _cal_monthofyear(
  86. x: np.datetime64,
  87. ):
  88. return x.month / 11.0 - 0.5
  89. _CAL_DATE_METHOD = {
  90. "year": _cal_year,
  91. "month": _cal_month,
  92. "day": _cal_day,
  93. "hour": _cal_hour,
  94. "weekday": _cal_weekday,
  95. "quarter": _cal_quarter,
  96. "minuteofhour": _cal_minuteofhour,
  97. "monthofyear": _cal_monthofyear,
  98. "hourofday": _cal_hourofday,
  99. "dayofweek": _cal_dayofweek,
  100. "dayofmonth": _cal_dayofmonth,
  101. "dayofyear": _cal_dayofyear,
  102. "weekofyear": _cal_weekofyear,
  103. "is_holiday": _cal_holiday,
  104. "is_workday": _cal_workday,
  105. }
  106. def _load_from_one_dataframe(
  107. data: Union["pd.DataFrame", "pd.Series"], # noqa: F821
  108. time_col: Optional[str] = None,
  109. value_cols: Optional[Union[List[str], str]] = None,
  110. freq: Optional[Union[str, int]] = None,
  111. drop_tail_nan: bool = False,
  112. dtype: Optional[Union[type, Dict[str, type]]] = None,
  113. ):
  114. import pandas as pd
  115. series_data = None
  116. if value_cols is None:
  117. if isinstance(data, pd.Series):
  118. series_data = data.copy()
  119. else:
  120. series_data = data.loc[:, data.columns != time_col].copy()
  121. else:
  122. series_data = data.loc[:, value_cols].copy()
  123. if time_col:
  124. if time_col not in data.columns:
  125. raise ValueError(
  126. "The time column: {} doesn't exist in the `data`!".format(time_col)
  127. )
  128. time_col_vals = data.loc[:, time_col]
  129. else:
  130. time_col_vals = data.index
  131. if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
  132. time_col_vals = time_col_vals.astype(str)
  133. if np.issubdtype(time_col_vals.dtype, np.integer):
  134. if freq:
  135. if not isinstance(freq, int) or freq < 1:
  136. raise ValueError(
  137. "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
  138. )
  139. else:
  140. freq = 1
  141. start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
  142. if (stop_idx - start_idx) / freq != len(data):
  143. raise ValueError("The number of rows doesn't match with the RangeIndex!")
  144. time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
  145. elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
  146. time_col_vals.dtype, np.datetime64
  147. ):
  148. time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
  149. time_index = pd.DatetimeIndex(time_col_vals)
  150. if freq:
  151. if not isinstance(freq, str):
  152. raise ValueError(
  153. "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
  154. )
  155. else:
  156. # If freq is not provided and automatic inference fail, throw exception
  157. freq = pd.infer_freq(time_index)
  158. if freq is None:
  159. raise ValueError(
  160. "Failed to infer the `freq`. A valid `freq` is required."
  161. )
  162. if freq[0] == "-":
  163. freq = freq[1:]
  164. else:
  165. raise ValueError("The type of `time_col` is invalid.")
  166. if isinstance(series_data, pd.Series):
  167. series_data = series_data.to_frame()
  168. series_data.set_index(time_index, inplace=True)
  169. series_data.sort_index(inplace=True)
  170. return series_data
  171. def _load_from_dataframe(
  172. df: "pd.DataFrame", # noqa: F821
  173. group_id: str = None,
  174. time_col: Optional[str] = None,
  175. target_cols: Optional[Union[List[str], str]] = None,
  176. label_col: Optional[Union[List[str], str]] = None,
  177. observed_cov_cols: Optional[Union[List[str], str]] = None,
  178. feature_cols: Optional[Union[List[str], str]] = None,
  179. known_cov_cols: Optional[Union[List[str], str]] = None,
  180. static_cov_cols: Optional[Union[List[str], str]] = None,
  181. freq: Optional[Union[str, int]] = None,
  182. fill_missing_dates: bool = False,
  183. fillna_method: str = "pre",
  184. fillna_window_size: int = 10,
  185. **kwargs,
  186. ):
  187. dfs = [] # separate multiple group
  188. if group_id is not None:
  189. group_unique = df[group_id].unique()
  190. for column in group_unique:
  191. dfs.append(df[df[group_id].isin([column])])
  192. else:
  193. dfs = [df]
  194. res = []
  195. if label_col:
  196. if isinstance(label_col, str) and len(label_col) > 1:
  197. raise ValueError("The length of label_col must be 1.")
  198. target_cols = label_col
  199. if feature_cols:
  200. observed_cov_cols = feature_cols
  201. for df in dfs:
  202. target = None
  203. observed_cov = None
  204. known_cov = None
  205. static_cov = dict()
  206. if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
  207. target = _load_from_one_dataframe(
  208. df,
  209. time_col,
  210. [a for a in df.columns if a != time_col],
  211. freq,
  212. )
  213. else:
  214. if target_cols:
  215. target = _load_from_one_dataframe(
  216. df,
  217. time_col,
  218. target_cols,
  219. freq,
  220. )
  221. if observed_cov_cols:
  222. observed_cov = _load_from_one_dataframe(
  223. df,
  224. time_col,
  225. observed_cov_cols,
  226. freq,
  227. )
  228. if known_cov_cols:
  229. known_cov = _load_from_one_dataframe(
  230. df,
  231. time_col,
  232. known_cov_cols,
  233. freq,
  234. )
  235. if static_cov_cols:
  236. if isinstance(static_cov_cols, str):
  237. static_cov_cols = [static_cov_cols]
  238. for col in static_cov_cols:
  239. if col not in df.columns or len(np.unique(df[col])) != 1:
  240. raise ValueError(
  241. "static cov cals data is not in columns or schema is not right!"
  242. )
  243. static_cov[col] = df[col].iloc[0]
  244. res.append(
  245. {
  246. "past_target": target,
  247. "observed_cov_numeric": observed_cov,
  248. "known_cov_numeric": known_cov,
  249. "static_cov_numeric": static_cov,
  250. }
  251. )
  252. return res[0]
  253. def _distance_to_holiday(holiday):
  254. def _distance_to_day(index):
  255. import pandas as pd
  256. holiday_date = holiday.dates(
  257. index - pd.Timedelta(days=_MAX_WINDOW),
  258. index + pd.Timedelta(days=_MAX_WINDOW),
  259. )
  260. assert (
  261. len(holiday_date) != 0
  262. ), f"No closest holiday for the date index {index} found."
  263. # It sometimes returns two dates if it is exactly half a year after the
  264. # holiday. In this case, the smaller distance (182 days) is returned.
  265. return float((index - holiday_date[0]).days)
  266. return _distance_to_day
  267. def _to_time_features(
  268. dataset, freq, feature_cols, extend_points, inplace: bool = False
  269. ):
  270. import pandas as pd
  271. new_ts = dataset
  272. if not inplace:
  273. new_ts = dataset.copy()
  274. # Get known_cov
  275. kcov = new_ts["known_cov_numeric"]
  276. if not kcov:
  277. tf_kcov = new_ts["past_target"].index.to_frame()
  278. else:
  279. tf_kcov = kcov.index.to_frame()
  280. time_col = tf_kcov.columns[0]
  281. if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
  282. raise ValueError(
  283. "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
  284. )
  285. if not kcov:
  286. freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
  287. extend_time = pd.date_range(
  288. start=tf_kcov[time_col][-1],
  289. freq=freq,
  290. periods=extend_points + 1,
  291. closed="right",
  292. name=time_col,
  293. ).to_frame()
  294. tf_kcov = pd.concat([tf_kcov, extend_time])
  295. for k in feature_cols:
  296. if k != "holidays":
  297. v = tf_kcov[time_col].apply(lambda x: _CAL_DATE_METHOD[k](x))
  298. v.index = tf_kcov[time_col]
  299. if new_ts["known_cov_numeric"] is None:
  300. new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
  301. else:
  302. new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
  303. new_ts["known_cov_numeric"].index
  304. )
  305. else:
  306. from pandas.tseries.offsets import DateOffset, Easter, Day
  307. from pandas.tseries import holiday as hd
  308. from sklearn.preprocessing import StandardScaler
  309. _EASTER_SUNDAY = hd.Holiday(
  310. "Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)]
  311. )
  312. _NEW_YEARS_DAY = hd.Holiday("New Years Day", month=1, day=1)
  313. _SUPER_BOWL = hd.Holiday(
  314. "Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1))
  315. )
  316. _MOTHERS_DAY = hd.Holiday(
  317. "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
  318. )
  319. _INDEPENDENCE_DAY = hd.Holiday("Independence Day", month=7, day=4)
  320. _CHRISTMAS_EVE = hd.Holiday("Christmas", month=12, day=24)
  321. _CHRISTMAS_DAY = hd.Holiday("Christmas", month=12, day=25)
  322. _NEW_YEARS_EVE = hd.Holiday("New Years Eve", month=12, day=31)
  323. _BLACK_FRIDAY = hd.Holiday(
  324. "Black Friday",
  325. month=11,
  326. day=1,
  327. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
  328. )
  329. _CYBER_MONDAY = hd.Holiday(
  330. "Cyber Monday",
  331. month=11,
  332. day=1,
  333. offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
  334. )
  335. _HOLYDAYS = [
  336. hd.EasterMonday,
  337. hd.GoodFriday,
  338. hd.USColumbusDay,
  339. hd.USLaborDay,
  340. hd.USMartinLutherKingJr,
  341. hd.USMemorialDay,
  342. hd.USPresidentsDay,
  343. hd.USThanksgivingDay,
  344. _EASTER_SUNDAY,
  345. _NEW_YEARS_DAY,
  346. _SUPER_BOWL,
  347. _MOTHERS_DAY,
  348. _INDEPENDENCE_DAY,
  349. _CHRISTMAS_EVE,
  350. _CHRISTMAS_DAY,
  351. _NEW_YEARS_EVE,
  352. _BLACK_FRIDAY,
  353. _CYBER_MONDAY,
  354. ]
  355. holidays_col = []
  356. for i, H in enumerate(_HOLYDAYS):
  357. v = tf_kcov[time_col].apply(_distance_to_holiday(H))
  358. v.index = tf_kcov[time_col]
  359. holidays_col.append(k + "_" + str(i))
  360. if new_ts["known_cov_numeric"] is None:
  361. new_ts["known_cov_numeric"] = pd.DataFrame(
  362. v.rename(k + "_" + str(i)), index=v.index
  363. )
  364. else:
  365. new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
  366. new_ts["known_cov_numeric"].index
  367. )
  368. scaler = StandardScaler()
  369. scaler.fit(new_ts["known_cov_numeric"][holidays_col])
  370. new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
  371. new_ts["known_cov_numeric"][holidays_col]
  372. )
  373. return new_ts
  374. class CutOff(PyOnlyProcessor):
  375. def __init__(self, size):
  376. super().__init__()
  377. self._size = size
  378. def __call__(self, data):
  379. ts = data["ts"]
  380. ori_ts = data["ori_ts"]
  381. skip_len = self._size.get("skip_chunk_len", 0)
  382. if len(ts) < self._size["in_chunk_len"] + skip_len:
  383. raise ValueError(
  384. f"The length of the input data is {len(ts)}, but it should be at least {self._size['in_chunk_len'] + self._size['skip_chunk_len']} for training."
  385. )
  386. ts_data = ts[-(self._size["in_chunk_len"] + skip_len) :]
  387. return {**data, "ts": ts_data, "ori_ts": ts_data}
  388. class Normalize(PyOnlyProcessor):
  389. def __init__(self, scale_path, params_info):
  390. import joblib
  391. super().__init__()
  392. self._scaler = joblib.load(scale_path)
  393. self._params_info = params_info
  394. def __call__(self, data):
  395. ts = data["ts"]
  396. if self._params_info.get("target_cols", None) is not None:
  397. ts[self._params_info["target_cols"]] = self._scaler.transform(
  398. ts[self._params_info["target_cols"]]
  399. )
  400. if self._params_info.get("feature_cols", None) is not None:
  401. ts[self._params_info["feature_cols"]] = self._scaler.transform(
  402. ts[self._params_info["feature_cols"]]
  403. )
  404. return {**data, "ts": ts}
  405. class Denormalize(PyOnlyProcessor):
  406. def __init__(self, scale_path, params_info):
  407. import joblib
  408. super().__init__()
  409. self._scaler = joblib.load(scale_path)
  410. self._params_info = params_info
  411. def __call__(self, data):
  412. pred = data["pred"]
  413. scale_cols = pred.columns.values.tolist()
  414. pred[scale_cols] = self._scaler.inverse_transform(pred[scale_cols])
  415. return {**data, "pred": pred}
  416. class BuildTSDataset(PyOnlyProcessor):
  417. def __init__(self, params_info):
  418. super().__init__()
  419. self._params_info = params_info
  420. def __call__(self, data):
  421. ts = data["ts"]
  422. ori_ts = data["ori_ts"]
  423. ts_data = _load_from_dataframe(ts, **self._params_info)
  424. return {**data, "ts": ts_data, "ori_ts": ts_data}
  425. class CalcTimeFeatures(PyOnlyProcessor):
  426. def __init__(self, params_info, size, holiday=False):
  427. super().__init__()
  428. self._freq = params_info["freq"]
  429. self._size = size
  430. self._holiday = holiday
  431. def __call__(self, data):
  432. ts = data["ts"]
  433. if not self._holiday:
  434. ts = _to_time_features(
  435. ts,
  436. self._freq,
  437. ["hourofday", "dayofmonth", "dayofweek", "dayofyear"],
  438. self._size["out_chunk_len"],
  439. )
  440. else:
  441. ts = _to_time_features(
  442. ts,
  443. self._freq,
  444. [
  445. "minuteofhour",
  446. "hourofday",
  447. "dayofmonth",
  448. "dayofweek",
  449. "dayofyear",
  450. "monthofyear",
  451. "weekofyear",
  452. "holidays",
  453. ],
  454. self._size["out_chunk_len"],
  455. )
  456. return {**data, "ts": ts}
  457. class BuildPaddedMask(PyOnlyProcessor):
  458. def __init__(self, input_data):
  459. super().__init__()
  460. self._input_data = input_data
  461. def __call__(self, data):
  462. ts = data["ts"]
  463. if "features" in self._input_data:
  464. ts["features"] = ts["past_target"]
  465. if "pad_mask" in self._input_data:
  466. target_dim = len(ts["features"])
  467. max_length = self._input_data["pad_mask"][-1]
  468. if max_length > 0:
  469. ones = np.ones(max_length, dtype=np.int32)
  470. if max_length != target_dim:
  471. target_ndarray = np.array(ts["features"]).astype(np.float32)
  472. target_ndarray_final = np.zeros(
  473. [max_length, target_dim], dtype=np.int32
  474. )
  475. end = min(target_dim, max_length)
  476. target_ndarray_final[:end, :] = target_ndarray
  477. ts["features"] = target_ndarray_final
  478. ones[end:] = 0.0
  479. ts["pad_mask"] = ones
  480. else:
  481. ts["pad_mask"] = ones
  482. return {**data, "ts": ts}
  483. class DataFrame2Arrays(PyOnlyProcessor):
  484. def __init__(self, input_data):
  485. super().__init__()
  486. self._input_data = input_data
  487. def __call__(self, data):
  488. ts = data["ts"]
  489. ts_list = []
  490. input_name = list(self._input_data.keys())
  491. input_name.sort()
  492. for key in input_name:
  493. ts_list.append(np.array(ts[key]).astype("float32"))
  494. return {**data, "ts": ts_list}