Explorar el Código

support ts_fc & ts_cls

gaotingquan hace 1 año
padre
commit
fe7f912d44

+ 6 - 1
paddlex/inference/components/paddle_predictor/__init__.py

@@ -12,4 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .predictor import ImagePredictor, ImageDetPredictor, ImageInstanceSegPredictor
+from .predictor import (
+    ImagePredictor,
+    ImageDetPredictor,
+    ImageInstanceSegPredictor,
+    TSPPPredictor,
+)

+ 13 - 2
paddlex/inference/components/paddle_predictor/predictor.py

@@ -24,9 +24,7 @@ from ....utils import logging
 class BasePaddlePredictor(BaseComponent):
     """Predictor based on Paddle Inference"""
 
-    INPUT_KEYS = "batch_data"
     OUTPUT_KEYS = "pred"
-    DEAULT_INPUTS = {"batch_data": "batch_data"}
     DEAULT_OUTPUTS = {"pred": "pred"}
     ENABLE_BATCH = True
 
@@ -172,6 +170,8 @@ No need to generate again."
 
 
 class ImagePredictor(BasePaddlePredictor):
+
+    INPUT_KEYS = "img"
     DEAULT_INPUTS = {"img": "img"}
 
     def to_batch(self, img):
@@ -228,3 +228,14 @@ class ImageInstanceSegPredictor(ImageDetPredictor):
         "img_size": "img_size",
     }
     DEAULT_OUTPUTS = {"boxes": "boxes", "masks": "masks"}
+
+
+class TSPPPredictor(BasePaddlePredictor):
+
+    INPUT_KEYS = "ts"
+    DEAULT_INPUTS = {"ts": "ts"}
+
+    def to_batch(self, ts):
+        n = len(ts[0])
+        x = [np.stack([lst[i] for lst in ts], axis=0) for i in range(n)]
+        return x

+ 1 - 0
paddlex/inference/components/transforms/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .image import *
+from .ts import *

+ 15 - 0
paddlex/inference/components/transforms/ts/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ts_common import *

+ 351 - 0
paddlex/inference/components/transforms/ts/ts_common.py

