Kaynağa Gözat

update ts model predict (#2643)

* update ts model predict

* update ts model predict

* add docstring

* update_docstring

* update_ts sampler
Sunflower7788 11 ay önce
ebeveyn
işleme
5dd30cf3bd

+ 1 - 0
paddlex/inference/common/batch_sampler/__init__.py

@@ -14,3 +14,4 @@
 
 from .base_batch_sampler import BaseBatchSampler
 from .image_batch_sampler import ImageBatchSampler
+from .ts_batch_sampler import TSBatchSampler

+ 110 - 0
paddlex/inference/common/batch_sampler/ts_batch_sampler.py

@@ -0,0 +1,110 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import ast
+from pathlib import Path
+import numpy as np
+import pandas as pd
+
+from ....utils import logging
+from ....utils.download import download
+from ....utils.cache import CACHE_DIR
+from .base_batch_sampler import BaseBatchSampler
+
+
+class TSBatchSampler(BaseBatchSampler):
+    """Batch sampler for time series data, supporting CSV file inputs."""
+
+    SUFFIX = ["csv"]
+
+    def _download_from_url(self, in_path: str) -> str:
+        """Download a file from a URL to a cache directory.
+
+        Args:
+            in_path (str): URL of the file to be downloaded.
+
+        Returns:
+            str: Path to the downloaded file.
+        """
+        file_name = Path(in_path).name
+        save_path = Path(CACHE_DIR) / "predict_input" / file_name
+        download(in_path, save_path, overwrite=True)
+        return save_path.as_posix()
+
+    def _get_files_list(self, fp: str) -> list:
+        """Get a list of CSV files from a directory or a single file path.
+
+        Args:
+            fp (str): Path to a directory or a single CSV file.
+
+        Returns:
+            list: Sorted list of CSV file paths.
+
+        Raises:
+            Exception: If no CSV file is found in the path.
+        """
+        file_list = []
+        if fp is None or not os.path.exists(fp):
+            raise Exception(f"Not found any csv file in path: {fp}")
+
+        if os.path.isfile(fp) and fp.split(".")[-1] in self.SUFFIX:
+            file_list.append(fp)
+        elif os.path.isdir(fp):
+            for root, dirs, files in os.walk(fp):
+                for single_file in files:
+                    if single_file.split(".")[-1] in self.SUFFIX:
+                        file_list.append(os.path.join(root, single_file))
+        if len(file_list) == 0:
+            raise Exception("Not found any file in {}".format(fp))
+        file_list = sorted(file_list)
+        return file_list
+
+    def sample(self, inputs: list) -> list:
+        """Generate batches of data from inputs, which can be DataFrames or file paths.
+
+        Args:
+            inputs (list): List of DataFrames or file paths.
+
+        Yields:
+            list: A batch of data which is either DataFrames or file paths.
+        """
+        if not isinstance(inputs, list):
+            inputs = [inputs]
+
+        batch = []
+        for input in inputs:
+            if isinstance(input, pd.DataFrame):
+                batch.append(input)
+                if len(batch) == self.batch_size:
+                    yield batch
+                    batch = []
+            elif isinstance(input, str):
+                file_path = (
+                    self._download_from_url(input)
+                    if input.startswith("http")
+                    else input
+                )
+                file_list = self._get_files_list(file_path)
+                for file_path in file_list:
+                    batch.append(file_path)
+                    if len(batch) == self.batch_size:
+                        yield batch
+                        batch = []
+            else:
+                logging.warning(
+                    f"Not supported input data type! Only `pd.DataFrame` and `str` are supported! So has been ignored: {input}."
+                )
+        if len(batch) > 0:
+            yield batch

+ 1 - 0
paddlex/inference/common/reader/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .image_reader import ReadImage
+from .ts_reader import ReadTS

+ 45 - 0
paddlex/inference/common/reader/ts_reader.py

@@ -0,0 +1,45 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pandas as pd
+
+from ...utils.io import CSVReader
+
+
+class ReadTS:
+
+    def __init__(self):
+        super().__init__()
+        self._reader = CSVReader(backend="pandas")
+
+    def __call__(self, ts_list):
+        """apply"""
+        return [self.read(ts) for ts in ts_list]
+
+    def read(self, ts):
+        if isinstance(ts, pd.DataFrame):
+            return ts
+        elif isinstance(ts, str):
+            ts_data = self._reader.read(ts)
+            if ts_data is None:
+                raise Exception(f"TS read Error: {ts}")
+            return ts_data
+        else:
+            raise TypeError(
+                f"ReadTS only supports the following types:\n"
+                f"1. str, indicating a CSV file path or a directory containing CSV files.\n"
+                f"2. pandas.DataFrame.\n"
+                f"However, got type: {type(ts).__name__}."
+            )

+ 1 - 0
paddlex/inference/common/result/__init__.py

@@ -14,6 +14,7 @@
 
 from .base_result import BaseResult
 from .base_cv_result import BaseCVResult
+from .base_ts_result import BaseTSResult
 from .mixin import (
     StrMixin,
     JsonMixin,

+ 44 - 0
paddlex/inference/common/result/base_ts_result.py

@@ -0,0 +1,44 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_result import BaseResult
+from .mixin import StrMixin, JsonMixin, CSVMixin
+from ...utils.io import CSVWriter
+
+
+class BaseTSResult(BaseResult, StrMixin, JsonMixin, CSVMixin):
+    """Base class for times series results."""
+
+    INPUT_TS_KEY = "input_ts"
+
+    def __init__(self, data: dict) -> None:
+        """
+        Initialize the BaseTSResult.
+
+        Args:
+            data (dict): The initial data.
+
+        Raises:
+            AssertionError: If the required key (`BaseTSResult.INPUT_TS_KEY`) are not found in the data.
+        """
+        assert (
+            BaseTSResult.INPUT_TS_KEY in data
+        ), f"`{BaseTSResult.INPUT_TS_KEY}` is needed, but not found in `{list(data.keys())}`!"
+        self._input_ts = data.pop("input_ts", None)
+        self._ts_writer = CSVWriter(backend="pandas")
+
+        super().__init__(data)
+        StrMixin.__init__(self)
+        JsonMixin.__init__(self)
+        CSVMixin.__init__(self, "pandas")

+ 4 - 3
paddlex/inference/models_new/__init__.py

@@ -31,9 +31,10 @@ from .text_recognition import TextRecPredictor
 from .semantic_segmentation import SegPredictor
 
 # from .general_recognition import ShiTuRecPredictor
-# from .ts_fc import TSFcPredictor
-# from .ts_ad import TSAdPredictor
-# from .ts_cls import TSClsPredictor
+
+from .ts_forecast import TSFcPredictor
+from .ts_anomaly import TSAdPredictor
+from .ts_classify import TSClsPredictor
 from .image_unwarping import WarpPredictor
 
 # from .multilabel_classification import MLClasPredictor

+ 8 - 0
paddlex/inference/models_new/common/__init__.py

@@ -22,4 +22,12 @@ from .vision import (
     ToBatch,
 )
 
+from .ts import (
+    BuildTSDataset,
+    TSCutOff,
+    TSNormalize,
+    TimeFeature,
+    TStoArray,
+    TStoBatch,
+)
 from .static_infer import StaticInfer

+ 15 - 0
paddlex/inference/models_new/common/ts/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .processors import *

+ 533 - 0
paddlex/inference/models_new/common/ts/funcs.py

@@ -0,0 +1,533 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict
+import os
+import numpy as np
+import pandas as pd
+import chinese_calendar
+from pandas.tseries.offsets import DateOffset, Easter, Day
+from pandas.tseries import holiday as hd
+from sklearn.preprocessing import StandardScaler
+
+
+MAX_WINDOW = 183 + 17
+EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
+NewYearsDay = hd.Holiday("New Years Day", month=1, day=1)
+SuperBowl = hd.Holiday("Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1)))
+MothersDay = hd.Holiday(
+    "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
+)
+IndependenceDay = hd.Holiday("Independence Day", month=7, day=4)
+ChristmasEve = hd.Holiday("Christmas", month=12, day=24)
+ChristmasDay = hd.Holiday("Christmas", month=12, day=25)
+NewYearsEve = hd.Holiday("New Years Eve", month=12, day=31)
+BlackFriday = hd.Holiday(
+    "Black Friday",
+    month=11,
+    day=1,
+    offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
+)
+CyberMonday = hd.Holiday(
+    "Cyber Monday",
+    month=11,
+    day=1,
+    offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
+)
+
+HOLIDAYS = [
+    hd.EasterMonday,
+    hd.GoodFriday,
+    hd.USColumbusDay,
+    hd.USLaborDay,
+    hd.USMartinLutherKingJr,
+    hd.USMemorialDay,
+    hd.USPresidentsDay,
+    hd.USThanksgivingDay,
+    EasterSunday,
+    NewYearsDay,
+    SuperBowl,
+    MothersDay,
+    IndependenceDay,
+    ChristmasEve,
+    ChristmasDay,
+    NewYearsEve,
+    BlackFriday,
+    CyberMonday,
+]
+
+
+def _cal_year(
+    x: np.datetime64,
+):
+    return x.year
+
+
+def _cal_month(
+    x: np.datetime64,
+):
+    return x.month
+
+
+def _cal_day(
+    x: np.datetime64,
+):
+    return x.day
+
+
+def _cal_hour(
+    x: np.datetime64,
+):
+    return x.hour
+
+
+def _cal_weekday(
+    x: np.datetime64,
+):
+    return x.dayofweek
+
+
+def _cal_quarter(
+    x: np.datetime64,
+):
+    return x.quarter
+
+
+def _cal_hourofday(
+    x: np.datetime64,
+):
+    return x.hour / 23.0 - 0.5
+
+
+def _cal_dayofweek(
+    x: np.datetime64,
+):
+    return x.dayofweek / 6.0 - 0.5
+
+
+def _cal_dayofmonth(
+    x: np.datetime64,
+):
+    return x.day / 30.0 - 0.5
+
+
+def _cal_dayofyear(
+    x: np.datetime64,
+):
+    return x.dayofyear / 364.0 - 0.5
+
+
+def _cal_weekofyear(
+    x: np.datetime64,
+):
+    return x.weekofyear / 51.0 - 0.5
+
+
+def _cal_holiday(
+    x: np.datetime64,
+):
+    return float(chinese_calendar.is_holiday(x))
+
+
+def _cal_workday(
+    x: np.datetime64,
+):
+    return float(chinese_calendar.is_workday(x))
+
+
+def _cal_minuteofhour(
+    x: np.datetime64,
+):
+    return x.minute / 59 - 0.5
+
+
+def _cal_monthofyear(
+    x: np.datetime64,
+):
+    return x.month / 11.0 - 0.5
+
+
+CAL_DATE_METHOD = {
+    "year": _cal_year,
+    "month": _cal_month,
+    "day": _cal_day,
+    "hour": _cal_hour,
+    "weekday": _cal_weekday,
+    "quarter": _cal_quarter,
+    "minuteofhour": _cal_minuteofhour,
+    "monthofyear": _cal_monthofyear,
+    "hourofday": _cal_hourofday,
+    "dayofweek": _cal_dayofweek,
+    "dayofmonth": _cal_dayofmonth,
+    "dayofyear": _cal_dayofyear,
+    "weekofyear": _cal_weekofyear,
+    "is_holiday": _cal_holiday,
+    "is_workday": _cal_workday,
+}
+
+
+def load_from_one_dataframe(
+    data: Union[pd.DataFrame, pd.Series],
+    time_col: Optional[str] = None,
+    value_cols: Optional[Union[List[str], str]] = None,
+    freq: Optional[Union[str, int]] = None,
+    drop_tail_nan: bool = False,
+    dtype: Optional[Union[type, Dict[str, type]]] = None,
+) -> pd.DataFrame:
+    """Transforms a DataFrame or Series into a time-indexed DataFrame.
+
+    Args:
+        data (Union[pd.DataFrame, pd.Series]): The input data containing time series information.
+        time_col (Optional[str]): The column name representing time information. If None, uses the index.
+        value_cols (Optional[Union[List[str], str]]): Columns to extract as values. If None, uses all except time_col.
+        freq (Optional[Union[str, int]]): The frequency of the time series data.
+        drop_tail_nan (bool): If True, drop trailing NaN values from the data.
+        dtype (Optional[Union[type, Dict[str, type]]]): Enforce a specific data type on the resulting DataFrame.
+
+    Returns:
+        pd.DataFrame: A DataFrame with time as the index and specified value columns.
+
+    Raises:
+        ValueError: If the time column doesn't exist, or if frequency cannot be inferred.
+
+    """
+    # Initialize series_data with specified value columns or all except time_col
+    series_data = None
+    if value_cols is None:
+        if isinstance(data, pd.Series):
+            series_data = data.copy()
+        else:
+            series_data = data.loc[:, data.columns != time_col].copy()
+    else:
+        series_data = data.loc[:, value_cols].copy()
+
+    # Determine the time column values
+    if time_col:
+        if time_col not in data.columns:
+            raise ValueError(
+                "The time column: {} doesn't exist in the `data`!".format(time_col)
+            )
+        time_col_vals = data.loc[:, time_col]
+    else:
+        time_col_vals = data.index
+
+    # Handle integer-based time column values when frequency is a string
+    if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
+        time_col_vals = time_col_vals.astype(str)
+
+    # Process integer-based time column values
+    if np.issubdtype(time_col_vals.dtype, np.integer):
+        if freq:
+            if not isinstance(freq, int) or freq < 1:
+                raise ValueError(
+                    "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
+                )
+        else:
+            freq = 1  # Default frequency for integer index
+        start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
+        if (stop_idx - start_idx) / freq != len(data):
+            raise ValueError("The number of rows doesn't match with the RangeIndex!")
+        time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
+
+    # Process datetime-like time column values
+    elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
+        time_col_vals.dtype, np.datetime64
+    ):
+        time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
+        time_index = pd.DatetimeIndex(time_col_vals)
+        if freq:
+            if not isinstance(freq, str):
+                raise ValueError(
+                    "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
+                )
+        else:
+            # Attempt to infer frequency if not provided
+            freq = pd.infer_freq(time_index)
+            if freq is None:
+                raise ValueError(
+                    "Failed to infer the `freq`. A valid `freq` is required."
+                )
+            if freq[0] == "-":
+                freq = freq[1:]
+
+    # Raise error for unsupported time column types
+    else:
+        raise ValueError("The type of `time_col` is invalid.")
+
+    # Ensure series_data is a DataFrame
+    if isinstance(series_data, pd.Series):
+        series_data = series_data.to_frame()
+
+    # Set time index and sort data
+    series_data.set_index(time_index, inplace=True)
+    series_data.sort_index(inplace=True)
+    return series_data
+
+
+def load_from_dataframe(
+    df: pd.DataFrame,
+    group_id: Optional[str] = None,
+    time_col: Optional[str] = None,
+    target_cols: Optional[Union[List[str], str]] = None,
+    label_col: Optional[Union[List[str], str]] = None,
+    observed_cov_cols: Optional[Union[List[str], str]] = None,
+    feature_cols: Optional[Union[List[str], str]] = None,
+    known_cov_cols: Optional[Union[List[str], str]] = None,
+    static_cov_cols: Optional[Union[List[str], str]] = None,
+    freq: Optional[Union[str, int]] = None,
+    fill_missing_dates: bool = False,
+    fillna_method: str = "pre",
+    fillna_window_size: int = 10,
+    **kwargs,
+) -> Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]:
+    """Loads and processes time series data from a DataFrame.
+
+    This function extracts and organizes time series data from a given DataFrame.
+    It supports optional grouping and extraction of specific columns as features.
+
+    Args:
+        df (pd.DataFrame): The input DataFrame containing time series data.
+        group_id (Optional[str]): Column name used for grouping the data.
+        time_col (Optional[str]): Name of the time column.
+        target_cols (Optional[Union[List[str], str]]): Columns to be used as target.
+        label_col (Optional[Union[List[str], str]]): Columns to be used as label.
+        observed_cov_cols (Optional[Union[List[str], str]]): Columns for observed covariates.
+        feature_cols (Optional[Union[List[str], str]]): Columns to be used as features.
+        known_cov_cols (Optional[Union[List[str], str]]): Columns for known covariates.
+        static_cov_cols (Optional[Union[List[str], str]]): Columns for static covariates.
+        freq (Optional[Union[str, int]]): Frequency of the time series data.
+        fill_missing_dates (bool): Whether to fill missing dates in the time series.
+        fillna_method (str): Method to fill missing values ('pre' or 'post').
+        fillna_window_size (int): Window size for filling missing values.
+        **kwargs: Additional keyword arguments.
+
+    Returns:
+        Dict[str, Optional[Union[pd.DataFrame, Dict[str, any]]]]: A dictionary containing processed time series data.
+    """
+    # List to store DataFrames if grouping is applied
+    dfs = []
+
+    # Separate the DataFrame into groups if group_id is provided
+    if group_id is not None:
+        group_unique = df[group_id].unique()
+        for column in group_unique:
+            dfs.append(df[df[group_id].isin([column])])
+    else:
+        dfs = [df]
+
+    # Result list to store processed data from each group
+    res = []
+
+    # If label_col is provided, ensure it is a single column
+    if label_col:
+        if isinstance(label_col, str) and len(label_col) > 1:
+            raise ValueError("The length of label_col must be 1.")
+        target_cols = label_col
+
+    # If feature_cols is provided, treat it as observed_cov_cols
+    if feature_cols:
+        observed_cov_cols = feature_cols
+
+    # Process each DataFrame in the list
+    for df in dfs:
+        target = None
+        observed_cov = None
+        known_cov = None
+        static_cov = dict()
+
+        # If no specific columns are provided, use all columns except time_col
+        if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
+            target = load_from_one_dataframe(
+                df,
+                time_col,
+                [a for a in df.columns if a != time_col],
+                freq,
+            )
+        else:
+            if target_cols:
+                target = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    target_cols,
+                    freq,
+                )
+
+            if observed_cov_cols:
+                observed_cov = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    observed_cov_cols,
+                    freq,
+                )
+
+            if known_cov_cols:
+                known_cov = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    known_cov_cols,
+                    freq,
+                )
+
+            if static_cov_cols:
+                if isinstance(static_cov_cols, str):
+                    static_cov_cols = [static_cov_cols]
+                for col in static_cov_cols:
+                    if col not in df.columns or len(np.unique(df[col])) != 1:
+                        raise ValueError(
+                            "Static covariate columns data is not in columns or schema is not correct!"
+                        )
+                    static_cov[col] = df[col].iloc[0]
+        # Append the processed data into the results list
+        res.append(
+            {
+                "past_target": target,
+                "observed_cov_numeric": observed_cov,
+                "known_cov_numeric": known_cov,
+                "static_cov_numeric": static_cov,
+            }
+        )
+    # Return the first processed result
+    return res[0]
+
+
+def _distance_to_holiday(holiday) -> Callable[[pd.Timestamp], float]:
+    """Creates a function to calculate the distance in days to the nearest holiday.
+
+    This function generates a closure that computes the number of days from
+    a given date index to the nearest holiday within a defined window.
+
+    Args:
+        holiday: An object that provides a `dates` method, which returns the
+            dates of holidays within a specified range.
+
+    Returns:
+        Callable[[pd.Timestamp], float]: A function that takes a date index
+        as input and returns the distance in days to the nearest holiday.
+    """
+
+    def _distance_to_day(index: pd.Timestamp) -> float:
+        """Calculates the distance in days from a given date index to the nearest holiday.
+
+        Args:
+            index (pd.Timestamp): The date index for which the distance to the
+                nearest holiday should be calculated.
+
+        Returns:
+            float: The number of days to the nearest holiday.
+
+        Raises:
+            AssertionError: If no holiday is found within the specified window.
+        """
+        holiday_date = holiday.dates(
+            index - pd.Timedelta(days=MAX_WINDOW),
+            index + pd.Timedelta(days=MAX_WINDOW),
+        )
+        assert (
+            len(holiday_date) != 0
+        ), f"No closest holiday for the date index {index} found."
+        # It sometimes returns two dates if it is exactly half a year after the
+        # holiday. In this case, the smaller distance (182 days) is returned.
+        return float((index - holiday_date[0]).days)
+
+    return _distance_to_day
+
+
+def time_feature(
+    dataset: Dict,
+    freq: Optional[Union[str, int]],
+    feature_cols: List[str],
+    extend_points: int,
+    inplace: bool = False,
+) -> Dict:
+    """Transforms the time column of a dataset into time features.
+
+    This function extracts time-related features from the time column in a
+    dataset, optionally extending the time series for future points and
+    normalizing holiday distances.
+
+    Args:
+        dataset (Dict): Dataset to be transformed.
+        freq: Optional[Union[str, int]]: Frequency of the time series data. If not provided,
+            the frequency will be inferred.
+        feature_cols (List[str]): List of feature columns to be extracted.
+        extend_points (int): Number of future points to extend the time series.
+        inplace (bool): Whether to perform the transformation inplace. Default is False.
+
+    Returns:
+        Dict: The transformed dataset with time features added.
+
+    Raises:
+        ValueError: If the time column is of an integer type instead of datetime.
+    """
+    new_ts = dataset
+    if not inplace:
+        new_ts = dataset.copy()
+    # Get known_cov_numeric or initialize with past target index
+    kcov = new_ts["known_cov_numeric"]
+    if not kcov:
+        tf_kcov = new_ts["past_target"].index.to_frame()
+    else:
+        tf_kcov = kcov.index.to_frame()
+    time_col = tf_kcov.columns[0]
+    # Check if time column is of datetime type
+    if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
+        raise ValueError(
+            "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
+        )
+    # Extend the time series if no known_cov_numeric
+    if not kcov:
+        freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
+        extend_time = pd.date_range(
+            start=tf_kcov[time_col][-1],
+            freq=freq,
+            periods=extend_points + 1,
+            closed="right",
+            name=time_col,
+        ).to_frame()
+        tf_kcov = pd.concat([tf_kcov, extend_time])
+
+    # Extract and add time features to known_cov_numeric
+    for k in feature_cols:
+        if k != "holidays":
+            v = tf_kcov[time_col].apply(lambda x: CAL_DATE_METHOD[k](x))
+            v.index = tf_kcov[time_col]
+
+            if new_ts["known_cov_numeric"] is None:
+                new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
+            else:
+                new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
+                    new_ts["known_cov_numeric"].index
+                )
+        else:
+            holidays_col = []
+            for i, H in enumerate(HOLIDAYS):
+                v = tf_kcov[time_col].apply(_distance_to_holiday(H))
+                v.index = tf_kcov[time_col]
+                holidays_col.append(k + "_" + str(i))
+                if new_ts["known_cov_numeric"] is None:
+                    new_ts["known_cov_numeric"] = pd.DataFrame(
+                        v.rename(k + "_" + str(i)), index=v.index
+                    )
+                else:
+                    new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
+                        new_ts["known_cov_numeric"].index
+                    )
+
+            scaler = StandardScaler()
+            scaler.fit(new_ts["known_cov_numeric"][holidays_col])
+            new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
+                new_ts["known_cov_numeric"][holidays_col]
+            )
+    return new_ts

+ 313 - 0
paddlex/inference/models_new/common/ts/processors.py

@@ -0,0 +1,313 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Dict, Any
+from pathlib import Path
+from copy import deepcopy
+import joblib
+import numpy as np
+import pandas as pd
+
+from .funcs import load_from_dataframe, time_feature
+
+
+__all__ = [
+    "BuildTSDataset",
+    "TSCutOff",
+    "TSNormalize",
+    "TimeFeature",
+    "TStoArray",
+    "TStoBatch",
+]
+
+
+class TSCutOff:
+    """Truncates time series data to a specified length for training.
+
+    This class provides a method to truncate or cut off time series data
+    to a specified input length, optionally skipping some initial data
+    points. This is useful for preparing data for training models that
+    require a fixed input size.
+    """
+
+    def __init__(self, size: Dict[str, int]):
+        """Initializes the TSCutOff with size configurations.
+
+        Args:
+            size (Dict[str, int]): Dictionary containing size configurations,
+                including 'in_chunk_len' for the input chunk length and
+                optionally 'skip_chunk_len' for the number of initial data
+                points to skip.
+        """
+        super().__init__()
+        self.size = size
+
+    def __call__(self, ts_list: List) -> List:
+        """Applies the cut off operation to a list of time series.
+
+        Args:
+            ts_list (List): List of time series data frames to be truncated.
+
+        Returns:
+            List: List of truncated time series data frames.
+        """
+        return [self.cutoff(ts) for ts in ts_list]
+
+    def cutoff(self, ts: Any) -> Any:
+        """Truncates a single time series data frame to the specified length.
+
+        This method truncates the time series data to the specified input
+        chunk length, optionally skipping some initial data points. It raises
+        a ValueError if the time series is too short.
+
+        Args:
+            ts: A single time series data frame to be truncated.
+
+        Returns:
+            Any: The truncated time series data frame.
+
+        Raises:
+            ValueError: If the time series length is less than the required
+            minimum length (input chunk length plus any skip chunk length).
+        """
+        skip_len = self.size.get("skip_chunk_len", 0)
+        if len(ts) < self.size["in_chunk_len"] + skip_len:
+            raise ValueError(
+                f"The length of the input data is {len(ts)}, but it should be at least {self.size['in_chunk_len'] + self.size['skip_chunk_len']} for training."
+            )
+        ts_data = ts[-(self.size["in_chunk_len"] + skip_len) :]
+        return ts_data
+
+
+class TSNormalize:
+    """Normalizes time series data using a pre-fitted scaler.
+
+    This class normalizes specified columns of time series data using a
+    pre-fitted scaler loaded from a specified path. It supports normalization
+    of both target and feature columns as specified in the parameters.
+    """
+
+    def __init__(self, scale_path: str, params_info: Dict[str, Any]):
+        """Initializes the TSNormalize with a scaler and normalization parameters.
+
+        Args:
+            scale_path (str): Path to the pre-fitted scaler object file.
+            params_info (Dict[str, Any]): Dictionary containing information
+                about which columns to normalize, including 'target_cols'
+                and 'feature_cols'.
+        """
+        super().__init__()
+        self.scaler = joblib.load(scale_path)
+        self.params_info = params_info
+
+    def __call__(self, ts_list: List[pd.DataFrame]) -> List[pd.DataFrame]:
+        """Applies normalization to a list of time series data frames.
+
+        Args:
+            ts_list (List[pd.DataFrame]): List of time series data frames to be normalized.
+
+        Returns:
+            List[pd.DataFrame]: List of normalized time series data frames.
+        """
+        return [self.tsnorm(ts) for ts in ts_list]
+
+    def tsnorm(self, ts: pd.DataFrame) -> pd.DataFrame:
+        """Normalizes specified columns of a single time series data frame.
+
+        This method applies the scaler to normalize the specified target
+        and feature columns of the time series.
+
+        Args:
+            ts (pd.DataFrame): A single time series data frame to be normalized.
+
+        Returns:
+            pd.DataFrame: The normalized time series data frame.
+        """
+        if self.params_info.get("target_cols", None) is not None:
+            ts[self.params_info["target_cols"]] = self.scaler.transform(
+                ts[self.params_info["target_cols"]]
+            )
+        if self.params_info.get("feature_cols", None) is not None:
+            ts[self.params_info["feature_cols"]] = self.scaler.transform(
+                ts[self.params_info["feature_cols"]]
+            )
+        return ts
+
+
+class BuildTSDataset:
+    """Constructs a time series dataset from a list of time series data frames."""
+
+    def __init__(self, params_info: Dict[str, Any]):
+        """Initializes the BuildTSDataset with parameters for dataset construction.
+
+        Args:
+            params_info (Dict[str, Any]): Dictionary containing parameters for
+                constructing the time series dataset.
+        """
+        super().__init__()
+        self.params_info = params_info
+
+    def __call__(self, ts_list: List) -> List:
+        """Applies the dataset construction to a list of time series.
+
+        Args:
+            ts_list (List): List of time series data frames.
+
+        Returns:
+            List: List of constructed time series datasets.
+        """
+        return [self.buildtsdata(ts) for ts in ts_list]
+
+    def buildtsdata(self, ts) -> Any:
+        """Builds a time series dataset from a single time series data frame.
+
+        Args:
+            ts: A single time series data frame.
+
+        Returns:
+            Any: A constructed time series dataset.
+        """
+        ts_data = load_from_dataframe(ts, **self.params_info)
+        return ts_data
+
+
+class TimeFeature:
+    """Extracts time features from time series data for forecasting."""
+
+    def __init__(
+        self, params_info: Dict[str, Any], size: Dict[str, int], holiday: bool = False
+    ):
+        """Initializes the TimeFeature extractor.
+
+        Args:
+            params_info (Dict[str, Any]): Dictionary containing frequency information.
+            size (Dict[str, int]): Dictionary containing the output chunk length.
+            holiday (bool, optional): Whether to include holiday features. Defaults to False.
+        """
+        super().__init__()
+        self.freq = params_info["freq"]
+        self.size = size
+        self.holiday = holiday
+
+    def __call__(self, ts_list: List) -> List:
+        """Applies time feature extraction to a list of time series.
+
+        Args:
+            ts_list (List): List of time series data frames.
+
+        Returns:
+            List: List of time series with extracted time features.
+        """
+        return [self.timefeat(ts) for ts in ts_list]
+
+    def timefeat(self, ts: Dict[str, Any]) -> Any:
+        """Extracts time features from a single time series data frame.
+
+        Args:
+            ts: A single time series data frame.
+
+        Returns:
+            Any: The time series with added time features.
+        """
+        if not self.holiday:
+            ts = time_feature(
+                ts,
+                self.freq,
+                ["hourofday", "dayofmonth", "dayofweek", "dayofyear"],
+                self.size["out_chunk_len"],
+            )
+        else:
+            ts = time_feature(
+                ts,
+                self.freq,
+                [
+                    "minuteofhour",
+                    "hourofday",
+                    "dayofmonth",
+                    "dayofweek",
+                    "dayofyear",
+                    "monthofyear",
+                    "weekofyear",
+                    "holidays",
+                ],
+                self.size["out_chunk_len"],
+            )
+        return ts
+
+
+class TStoArray:
+    """Converts time series data into arrays for model input."""
+
+    def __init__(self, input_data: Dict[str, Any]):
+        """Initializes the TStoArray converter.
+
+        Args:
+            input_data (Dict[str, Any]): Dictionary specifying the input data format.
+        """
+        super().__init__()
+        self.input_data = input_data
+
+    def __call__(self, ts_list: List[Dict[str, Any]]) -> List[List[np.ndarray]]:
+        """Converts a list of time series data frames into arrays.
+
+        Args:
+            ts_list (List[Dict[str, Any]]): List of time series data frames.
+
+        Returns:
+            List[List[np.ndarray]]: List of lists of arrays for each time series.
+        """
+        return [self.tstoarray(ts) for ts in ts_list]
+
+    def tstoarray(self, ts: Dict[str, Any]) -> List[np.ndarray]:
+        """Converts a single time series data frame into arrays.
+
+        Args:
+            ts (Dict[str, Any]): A single time series data frame.
+
+        Returns:
+            List[np.ndarray]: List of arrays representing the time series data.
+        """
+        ts_list = []
+        input_name = list(self.input_data.keys())
+        input_name.sort()
+        for key in input_name:
+            ts_list.append(np.array(ts[key]).astype("float32"))
+
+        return ts_list
+
+
+class TStoBatch:
+    """Convert a list of time series into batches for processing.
+
+    This class provides a method to convert a list of time series data into
+    batches. Each time series in the list is assumed to be a sequence of
+    equal-length arrays or DataFrames.
+    """
+
+    def __call__(self, ts_list: List[np.ndarray]) -> List[np.ndarray]:
+        """Convert a list of time series into batches.
+
+        This method stacks time series data along a new axis to create batches.
+        It assumes that each time series in the list has the same length.
+
+        Args:
+            ts_list (List[np.ndarray]): A list of time series, where each
+                time series is represented as a list or array of equal length.
+
+        Returns:
+            List[np.ndarray]: A list of batches, where each batch is a stacked
+            array of time series data at the same index across all series.
+        """
+        n = len(ts_list[0])
+        return [np.stack([ts[i] for ts in ts_list], axis=0) for i in range(n)]