@@ -0,0 +1,351 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+from copy import deepcopy
+import joblib
+import numpy as np
+import pandas as pd
+
+from .....utils.download import download
+from .....utils.cache import CACHE_DIR
+from ....utils.io.readers import TSReader
+from ....utils.io.writers import TSWriter
+from ...base import BaseComponent
+from .ts_functions import load_from_dataframe, time_feature
+
+
+__all__ = [
+    "ReadTS",
+    "BuildTSDataset",
+    "TSCutOff",
+    "TSNormalize",
+    "TimeFeature",
+    "TStoArray",
+    "BuildPadMask",
+    "ArraytoTS",
+    "TSDeNormalize",
+]
+
+
+class ReadTS(BaseComponent):
+
+    INPUT_KEYS = ["ts"]
+    OUTPUT_KEYS = ["ts_path", "ts", "ori_ts"]
+    DEAULT_INPUTS = {"ts": "ts"}
+    DEAULT_OUTPUTS = {"ts_path": "ts_path", "ts": "ts", "ori_ts": "ori_ts"}
+
+    def __init__(self):
+        super().__init__()
+        self._reader = TSReader(backend="pandas")
+        self._writer = TSWriter(backend="pandas")
+
+    def apply(self, ts):
+        if not isinstance(ts, str):
+            ts_path = (Path(CACHE_DIR) / "predict_input" / "tmp_ts.csv").as_posix()
+            self._writer.write(ts_path, ts)
+            return {"ts_path": ts_path, "ts": ts, "ori_ts": deepcopy(ts)}
+
+        ts_path = ts
+        # XXX: auto download for url
+        ts_path = self._download_from_url(ts_path)
+        ts = self._reader.read(ts_path)
+        return {"ts_path": ts_path, "ts": ts, "ori_ts": deepcopy(ts)}
+
+    def _download_from_url(self, in_path):
+        if in_path.startswith("http"):
+            file_name = Path(in_path).name
+            save_path = Path(CACHE_DIR) / "predict_input" / file_name
+            download(in_path, save_path, overwrite=True)
+            return save_path.as_posix()
+        return in_path
+
+
+class TSCutOff(BaseComponent):
+
+    INPUT_KEYS = ["ts", "ori_ts"]
+    OUTPUT_KEYS = ["ts", "ori_ts"]
+    DEAULT_INPUTS = {"ts": "ts", "ori_ts": "ori_ts"}
+    DEAULT_OUTPUTS = {"ts": "ts", "ori_ts": "ori_ts"}
+
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def apply(self, ts, ori_ts):
+        skip_len = self.size.get("skip_chunk_len", 0)
+        if len(ts) < self.size["in_chunk_len"] + skip_len:
+            raise ValueError(
+                f"The length of the input data is {len(ts)}, but it should be at least {self.size['in_chunk_len'] + self.size['skip_chunk_len']} for training."
+            )
+        ts_data = ts[-(self.size["in_chunk_len"] + skip_len) :]
+        return {"ts": ts_data, "ori_ts": ts_data}
+
+
+class TSNormalize(BaseComponent):
+
+    INPUT_KEYS = ["ts"]
+    OUTPUT_KEYS = ["ts"]
+    DEAULT_INPUTS = {"ts": "ts"}
+    DEAULT_OUTPUTS = {"ts": "ts"}
+
+    def __init__(self, scale_path, params_info):
+        super().__init__()
+        self.scaler = joblib.load(scale_path)
+        self.params_info = params_info
+
+    def apply(self, ts):
+        """apply"""
+        if self.params_info.get("target_cols", None) is not None:
+            ts[self.params_info["target_cols"]] = self.scaler.transform(
+                ts[self.params_info["target_cols"]]
+            )
+        if self.params_info.get("feature_cols", None) is not None:
+            ts[self.params_info["feature_cols"]] = self.scaler.transform(
+                ts[self.params_info["feature_cols"]]
+            )
+
+        return {"ts": ts}
+
+
+class TSDeNormalize(BaseComponent):
+
+    INPUT_KEYS = ["pred"]
+    OUTPUT_KEYS = ["pred"]
+    DEAULT_INPUTS = {"pred": "pred"}
+    DEAULT_OUTPUTS = {"pred": "pred"}
+
+    def __init__(self, scale_path, params_info):
+        super().__init__()
+        self.scaler = joblib.load(scale_path)
+        self.params_info = params_info
+
+    def apply(self, pred):
+        """apply"""
+        scale_cols = pred.columns.values.tolist()
+        pred[scale_cols] = self.scaler.inverse_transform(pred[scale_cols])
+        return {"pred": pred}
+
+
+class BuildTSDataset(BaseComponent):
+
+    INPUT_KEYS = ["ts", "ori_ts"]
+    OUTPUT_KEYS = ["ts", "ori_ts"]
+    DEAULT_INPUTS = {"ts": "ts", "ori_ts": "ori_ts"}
+    DEAULT_OUTPUTS = {"ts": "ts", "ori_ts": "ori_ts"}
+
+    def __init__(self, params_info):
+        super().__init__()
+        self.params_info = params_info
+
+    def apply(self, ts, ori_ts):
+        """apply"""
+        ts_data = load_from_dataframe(ts, **self.params_info)
+        return {"ts": ts_data, "ori_ts": ts_data}
+
+
+class TimeFeature(BaseComponent):
+
+    INPUT_KEYS = ["ts"]
+    OUTPUT_KEYS = ["ts"]
+    DEAULT_INPUTS = {"ts": "ts"}
+    DEAULT_OUTPUTS = {"ts": "ts"}
+
+    def __init__(self, params_info, size, holiday=False):
+        super().__init__()
+        self.freq = params_info["freq"]
+        self.size = size
+        self.holiday = holiday
+
+    def apply(self, ts):
+        """apply"""
+        if not self.holiday:
+            ts = time_feature(
+                ts,
+                self.freq,
+                ["hourofday", "dayofmonth", "dayofweek", "dayofyear"],
+                self.size["out_chunk_len"],
+            )
+        else:
+            ts = time_feature(
+                ts,
+                self.freq,
+                [
+                    "minuteofhour",
+                    "hourofday",
+                    "dayofmonth",
+                    "dayofweek",
+                    "dayofyear",
+                    "monthofyear",
+                    "weekofyear",
+                    "holidays",
+                ],
+                self.size["out_chunk_len"],
+            )
+        return {"ts": ts}
+
+
+class BuildPadMask(BaseComponent):
+
+    INPUT_KEYS = ["ts"]
+    OUTPUT_KEYS = ["ts"]
+    DEAULT_INPUTS = {"ts": "ts"}
+    DEAULT_OUTPUTS = {"ts": "ts"}
+
+    def __init__(self, input_data):
+        super().__init__()
+        self.input_data = input_data
+
+    def apply(self, ts):
+        if "features" in self.input_data:
+            ts["features"] = ts["past_target"]
+
+        if "pad_mask" in self.input_data:
+            target_dim = len(ts["features"])
+            max_length = self.input_data["pad_mask"][-1]
+            if max_length > 0:
+                ones = np.ones(max_length, dtype=np.int32)
+                if max_length != target_dim:
+                    target_ndarray = np.array(ts["features"]).astype(np.float32)
+                    target_ndarray_final = np.zeros(
+                        [max_length, target_dim], dtype=np.int32
+                    )
+                    end = min(target_dim, max_length)
+                    target_ndarray_final[:end, :] = target_ndarray
+                    ts["features"] = target_ndarray_final
+                    ones[end:] = 0.0
+                    ts["pad_mask"] = ones
+                else:
+                    ts["pad_mask"] = ones
+        return {"ts": ts}
+
+
+class TStoArray(BaseComponent):
+
+    INPUT_KEYS = ["ts"]
+    OUTPUT_KEYS = ["ts"]
+    DEAULT_INPUTS = {"ts": "ts"}
+    DEAULT_OUTPUTS = {"ts": "ts"}
+
+    def __init__(self, input_data):
+        super().__init__()
+        self.input_data = input_data
+
+    def apply(self, ts):
+        ts_list = []
+        input_name = list(self.input_data.keys())
+        input_name.sort()
+        for key in input_name:
+            ts_list.append(np.array(ts[key]).astype("float32"))
+
+        return {"ts": ts_list}
+
+
+class ArraytoTS(BaseComponent):
+
+    INPUT_KEYS = ["ori_ts", "pred"]
+    OUTPUT_KEYS = ["pred"]
+    DEAULT_INPUTS = {"ori_ts": "ori_ts", "pred": "pred"}
+    DEAULT_OUTPUTS = {"pred": "pred"}
+
+    def __init__(self, info_params):
+        super().__init__()
+        self.info_params = info_params
+
+    def apply(self, ori_ts, pred):
+        pred = pred[0]
+        if ori_ts.get("past_target", None) is not None:
+            ts = ori_ts["past_target"]
+        elif ori_ts.get("observed_cov_numeric", None) is not None:
+            ts = ori_ts["observed_cov_numeric"]
+        elif ori_ts.get("known_cov_numeric", None) is not None:
+            ts = ori_ts["known_cov_numeric"]
+        elif ori_ts.get("static_cov_numeric", None) is not None:
+            ts = ori_ts["static_cov_numeric"]
+        else:
+            raise ValueError("No value in ori_ts")
+
+        column_name = (
+            self.info_params["target_cols"]
+            if "target_cols" in self.info_params
+            else self.info_params["feature_cols"]
+        )
+        if isinstance(self.info_params["freq"], str):
+            past_target_index = ts.index
+            if past_target_index.freq is None:
+                past_target_index.freq = pd.infer_freq(ts.index)
+            future_target_index = pd.date_range(
+                past_target_index[-1] + past_target_index.freq,
+                periods=pred.shape[0],
+                freq=self.info_params["freq"],
+                name=self.info_params["time_col"],
+            )
+        elif isinstance(self.info_params["freq"], int):
+            start_idx = max(ts.index) + 1
+            stop_idx = start_idx + pred.shape[0]
+            future_target_index = pd.RangeIndex(
+                start=start_idx,
+                stop=stop_idx,
+                step=self.info_params["freq"],
+                name=self.info_params["time_col"],
+            )
+
+        future_target = pd.DataFrame(
+            np.reshape(pred, newshape=[pred.shape[0], -1]),
+            index=future_target_index,
+            columns=column_name,
+        )
+        return {"pred": future_target}
+
+
+class GetAnomaly(BaseComponent):
+
+    INPUT_KEYS = ["ori_ts", "pred_ts"]
+    OUTPUT_KEYS = ["pred_ts"]
+    DEAULT_INPUTS = {"ori_ts": "ori_ts", "pred_ts": "pred_ts"}
+    DEAULT_OUTPUTS = {"pred_ts": "pred_ts"}
+
+    def __init__(self, model_threshold, info_params):
+        super().__init__()
+        self.model_threshold = model_threshold
+        self.info_params = info_params
+
+    def apply(self, ori_ts, pred_ts):
+        if ori_ts.get("past_target", None) is not None:
+            ts = ori_ts["past_target"]
+        elif ori_ts.get("observed_cov_numeric", None) is not None:
+            ts = ori_ts["observed_cov_numeric"]
+        elif ori_ts.get("known_cov_numeric", None) is not None:
+            ts = ori_ts["known_cov_numeric"]
+        elif ori_ts.get("static_cov_numeric", None) is not None:
+            ts = ori_ts["static_cov_numeric"]
+        else:
+            raise ValueError("No value in ori_ts")
+        column_name = (
+            self.info_params["target_cols"]
+            if "target_cols" in self.info_params
+            else self.info_params["feature_cols"]
+        )
+
+        anomaly_score = np.mean(np.square(pred_ts - np.array(ts)), axis=-1)
+        anomaly_label = (anomaly_score >= self.model_threshold) + 0
+
+        past_target_index = ts.index
+        past_target_index.name = self.info_params["time_col"]
+        anomaly_label = pd.DataFrame(
+            np.reshape(anomaly_label, newshape=[pred_ts.shape[0], -1]),
+            index=past_target_index,
+            columns=["label"],
+        )
+        return {"pred_ts": anomaly_label}