+ 15 - 0
paddlex/inference/models_new/ts_anomaly/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import TSAdPredictor

+ 146 - 0
paddlex/inference/models_new/ts_anomaly/predictor.py

@@ -0,0 +1,146 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+import pandas as pd
+import os
+
+from ....modules.ts_anomaly_detection.model_list import MODELS
+from ...common.batch_sampler import TSBatchSampler
+from ...common.reader import ReadTS
+from ..common import (
+    TSCutOff,
+    BuildTSDataset,
+    TSNormalize,
+    TimeFeature,
+    TStoArray,
+    TStoBatch,
+    StaticInfer,
+)
+from .processors import GetAnomaly
+from ..base import BasicPredictor
+from .result import TSAdResult
+
+
+class TSAdPredictor(BasicPredictor):
+    """TSAdPredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        """Initializes TSAdPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> TSBatchSampler:
+        """Builds and returns an ImageBatchSampler instance.
+
+        Returns:
+            ImageBatchSampler: An instance of ImageBatchSampler.
+        """
+        return TSBatchSampler()
+
+    def _get_result_class(self) -> type:
+        """Returns the result class, TopkResult.
+
+        Returns:
+            type: The TopkResult class.
+        """
+        return TSAdResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        preprocessors = {
+            "ReadTS": ReadTS(),
+            "TSCutOff": TSCutOff(self.config["size"]),
+        }
+
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            preprocessors["TSNormalize"] = TSNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+
+        preprocessors["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+
+        if self.config.get("time_feat", None):
+            preprocessors["TimeFeature"] = TimeFeature(
+                self.config["info_params"],
+                self.config["size"],
+                self.config["holiday"],
+            )
+        preprocessors["TStoArray"] = TStoArray(self.config["input_data"])
+        preprocessors["TStoBatch"] = TStoBatch()
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        postprocessors = {}
+        postprocessors["GetAnomaly"] = GetAnomaly(
+            self.config["model_threshold"], self.config["info_params"]
+        )
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, pd.DataFrame], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
+        """
+
+        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+        batch_cutoff_ts = self.preprocessors["TSCutOff"](ts_list=batch_raw_ts)
+
+        if "TSNormalize" in self.preprocessors:
+            batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_cutoff_ts)
+            batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_ts)
+        else:
+            batch_input_ts = self.preprocessors["BuildTSDataset"](
+                ts_list=batch_cutoff_ts
+            )
+
+        if "TimeFeature" in self.preprocessors:
+            batch_ts = self.preprocessors["TimeFeature"](ts_list=batch_input_ts)
+            batch_ts = self.preprocessors["TStoArray"](ts_list=batch_ts)
+        else:
+            batch_ts = self.preprocessors["TStoArray"](ts_list=batch_input_ts)
+
+        x = self.preprocessors["TStoBatch"](ts_list=batch_ts)
+        batch_preds = self.infer(x=x)
+
+        batch_ts_preds = self.postprocessors["GetAnomaly"](
+            ori_ts_list=batch_input_ts, pred_list=batch_preds
+        )
+        return {
+            "input_path": batch_data,
+            "input_ts": batch_raw_ts,
+            "anomaly": batch_ts_preds,
+        }

+ 94 - 0
paddlex/inference/models_new/ts_anomaly/processors.py

@@ -0,0 +1,94 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Dict, Any
+import numpy as np
+import pandas as pd
+
+
+class GetAnomaly:
+    """A class to detect anomalies in time series data based on a model threshold."""
+
+    def __init__(self, model_threshold: float, info_params: Dict[str, Any]):
+        """
+        Initializes the GetAnomaly class with a model threshold and parameters information.
+
+        Args:
+            model_threshold (float): The threshold for determining anomalies.
+            info_params (Dict[str, Any]): Configuration parameters including target columns and time column name.
+        """
+        super().__init__()
+        self.model_threshold = model_threshold
+        self.info_params = info_params
+
+    def __call__(
+        self, ori_ts_list: List[Dict[str, Any]], pred_list: List[np.ndarray]
+    ) -> List[pd.DataFrame]:
+        """
+        Detects anomalies for a list of time series predictions.
+
+        Args:
+            ori_ts_list (List[Dict[str, Any]]): Original time series data for each prediction, including past and covariate information.
+            pred_list (List[np.ndarray]): List of prediction arrays corresponding to each time series in ori_ts_list.
+
+        Returns:
+            List[pd.DataFrame]: A list of DataFrames, each containing anomaly labels for the time series.
+        """
+        return [
+            self.getanomaly(ori_ts, pred)
+            for ori_ts, pred in zip(ori_ts_list, pred_list)
+        ]
+
+    def getanomaly(self, ori_ts: Dict[str, Any], pred: np.ndarray) -> pd.DataFrame:
+        """
+        Detects anomalies in a single time series prediction.
+
+        Args:
+            ori_ts (Dict[str, Any]): Original time series data for a single time series.
+            pred (np.ndarray): Prediction array for the given time series.
+
+        Returns:
+            pd.DataFrame: A DataFrame containing anomaly labels for the time series.
+
+        Raises:
+            ValueError: If none of the expected keys are found in ori_ts.
+        """
+        pred = pred[0]
+        if ori_ts.get("past_target", None) is not None:
+            ts = ori_ts["past_target"]
+        elif ori_ts.get("observed_cov_numeric", None) is not None:
+            ts = ori_ts["observed_cov_numeric"]
+        elif ori_ts.get("known_cov_numeric", None) is not None:
+            ts = ori_ts["known_cov_numeric"]
+        elif ori_ts.get("static_cov_numeric", None) is not None:
+            ts = ori_ts["static_cov_numeric"]
+        else:
+            raise ValueError("No value in ori_ts")
+        column_name = (
+            self.info_params["target_cols"]
+            if "target_cols" in self.info_params
+            else self.info_params["feature_cols"]
+        )
+
+        anomaly_score = np.mean(np.square(pred - np.array(ts)), axis=-1)
+        anomaly_label = (anomaly_score >= self.model_threshold) + 0
+
+        past_target_index = ts.index
+        past_target_index.name = self.info_params["time_col"]
+        anomaly_label_df = pd.DataFrame(
+            np.reshape(anomaly_label, newshape=[pred.shape[0], -1]),
+            index=past_target_index,
+            columns=["label"],
+        )
+        return anomaly_label_df

+ 29 - 0
paddlex/inference/models_new/ts_anomaly/result.py

@@ -0,0 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from ...common.result import BaseTSResult
+
+
+class TSAdResult(BaseTSResult):
+    """A class representing the result of a time series anomaly detection task."""
+
+    def _to_csv(self) -> Any:
+        """
+        Converts the anomaly detection results to a CSV format.
+
+        Returns:
+            Any: The anomaly data formatted for CSV output, typically a DataFrame or similar structure.
+        """
+        return self["anomaly"]

+ 15 - 0
paddlex/inference/models_new/ts_classify/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import TSClsPredictor

+ 131 - 0
paddlex/inference/models_new/ts_classify/predictor.py

@@ -0,0 +1,131 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+import pandas as pd
+import os
+
+from ....modules.ts_classification.model_list import MODELS
+from ...common.batch_sampler import TSBatchSampler
+from ...common.reader import ReadTS
+from ..common import (
+    TSCutOff,
+    BuildTSDataset,
+    TSNormalize,
+    TimeFeature,
+    TStoArray,
+    TStoBatch,
+    StaticInfer,
+)
+
+from .processors import GetCls, BuildPadMask
+from ..base import BasicPredictor
+from .result import TSClsResult
+
+
+class TSClsPredictor(BasicPredictor):
+    """TSClsPredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        """Initializes TSClsPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> TSBatchSampler:
+        """Builds and returns an TSBatchSampler instance.
+
+        Returns:
+            TSBatchSampler: An instance of TSBatchSampler.
+        """
+        return TSBatchSampler()
+
+    def _get_result_class(self) -> type:
+        """Returns the result class.
+
+        Returns:
+            type: The Result class.
+        """
+        return TSClsResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        preprocessors = {
+            "ReadTS": ReadTS(),
+            "TSCutOff": TSCutOff(self.config["size"]),
+        }
+
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            preprocessors["TSNormalize"] = TSNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+
+        preprocessors["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+        preprocessors["BuildPadMask"] = BuildPadMask(self.config["input_data"])
+        preprocessors["TStoArray"] = TStoArray(self.config["input_data"])
+        preprocessors["TStoBatch"] = TStoBatch()
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        postprocessors = {}
+        postprocessors["GetCls"] = GetCls()
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
+        """
+        Processes a batch of time series data through a series of preprocessing, inference, and postprocessing steps.
+
+        Args:
+            batch_data (List[Union[str, pd.DataFrame]]): A list of paths or identifiers for the batch of time series data to be processed.
+
+        Returns:
+            Dict[str, Any]: A dictionary containing the paths to the input data, the raw input time series, and the classification results.
+        """
+        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+
+        if "TSNormalize" in self.preprocessors:
+            batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_raw_ts)
+            batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_ts)
+        else:
+            batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_raw_ts)
+
+        batch_input_ts = self.preprocessors["BuildPadMask"](ts_list=batch_input_ts)
+        batch_ts = self.preprocessors["TStoArray"](ts_list=batch_input_ts)
+
+        x = self.preprocessors["TStoBatch"](ts_list=batch_ts)
+        batch_preds = self.infer(x=x)
+
+        batch_ts_preds = self.postprocessors["GetCls"](pred_list=batch_preds)
+
+        return {
+            "input_path": batch_data,
+            "input_ts": batch_raw_ts,
+            "classification": batch_ts_preds,
+        }

+ 117 - 0
paddlex/inference/models_new/ts_classify/processors.py

@@ -0,0 +1,117 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pandas as pd
+from typing import List, Any, Dict
+
+
+class GetCls:
+    """A class to process prediction outputs and return class IDs and scores."""
+
+    def __init__(self):
+        """Initializes the GetCls instance."""
+        super().__init__()
+
+    def __call__(self, pred_list: List[Any]) -> List[pd.DataFrame]:
+        """
+        Processes a list of predictions and returns a list of DataFrames with class IDs and scores.
+
+        Args:
+            pred_list (List[Any]): A list of predictions, where each prediction is expected to be an iterable of arrays.
+
+        Returns:
+            List[pd.DataFrame]: A list of DataFrames, each containing the class ID and score for the corresponding prediction.
+        """
+        return [self.getcls(pred) for pred in pred_list]
+
+    def getcls(self, pred: Any) -> pd.DataFrame:
+        """
+        Computes the class ID and score from a single prediction.
+
+        Args:
+            pred (Any): A prediction, expected to be an iterable where the first element is an array representing logits or probabilities.
+
+        Returns:
+            pd.DataFrame: A DataFrame containing the class ID and score for the prediction.
+        """
+        pred_ts = pred[0]
+        pred_ts -= np.max(pred_ts, axis=-1, keepdims=True)
+        pred_ts = np.exp(pred_ts) / np.sum(np.exp(pred_ts), axis=-1, keepdims=True)
+        classid = np.argmax(pred_ts, axis=-1)
+        pred_score = pred_ts[classid]
+        result = pd.DataFrame.from_dict({"classid": [classid], "score": [pred_score]})
+        result.index.name = "sample"
+        return result
+
+
+class BuildPadMask:
+    """A class to build padding masks for time series data."""
+
+    def __init__(self, input_data: Dict[str, Any]):
+        """
+        Initializes the BuildPadMask instance.
+
+        Args:
+            input_data (Dict[str, Any]): A dictionary containing configuration data, including 'features'
+                                         and 'pad_mask' keys that influence how padding is applied.
+        """
+        super().__init__()
+        self.input_data = input_data
+
+    def __call__(self, ts_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+        """
+        Applies padding mask to a list of time series data.
+
+        Args:
+            ts_list (List[Dict[str, Any]]): A list of dictionaries, each representing a time series instance
+                                            with keys like 'features' and 'past_target'.
+
+        Returns:
+            List[Dict[str, Any]]: A list of dictionaries with updated 'features' and 'pad_mask' keys.
+        """
+        return [self.padmask(ts) for ts in ts_list]
+
+    def padmask(self, ts: Dict[str, Any]) -> Dict[str, Any]:
+        """
+        Builds a padding mask for a single time series instance.
+
+        Args:
+            ts (Dict[str, Any]): A dictionary representing a time series instance, expected to have keys
+                                 like 'features' and 'past_target'.
+
+        Returns:
+            Dict[str, Any]: The input dictionary with potentially updated 'features' and 'pad_mask' keys.
+        """
+        if "features" in self.input_data:
+            ts["features"] = ts["past_target"]
+
+        if "pad_mask" in self.input_data:
+            target_dim = len(ts["features"])
+            max_length = self.input_data["pad_mask"][-1]
+            if max_length > 0:
+                ones = np.ones(max_length, dtype=np.int32)
+                if max_length != target_dim:
+                    target_ndarray = np.array(ts["features"]).astype(np.float32)
+                    target_ndarray_final = np.zeros(
+                        [max_length, target_dim], dtype=np.int32
+                    )
+                    end = min(target_dim, max_length)
+                    target_ndarray_final[:end, :] = target_ndarray
+                    ts["features"] = target_ndarray_final
+                    ones[end:] = 0.0
+                    ts["pad_mask"] = ones
+                else:
+                    ts["pad_mask"] = ones
+        return ts

+ 29 - 0
paddlex/inference/models_new/ts_classify/result.py

@@ -0,0 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from ...common.result import BaseTSResult
+
+
+class TSClsResult(BaseTSResult):
+    """A class representing the result of a time series classification task."""
+
+    def _to_csv(self) -> Any:
+        """
+        Converts the classification results to a CSV format.
+
+        Returns:
+            Any: The classification data formatted for CSV output, typically a DataFrame or similar structure.
+        """
+        return self["classification"]

+ 15 - 0
paddlex/inference/models_new/ts_forecast/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import TSFcPredictor

+ 156 - 0
paddlex/inference/models_new/ts_forecast/predictor.py

@@ -0,0 +1,156 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+import pandas as pd
+import os
+
+from ....modules.ts_forecast.model_list import MODELS
+from ...common.batch_sampler import TSBatchSampler
+from ...common.reader import ReadTS
+from ..common import (
+    TSCutOff,
+    BuildTSDataset,
+    TSNormalize,
+    TimeFeature,
+    TStoArray,
+    TStoBatch,
+    StaticInfer,
+)
+from .processors import ArraytoTS, TSDeNormalize
+from ..base import BasicPredictor
+from .result import TSFcResult
+
+
+class TSFcPredictor(BasicPredictor):
+    """TSFcPredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        """Initializes TSFcPredictor.
+
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> TSBatchSampler:
+        """Builds and returns an ImageBatchSampler instance.
+
+        Returns:
+            ImageBatchSampler: An instance of ImageBatchSampler.
+        """
+        return TSBatchSampler()
+
+    def _get_result_class(self) -> type:
+        """Returns the result class, TopkResult.
+
+        Returns:
+            type: The TopkResult class.
+        """
+        return TSFcResult
+
+    def _build(self) -> Tuple:
+        """Build the preprocessors, inference engine, and postprocessors based on the configuration.
+
+        Returns:
+            tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
+        """
+        preprocessors = {
+            "ReadTS": ReadTS(),
+            "TSCutOff": TSCutOff(self.config["size"]),
+        }
+
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            preprocessors["TSNormalize"] = TSNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+
+        preprocessors["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+
+        if self.config.get("time_feat", None):
+            preprocessors["TimeFeature"] = TimeFeature(
+                self.config["info_params"],
+                self.config["size"],
+                self.config["holiday"],
+            )
+        preprocessors["TStoArray"] = TStoArray(self.config["input_data"])
+        preprocessors["TStoBatch"] = TStoBatch()
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        postprocessors = {}
+        postprocessors["ArraytoTS"] = ArraytoTS(self.config["info_params"])
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            postprocessors["TSDeNormalize"] = TSDeNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, pd.DataFrame], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
+        """
+
+        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+        batch_cutoff_ts = self.preprocessors["TSCutOff"](ts_list=batch_raw_ts)
+
+        if "TSNormalize" in self.preprocessors:
+            batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_cutoff_ts)
+            batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_ts)
+        else:
+            batch_input_ts = self.preprocessors["BuildTSDataset"](
+                ts_list=batch_cutoff_ts
+            )
+
+        if "TimeFeature" in self.preprocessors:
+            batch_ts = self.preprocessors["TimeFeature"](ts_list=batch_input_ts)
+            batch_ts = self.preprocessors["TStoArray"](ts_list=batch_ts)
+        else:
+            batch_ts = self.preprocessors["TStoArray"](ts_list=batch_input_ts)
+
+        x = self.preprocessors["TStoBatch"](ts_list=batch_ts)
+        batch_preds = self.infer(x=x)
+
+        batch_ts_preds = self.postprocessors["ArraytoTS"](
+            ori_ts_list=batch_input_ts, pred_list=batch_preds
+        )
+        if "TSDeNormalize" in self.postprocessors:
+            batch_ts_preds = self.postprocessors["TSDeNormalize"](
+                preds_list=batch_ts_preds
+            )
+
+        return {
+            "input_path": batch_data,
+            "input_ts": batch_raw_ts,
+            "forecast": batch_ts_preds,
+        }