+ 424 - 0
paddlex/inference/components/transforms/ts/ts_functions.py

@@ -0,0 +1,424 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict
+import numpy as np
+import pandas as pd
+import joblib
+import chinese_calendar
+from pandas.tseries.offsets import DateOffset, Easter, Day
+from pandas.tseries import holiday as hd
+from sklearn.preprocessing import StandardScaler
+
+
+MAX_WINDOW = 183 + 17
+EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
+NewYearsDay = hd.Holiday("New Years Day", month=1, day=1)
+SuperBowl = hd.Holiday("Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1)))
+MothersDay = hd.Holiday(
+    "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
+)
+IndependenceDay = hd.Holiday("Independence Day", month=7, day=4)
+ChristmasEve = hd.Holiday("Christmas", month=12, day=24)
+ChristmasDay = hd.Holiday("Christmas", month=12, day=25)
+NewYearsEve = hd.Holiday("New Years Eve", month=12, day=31)
+BlackFriday = hd.Holiday(
+    "Black Friday",
+    month=11,
+    day=1,
+    offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
+)
+CyberMonday = hd.Holiday(
+    "Cyber Monday",
+    month=11,
+    day=1,
+    offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
+)
+
+HOLIDAYS = [
+    hd.EasterMonday,
+    hd.GoodFriday,
+    hd.USColumbusDay,
+    hd.USLaborDay,
+    hd.USMartinLutherKingJr,
+    hd.USMemorialDay,
+    hd.USPresidentsDay,
+    hd.USThanksgivingDay,
+    EasterSunday,
+    NewYearsDay,
+    SuperBowl,
+    MothersDay,
+    IndependenceDay,
+    ChristmasEve,
+    ChristmasDay,
+    NewYearsEve,
+    BlackFriday,
+    CyberMonday,
+]
+
+
+def _cal_year(
+    x: np.datetime64,
+):
+    return x.year
+
+
+def _cal_month(
+    x: np.datetime64,
+):
+    return x.month
+
+
+def _cal_day(
+    x: np.datetime64,
+):
+    return x.day
+
+
+def _cal_hour(
+    x: np.datetime64,
+):
+    return x.hour
+
+
+def _cal_weekday(
+    x: np.datetime64,
+):
+    return x.dayofweek
+
+
+def _cal_quarter(
+    x: np.datetime64,
+):
+    return x.quarter
+
+
+def _cal_hourofday(
+    x: np.datetime64,
+):
+    return x.hour / 23.0 - 0.5
+
+
+def _cal_dayofweek(
+    x: np.datetime64,
+):
+    return x.dayofweek / 6.0 - 0.5
+
+
+def _cal_dayofmonth(
+    x: np.datetime64,
+):
+    return x.day / 30.0 - 0.5
+
+
+def _cal_dayofyear(
+    x: np.datetime64,
+):
+    return x.dayofyear / 364.0 - 0.5
+
+
+def _cal_weekofyear(
+    x: np.datetime64,
+):
+    return x.weekofyear / 51.0 - 0.5
+
+
+def _cal_holiday(
+    x: np.datetime64,
+):
+    return float(chinese_calendar.is_holiday(x))
+
+
+def _cal_workday(
+    x: np.datetime64,
+):
+    return float(chinese_calendar.is_workday(x))
+
+
+def _cal_minuteofhour(
+    x: np.datetime64,
+):
+    return x.minute / 59 - 0.5
+
+
+def _cal_monthofyear(
+    x: np.datetime64,
+):
+    return x.month / 11.0 - 0.5
+
+
+CAL_DATE_METHOD = {
+    "year": _cal_year,
+    "month": _cal_month,
+    "day": _cal_day,
+    "hour": _cal_hour,
+    "weekday": _cal_weekday,
+    "quarter": _cal_quarter,
+    "minuteofhour": _cal_minuteofhour,
+    "monthofyear": _cal_monthofyear,
+    "hourofday": _cal_hourofday,
+    "dayofweek": _cal_dayofweek,
+    "dayofmonth": _cal_dayofmonth,
+    "dayofyear": _cal_dayofyear,
+    "weekofyear": _cal_weekofyear,
+    "is_holiday": _cal_holiday,
+    "is_workday": _cal_workday,
+}
+
+
+def load_from_one_dataframe(
+    data: Union[pd.DataFrame, pd.Series],
+    time_col: Optional[str] = None,
+    value_cols: Optional[Union[List[str], str]] = None,
+    freq: Optional[Union[str, int]] = None,
+    drop_tail_nan: bool = False,
+    dtype: Optional[Union[type, Dict[str, type]]] = None,
+):
+
+    series_data = None
+    if value_cols is None:
+        if isinstance(data, pd.Series):
+            series_data = data.copy()
+        else:
+            series_data = data.loc[:, data.columns != time_col].copy()
+    else:
+        series_data = data.loc[:, value_cols].copy()
+
+    if time_col:
+        if time_col not in data.columns:
+            raise ValueError(
+                "The time column: {} doesn't exist in the `data`!".format(time_col)
+            )
+        time_col_vals = data.loc[:, time_col]
+    else:
+        time_col_vals = data.index
+
+    if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
+        time_col_vals = time_col_vals.astype(str)
+
+    if np.issubdtype(time_col_vals.dtype, np.integer):
+        if freq:
+            if not isinstance(freq, int) or freq < 1:
+                raise ValueError(
+                    "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
+                )
+        else:
+            freq = 1
+        start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
+        if (stop_idx - start_idx) / freq != len(data):
+            raise ValueError("The number of rows doesn't match with the RangeIndex!")
+        time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
+    elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
+        time_col_vals.dtype, np.datetime64
+    ):
+        time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
+        time_index = pd.DatetimeIndex(time_col_vals)
+        if freq:
+            if not isinstance(freq, str):
+                raise ValueError(
+                    "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
+                )
+        else:
+            # If freq is not provided and automatic inference fail, throw exception
+            freq = pd.infer_freq(time_index)
+            if freq is None:
+                raise ValueError(
+                    "Failed to infer the `freq`. A valid `freq` is required."
+                )
+            if freq[0] == "-":
+                freq = freq[1:]
+    else:
+        raise ValueError("The type of `time_col` is invalid.")
+    if isinstance(series_data, pd.Series):
+        series_data = series_data.to_frame()
+    series_data.set_index(time_index, inplace=True)
+    series_data.sort_index(inplace=True)
+    return series_data
+
+
+def load_from_dataframe(
+    df: pd.DataFrame,
+    group_id: str = None,
+    time_col: Optional[str] = None,
+    target_cols: Optional[Union[List[str], str]] = None,
+    label_col: Optional[Union[List[str], str]] = None,
+    observed_cov_cols: Optional[Union[List[str], str]] = None,
+    feature_cols: Optional[Union[List[str], str]] = None,
+    known_cov_cols: Optional[Union[List[str], str]] = None,
+    static_cov_cols: Optional[Union[List[str], str]] = None,
+    freq: Optional[Union[str, int]] = None,
+    fill_missing_dates: bool = False,
+    fillna_method: str = "pre",
+    fillna_window_size: int = 10,
+    **kwargs,
+):
+
+    dfs = []  # seperate multiple group
+    if group_id is not None:
+        group_unique = df[group_id].unique()
+        for column in group_unique:
+            dfs.append(df[df[group_id].isin([column])])
+    else:
+        dfs = [df]
+    res = []
+    if label_col:
+        if isinstance(label_col, str) and len(label_col) > 1:
+            raise ValueError("The length of label_col must be 1.")
+        target_cols = label_col
+    if feature_cols:
+        observed_cov_cols = feature_cols
+    for df in dfs:
+        target = None
+        observed_cov = None
+        known_cov = None
+        static_cov = dict()
+        if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
+            target = load_from_one_dataframe(
+                df,
+                time_col,
+                [a for a in df.columns if a != time_col],
+                freq,
+            )
+
+        else:
+            if target_cols:
+                target = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    target_cols,
+                    freq,
+                )
+
+            if observed_cov_cols:
+                observed_cov = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    observed_cov_cols,
+                    freq,
+                )
+
+            if known_cov_cols:
+                known_cov = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    known_cov_cols,
+                    freq,
+                )
+
+            if static_cov_cols:
+                if isinstance(static_cov_cols, str):
+                    static_cov_cols = [static_cov_cols]
+                for col in static_cov_cols:
+                    if col not in df.columns or len(np.unique(df[col])) != 1:
+                        raise ValueError(
+                            "static cov cals data is not in columns or schema is not right!"
+                        )
+                    static_cov[col] = df[col].iloc[0]
+        res.append(
+            {
+                "past_target": target,
+                "observed_cov_numeric": observed_cov,
+                "known_cov_numeric": known_cov,
+                "static_cov_numeric": static_cov,
+            }
+        )
+    return res[0]
+
+
+def _distance_to_holiday(holiday):
+    def _distance_to_day(index):
+        holiday_date = holiday.dates(
+            index - pd.Timedelta(days=MAX_WINDOW),
+            index + pd.Timedelta(days=MAX_WINDOW),
+        )
+        assert (
+            len(holiday_date) != 0
+        ), f"No closest holiday for the date index {index} found."
+        # It sometimes returns two dates if it is exactly half a year after the
+        # holiday. In this case, the smaller distance (182 days) is returned.
+        return float((index - holiday_date[0]).days)
+
+    return _distance_to_day
+
+
+def time_feature(dataset, freq, feature_cols, extend_points, inplace: bool = False):
+    """
+    Transform time column to time features.
+
+    Args:
+        dataset(TSDataset): Dataset to be transformed.
+        inplace(bool): Whether to perform the transformation inplace. default=False
+
+    Returns:
+        TSDataset
+    """
+    new_ts = dataset
+    if not inplace:
+        new_ts = dataset.copy()
+    # Get known_cov
+    kcov = new_ts["known_cov_numeric"]
+    if not kcov:
+        tf_kcov = new_ts["past_target"].index.to_frame()
+    else:
+        tf_kcov = kcov.index.to_frame()
+    time_col = tf_kcov.columns[0]
+    if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
+        raise ValueError(
+            "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
+        )
+    if not kcov:
+        freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
+        extend_time = pd.date_range(
+            start=tf_kcov[time_col][-1],
+            freq=freq,
+            periods=extend_points + 1,
+            closed="right",
+            name=time_col,
+        ).to_frame()
+        tf_kcov = pd.concat([tf_kcov, extend_time])
+
+    for k in feature_cols:
+        if k != "holidays":
+            v = tf_kcov[time_col].apply(lambda x: CAL_DATE_METHOD[k](x))
+            v.index = tf_kcov[time_col]
+
+            if new_ts["known_cov_numeric"] is None:
+                new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
+            else:
+                new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
+                    new_ts["known_cov_numeric"].index
+                )
+
+        else:
+            holidays_col = []
+            for i, H in enumerate(HOLIDAYS):
+                v = tf_kcov[time_col].apply(_distance_to_holiday(H))
+                v.index = tf_kcov[time_col]
+                holidays_col.append(k + "_" + str(i))
+                if new_ts["known_cov_numeric"] is None:
+                    new_ts["known_cov_numeric"] = pd.DataFrame(
+                        v.rename(k + "_" + str(i)), index=v.index
+                    )
+                else:
+                    new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
+                        new_ts["known_cov_numeric"].index
+                    )
+
+            scaler = StandardScaler()
+            scaler.fit(new_ts["known_cov_numeric"][holidays_col])
+            new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
+                new_ts["known_cov_numeric"][holidays_col]
+            )
+    return new_ts

+ 3 - 1
paddlex/inference/predictors/__init__.py

@@ -14,6 +14,7 @@
 
 
 from pathlib import Path
+from .official_models import official_models
 
 from .base import BasePredictor, BasicPredictor
 from .image_classification import ClasPredictor
@@ -24,7 +25,8 @@ from .object_detection import DetPredictor
 from .instance_segmentation import InstanceSegPredictor
 from .semantic_segmentation import SegPredictor
 from .general_recognition import ShiTuRecPredictor
-from .official_models import official_models
+from .ts_fc import TSFcPredictor
+from .ts_cls import TSClsPredictor
 
 
 def create_predictor(model: str, device: str = None, *args, **kwargs) -> BasePredictor:

+ 45 - 0
paddlex/inference/predictors/official_models.py

@@ -18,6 +18,7 @@ from ...utils import logging
 from ...utils.cache import CACHE_DIR
 from ...utils.download import download_and_extract
 
+
 OFFICIAL_MODELS = {
     "ResNet18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/ResNet18_infer.tar",
     "ResNet18_vd": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/ResNet18_vd_infer.tar",
@@ -139,6 +140,11 @@ CLIP_vit_large_patch14_224_infer.tar",
 Deeplabv3_Plus-R50_infer.tar",
     "Deeplabv3_Plus-R101": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 Deeplabv3_Plus-R101_infer.tar",
+    "PP-ShiTuV2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/PP-ShiTuV2_rec_infer.tar",
+    "PP-ShiTuV2_rec_CLIP_vit_base": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
+PP-ShiTuV2_rec_CLIP_vit_base_infer.tar",
+    "PP-ShiTuV2_rec_CLIP_vit_large": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
+PP-ShiTuV2_rec_CLIP_vit_large_infer.tar",
     "PP-LiteSeg-T": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/PP-LiteSeg-T_infer.tar",
     "OCRNet_HRNet-W48": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/OCRNet_HRNet-W48_infer.tar",
     "OCRNet_HRNet-W18": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/OCRNet_HRNet-W18_infer.tar",
@@ -156,6 +162,20 @@ Deeplabv3_Plus-R101_infer.tar",
     "Mask-RT-DETR-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Mask-RT-DETR-L_infer.tar",
     "PP-OCRv4_server_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 PP-OCRv4_server_rec_infer.tar",
+    "Mask-RT-DETR-S": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Mask-RT-DETR-S_infer.tar",
+    "Mask-RT-DETR-M": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Mask-RT-DETR-M_infer.tar",
+    "Mask-RT-DETR-X": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Mask-RT-DETR-X_infer.tar",
+    "SOLOv2": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/SOLOv2_infer.tar",
+    "MaskRCNN-ResNet50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MaskRCNN-ResNet50_infer.tar",
+    "MaskRCNN-ResNet50-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MaskRCNN-ResNet50-FPN_infer.tar",
+    "MaskRCNN-ResNet50-vd-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MaskRCNN-ResNet50-vd-FPN_infer.tar",
+    "MaskRCNN-ResNet50-vd-SSLDv2-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MaskRCNN-ResNet50-vd-SSLDv2_infer.tar",
+    "MaskRCNN-ResNet101-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MaskRCNN-ResNet101-FPN_infer.tar",
+    "MaskRCNN-ResNet101-vd-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MaskRCNN-ResNet101-vd-FPN_infer.tar",
+    "MaskRCNN-ResNeXt101-vd-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/MaskRCNN-ResNeXt101-vd-FPN_infer.tar",
+    "Cascade-MaskRCNN-ResNet50-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Cascade-MaskRCNN-ResNet50-FPN_infer.tar",
+    "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN_infer.tar",
+    "PP-YOLOE_seg-S": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/PP-YOLOE_seg-S_infer.tar",
     "PP-OCRv4_mobile_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 PP-OCRv4_mobile_rec_infer.tar",
     "PP-OCRv4_server_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
@@ -170,6 +190,31 @@ openatom_rec_svtrv2_ch_infer.tar",
     "SLANet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/SLANet_infer.tar",
     "LaTeX_OCR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/LaTeX_OCR_rec_infer.tar",
     "UVDoc": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/UVDoc_infer.tar",
+    "FasterRCNN-ResNet34-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet34-FPN_infer.tar",
+    "FasterRCNN-ResNet50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet50_infer.tar",
+    "FasterRCNN-ResNet50-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet50-FPN_infer.tar",
+    "FasterRCNN-ResNet50-vd-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet50-vd-FPN_infer.tar",
+    "FasterRCNN-ResNet50-vd-SSLDv2-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet50-vd-SSLDv2-FPN_infer.tar",
+    "FasterRCNN-ResNet101": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet101_infer.tar",
+    "FasterRCNN-ResNet101-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet101-FPN_infer.tar",
+    "FasterRCNN-ResNeXt101-vd-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNeXt101-vd-FPN_infer.tar",
+    "FasterRCNN-Swin-Tiny-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-Swin-Tiny-FPN_infer.tar",
+    "Cascade-FasterRCNN-ResNet50-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Cascade-FasterRCNN-ResNet50-FPN_infer.tar",
+    "Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN_infer.tar",
+    "UVDoc": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/UVDoc_infer.tar",
+    "DLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/DLinear_infer.tar",
+    "NLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/NLinear_infer.tar",
+    "RLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/RLinear_infer.tar",
+    "Nonstationary": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Nonstationary_infer.tar",
+    "TimesNet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/TimesNet_infer.tar",
+    "TiDE": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/TiDE_infer.tar",
+    "PatchTST": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/PatchTST_infer.tar",
+    "DLinear_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/DLinear_ad_infer.tar",
+    "AutoEncoder_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/AutoEncoder_ad_infer.tar",
+    "Nonstationary_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/Nonstationary_ad_infer.tar",
+    "PatchTST_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/PatchTST_ad_infer.tar",
+    "TimesNet_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/TimesNet_ad_infer.tar",
+    "TimesNet_cls": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/TimesNet_cls_infer.tar",
 }
 
 