+ 149 - 0
paddlex/inference/models_new/ts_forecast/processors.py

@@ -0,0 +1,149 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Dict, Any, Union
+import joblib
+import numpy as np
+import pandas as pd
+
+
+class TSDeNormalize:
+    """A class to de-normalize time series prediction data using a pre-fitted scaler."""
+
+    def __init__(self, scale_path: str, params_info: dict):
+        """
+        Initializes the TSDeNormalize class with a scaler and parameters information.
+
+        Args:
+            scale_path (str): The file path to the serialized scaler object.
+            params_info (dict): Additional parameters information.
+        """
+        super().__init__()
+        self.scaler = joblib.load(scale_path)
+        self.params_info = params_info
+
+    def __call__(self, preds_list: List[pd.DataFrame]) -> List[pd.DataFrame]:
+        """
+        Applies de-normalization to a list of prediction DataFrames.
+
+        Args:
+            preds_list (List[pd.DataFrame]): A list of DataFrames containing normalized prediction data.
+
+        Returns:
+            List[pd.DataFrame]: A list of DataFrames with de-normalized prediction data.
+        """
+        return [self.tsdenorm(pred) for pred in preds_list]
+
+    def tsdenorm(self, pred: pd.DataFrame) -> pd.DataFrame:
+        """
+        De-normalizes a single prediction DataFrame.
+
+        Args:
+            pred (pd.DataFrame): A DataFrame containing normalized prediction data.
+
+        Returns:
+            pd.DataFrame: A DataFrame with de-normalized prediction data.
+        """
+        scale_cols = pred.columns.values.tolist()
+        pred[scale_cols] = self.scaler.inverse_transform(pred[scale_cols])
+        return pred
+
+
+class ArraytoTS:
+    """A class to convert arrays of predictions into time series format."""
+
+    def __init__(self, info_params: Dict[str, Any]):
+        """
+        Initializes the ArraytoTS class with the given parameters.
+
+        Args:
+            info_params (Dict[str, Any]): Configuration parameters including target columns, frequency, and time column name.
+        """
+        super().__init__()
+        self.info_params = info_params
+
+    def __call__(
+        self, ori_ts_list: List[Dict[str, Any]], pred_list: List[np.ndarray]
+    ) -> List[pd.DataFrame]:
+        """
+        Converts a list of arrays to a list of time series DataFrames.
+
+        Args:
+            ori_ts_list (List[Dict[str, Any]]): Original time series data for each prediction, including past and covariate information.
+            pred_list (List[np.ndarray]): List of prediction arrays corresponding to each time series in ori_ts_list.
+
+        Returns:
+            List[pd.DataFrame]: A list of DataFrames, each representing the forecasted time series.
+        """
+        return [
+            self.arraytots(ori_ts, pred) for ori_ts, pred in zip(ori_ts_list, pred_list)
+        ]
+
+    def arraytots(self, ori_ts: Dict[str, Any], pred: np.ndarray) -> pd.DataFrame:
+        """
+        Converts a single array prediction to a time series DataFrame.
+
+        Args:
+            ori_ts (Dict[str, Any]): Original time series data for a single time series.
+            pred (np.ndarray): Prediction array for the given time series.
+
+        Returns:
+            pd.DataFrame: A DataFrame representing the forecasted time series.
+
+        Raises:
+            ValueError: If none of the expected keys are found in ori_ts.
+        """
+        pred = pred[0]
+        if ori_ts.get("past_target", None) is not None:
+            ts = ori_ts["past_target"]
+        elif ori_ts.get("observed_cov_numeric", None) is not None:
+            ts = ori_ts["observed_cov_numeric"]
+        elif ori_ts.get("known_cov_numeric", None) is not None:
+            ts = ori_ts["known_cov_numeric"]
+        elif ori_ts.get("static_cov_numeric", None) is not None:
+            ts = ori_ts["static_cov_numeric"]
+        else:
+            raise ValueError("No value in ori_ts")
+
+        column_name = (
+            self.info_params["target_cols"]
+            if "target_cols" in self.info_params
+            else self.info_params["feature_cols"]
+        )
+        if isinstance(self.info_params["freq"], str):
+            past_target_index = ts.index
+            if past_target_index.freq is None:
+                past_target_index.freq = pd.infer_freq(ts.index)
+            future_target_index = pd.date_range(
+                past_target_index[-1] + past_target_index.freq,
+                periods=pred.shape[0],
+                freq=self.info_params["freq"],
+                name=self.info_params["time_col"],
+            )
+        elif isinstance(self.info_params["freq"], int):
+            start_idx = max(ts.index) + 1
+            stop_idx = start_idx + pred.shape[0]
+            future_target_index = pd.RangeIndex(
+                start=start_idx,
+                stop=stop_idx,
+                step=self.info_params["freq"],
+                name=self.info_params["time_col"],
+            )
+
+        future_target = pd.DataFrame(
+            np.reshape(pred, newshape=[pred.shape[0], -1]),
+            index=future_target_index,
+            columns=column_name,
+        )
+        return future_target

+ 29 - 0
paddlex/inference/models_new/ts_forecast/result.py

@@ -0,0 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from ...common.result import BaseTSResult
+
+
+class TSFcResult(BaseTSResult):
+    """A class representing the result of a time series forecasting task."""
+
+    def _to_csv(self) -> Any:
+        """
+        Converts the forecasting results to a CSV format.
+
+        Returns:
+            Any: The forecast data formatted for CSV output, typically a DataFrame or similar structure.
+        """
+        return self["forecast"]