+ 88 - 0
paddlex/inference/predictors/ts.py

@@ -0,0 +1,88 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ...utils.func_register import FuncRegister
+from ...modules.ts_forecast.model_list import MODELS
+from ..components import *
+from ..results import TSResult
+from ..utils.process_hook import batchable_method
+from .base import BasicPredictor
+
+
+class TSPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    def _check_args(self, kwargs):
+        pass
+
+    def _build_components(self):
+        preprocess = self._build_preprocess()
+        predictor = TSPPPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        postprocess = self._build_postprocess()
+        return {**preprocess, "predictor": predictor, **postprocess}
+
+    def _build_preprocess(self):
+        if not self.config.get("info_params", None):
+            raise Exception("info_params is not found in config file")
+
+        ops = {}
+        ops["ReadTS"] = ReadTS()
+        ops["TSCutOff"] = TSCutOff(self.config["size"])
+
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            ops["TSNormalize"] = TSNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+
+        ops["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+
+        if self.config.get("time_feat", None):
+            ops["TimeFeature"] = TimeFeature(
+                self.config["info_params"],
+                self.config["size"],
+                self.config["holiday"],
+            )
+        ops["TStoArray"] = TStoArray(self.config["input_data"])
+        return ops
+
+    def _build_postprocess(self):
+        if not self.config.get("info_params", None):
+            raise Exception("info_params is not found in config file")
+
+        ops = {}
+        ops["ArraytoTS"] = ArraytoTS(self.config["info_params"])
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            ops["TSDeNormalize"] = TSDeNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+        return ops
+
+    @batchable_method
+    def _pack_res(self, data):
+        return {
+            "result": TSResult({"ts_path": data["ts_path"], "forecast": data["pred"]})
+        }

+ 65 - 0
paddlex/inference/predictors/ts_cls.py

@@ -0,0 +1,65 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from ...modules.ts_classification.model_list import MODELS
+from ..components import *
+from ..results import TSClsResult
+from ..utils.process_hook import batchable_method
+from .base import BasicPredictor
+
+
+class TSClsPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    def _check_args(self, kwargs):
+        pass
+
+    def _build_components(self):
+        preprocess = self._build_preprocess()
+        predictor = TSPPPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        return {**preprocess, "predictor": predictor}
+
+    def _build_preprocess(self):
+        if not self.config.get("info_params", None):
+            raise Exception("info_params is not found in config file")
+
+        ops = {}
+        ops["ReadTS"] = ReadTS()
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            ops["TSNormalize"] = TSNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+
+        ops["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+        ops["BuildPadMask"] = BuildPadMask(self.config["input_data"])
+        ops["TStoArray"] = TStoArray(self.config["input_data"])
+
+        return ops
+
+    @batchable_method
+    def _pack_res(self, data):
+        return {
+            "result": TSClsResult(
+                {"ts_path": data["ts_path"], "forecast": data["pred"]}
+            )
+        }

+ 87 - 0
paddlex/inference/predictors/ts_fc.py

@@ -0,0 +1,87 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ...modules.ts_forecast.model_list import MODELS
+from ..components import *
+from ..results import TSFcResult
+from ..utils.process_hook import batchable_method
+from .base import BasicPredictor
+
+
+class TSFcPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    def _check_args(self, kwargs):
+        pass
+
+    def _build_components(self):
+        preprocess = self._build_preprocess()
+        predictor = TSPPPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        postprocess = self._build_postprocess()
+        return {**preprocess, "predictor": predictor, **postprocess}
+
+    def _build_preprocess(self):
+        if not self.config.get("info_params", None):
+            raise Exception("info_params is not found in config file")
+
+        ops = {}
+        ops["ReadTS"] = ReadTS()
+        ops["TSCutOff"] = TSCutOff(self.config["size"])
+
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            ops["TSNormalize"] = TSNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+
+        ops["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+
+        if self.config.get("time_feat", None):
+            ops["TimeFeature"] = TimeFeature(
+                self.config["info_params"],
+                self.config["size"],
+                self.config["holiday"],
+            )
+        ops["TStoArray"] = TStoArray(self.config["input_data"])
+        return ops
+
+    def _build_postprocess(self):
+        if not self.config.get("info_params", None):
+            raise Exception("info_params is not found in config file")
+
+        ops = {}
+        ops["ArraytoTS"] = ArraytoTS(self.config["info_params"])
+        if self.config.get("scale", None):
+            scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+            if not os.path.exists(scaler_file_path):
+                raise Exception(f"Cannot find scaler file: {scaler_file_path}")
+            ops["TSDeNormalize"] = TSDeNormalize(
+                scaler_file_path, self.config["info_params"]
+            )
+        return ops
+
+    @batchable_method
+    def _pack_res(self, data):
+        return {
+            "result": TSFcResult({"ts_path": data["ts_path"], "forecast": data["pred"]})
+        }

+ 1 - 0
paddlex/inference/results/__init__.py

@@ -21,3 +21,4 @@ from .ocr import OCRResult
 from .det import DetResult
 from .seg import SegResult
 from .instance_seg import InstanceSegResult
+from .ts import TSFcResult, TSClsResult

+ 1 - 2
paddlex/inference/results/det.py

@@ -84,7 +84,6 @@ class DetResult(BaseResult):
 
     def __init__(self, data):
         super().__init__(data)
-        self.data = data
         # We use pillow backend to save both numpy arrays and PIL Image objects
         self._img_reader.set_backend("pillow")
         self._img_writer.set_backend("pillow")
@@ -93,7 +92,7 @@ class DetResult(BaseResult):
         """apply"""
         boxes = self["boxes"]
         img_path = self["img_path"]
-        labels = self.data["labels"]
+        labels = self["labels"]
         file_name = os.path.basename(img_path)
 
         image = self._img_reader.read(img_path)

+ 1 - 2
paddlex/inference/results/instance_seg.py

@@ -65,7 +65,6 @@ class InstanceSegResult(BaseResult):
 
     def __init__(self, data):
         super().__init__(data)
-        self.data = data
         # We use pillow backend to save both numpy arrays and PIL Image objects
         self._img_reader.set_backend("pillow")
         self._img_writer.set_backend("pillow")
@@ -75,7 +74,7 @@ class InstanceSegResult(BaseResult):
         boxes = self["boxes"]
         masks = self["masks"]
         img_path = self["img_path"]
-        labels = self.data["labels"]
+        labels = self["labels"]
         file_name = os.path.basename(img_path)
 
         image = self._img_reader.read(img_path)

+ 60 - 0
paddlex/inference/results/ts.py

@@ -0,0 +1,60 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+import numpy as np
+import pandas as pd
+
+from ..utils.io import TSWriter
+from .base import BaseResult
+
+
+class TSFcResult(BaseResult):
+
+    def __init__(self, data):
+        super().__init__(data)
+        self._writer = TSWriter(backend="pandas")
+
+    def save_to_csv(self, save_path):
+        """write ts"""
+        if not save_path.endswith(".csv"):
+            save_path = Path(save_path) / f"{Path(self['ts_path']).stem}.csv"
+        self._writer.write(save_path, self["forecast"])
+
+
+class TSClsResult(BaseResult):
+
+    def __init__(self, data):
+        super().__init__(
+            {"ts_path": data["ts_path"], "classification": self.process_data(data)}
+        )
+        self._writer = TSWriter(backend="pandas")
+
+    def process_data(self, data):
+        """apply"""
+        pred_ts = data["forecast"][0]
+        pred_ts -= np.max(pred_ts, axis=-1, keepdims=True)
+        pred_ts = np.exp(pred_ts) / np.sum(np.exp(pred_ts), axis=-1, keepdims=True)
+        classid = np.argmax(pred_ts, axis=-1)
+        pred_score = pred_ts[classid]
+        result = {"classid": [classid], "score": [pred_score]}
+        result = pd.DataFrame.from_dict(result)
+        result.index.name = "sample"
+        return result
+
+    def save_to_csv(self, save_path):
+        """write ts"""
+        if not save_path.endswith(".csv"):
+            save_path = Path(save_path) / f"{Path(self['ts_path']).stem}.csv"
+        self._writer.write(save_path, self["classification"])

+ 2 - 9
paddlex/inference/utils/io/__init__.py

@@ -13,12 +13,5 @@
 # limitations under the License.
 
 
-from .readers import ImageReader, VideoReader, ReaderType
-from .writers import (
-    ImageWriter,
-    TextWriter,
-    JsonWriter,
-    WriterType,
-    HtmlWriter,
-    XlsxWriter,
-)
+from .readers import ReaderType, ImageReader, VideoReader, TSReader
+from .writers import WriterType, ImageWriter, TextWriter, JsonWriter, TSWriter, HtmlWriter, XlsxWriter

+ 44 - 1
paddlex/inference/utils/io/readers.py

@@ -17,8 +17,9 @@ import enum
 import itertools
 import cv2
 from PIL import Image, ImageOps
+import pandas as pd
 
-__all__ = ["ImageReader", "VideoReader", "ReaderType"]
+__all__ = ["ReaderType", "ImageReader", "VideoReader", "TSReader"]
 
 
 class ReaderType(enum.Enum):
@@ -27,6 +28,8 @@ class ReaderType(enum.Enum):
     IMAGE = 1
     GENERATIVE = 2
     POINT_CLOUD = 3
+    JSON = 4
+    TS = 5
 
 
 class _BaseReader(object):
@@ -236,3 +239,43 @@ class OpenCVVideoReaderBackend(_VideoReaderBackend):
         if self._cap is not None:
             self._cap_release()
             self._cap = None
+
+
+class TSReader(_BaseReader):
+    """TSReader"""
+
+    def __init__(self, backend="pandas", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def read(self, in_path):
+        """read the image file from path"""
+        arr = self._backend.read_file(in_path)
+        return arr
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "pandas":
+            return PandasTSReaderBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return ReaderType.TS
+
+
+class _TSReaderBackend(_BaseReaderBackend):
+    """_TSReaderBackend"""
+
+    pass
+
+
+class PandasTSReaderBackend(_TSReaderBackend):
+    """PandasTSReaderBackend"""
+
+    def __init__(self):
+        super().__init__()
+
+    def read_file(self, in_path):
+        """read image file from path by OpenCV"""
+        return pd.read_csv(in_path)

+ 47 - 8
paddlex/inference/utils/io/writers.py

@@ -21,16 +21,11 @@ from pathlib import Path
 import cv2
 import numpy as np
 from PIL import Image
+import pandas as pd
 from .tablepyxl import document_to_xl
 
-__all__ = [
-    "ImageWriter",
-    "TextWriter",
-    "JsonWriter",
-    "WriterType",
-    "HtmlWriter",
-    "XlsxWriter",
-]
+
+__all__ = ["WriterType", "ImageWriter", "TextWriter", "JsonWriter", "TSWriter", "HtmlWriter", "XlsxWriter"]
 
 
 class WriterType(enum.Enum):
@@ -42,6 +37,7 @@ class WriterType(enum.Enum):
     JSON = 4
     HTML = 5
     XLSX = 6
+    TS = 7
 
 
 class _BaseWriter(object):
@@ -292,3 +288,46 @@ class UJsonWriterBackend(_BaseJsonWriterBackend):
     # TODO
     def _write_obj(self, out_path, obj, **bk_args):
         raise NotImplementedError
+
+
+class TSWriter(_BaseWriter):
+    """TSWriter"""
+
+    def __init__(self, backend="pandas", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj):
+        """write"""
+        return self._backend.write_obj(out_path, obj)
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "pandas":
+            return PandasTSWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.TS
+
+
+class _TSWriterBackend(_BaseWriterBackend):
+    """_TSWriterBackend"""
+
+    pass
+
+
+class PandasTSWriterBackend(_TSWriterBackend):
+    """PILImageWriterBackend"""
+
+    def __init__(self):
+        super().__init__()
+
+    def _write_obj(self, out_path, obj):
+        """write image object by PIL"""
+        if isinstance(obj, pd.DataFrame):
+            ts = obj
+        else:
+            raise TypeError("Unsupported object type")
+        return ts.to_csv(out_path)