Procházet zdrojové kódy

update ts export infer

Sunflower7788 před 1 rokem
rodič
revize
2e8c9b6b9d
33 změnil soubory, kde provedl 2141 přidání a 175 odebrání
  1. 2 2
      paddlex/modules/base/predictor/io/__init__.py
  2. 43 1
      paddlex/modules/base/predictor/io/readers.py
  3. 46 1
      paddlex/modules/base/predictor/io/writers.py
  4. 1 0
      paddlex/modules/base/predictor/transforms/__init__.py
  5. 500 0
      paddlex/modules/base/predictor/transforms/ts_common.py
  6. 424 0
      paddlex/modules/base/predictor/transforms/ts_functions.py
  7. 1 0
      paddlex/modules/base/predictor/utils/paddle_inference_predictor.py
  8. 2 8
      paddlex/modules/ts_anomaly_detection/predictor/__init__.py
  9. 27 0
      paddlex/modules/ts_anomaly_detection/predictor/keys.py
  10. 97 0
      paddlex/modules/ts_anomaly_detection/predictor/predictor.py
  11. 73 0
      paddlex/modules/ts_anomaly_detection/predictor/transforms.py
  12. 87 0
      paddlex/modules/ts_anomaly_detection/predictor/utils.py
  13. 59 5
      paddlex/modules/ts_anomaly_detection/trainer.py
  14. 2 9
      paddlex/modules/ts_classification/predictor/__init__.py
  15. 27 0
      paddlex/modules/ts_classification/predictor/keys.py
  16. 103 0
      paddlex/modules/ts_classification/predictor/predictor.py
  17. 79 0
      paddlex/modules/ts_classification/predictor/transforms.py
  18. 66 0
      paddlex/modules/ts_classification/predictor/utils.py
  19. 58 5
      paddlex/modules/ts_classification/trainer.py
  20. 0 121
      paddlex/modules/ts_forecast/predictor.py
  21. 17 0
      paddlex/modules/ts_forecast/predictor/__init__.py
  22. 27 0
      paddlex/modules/ts_forecast/predictor/keys.py
  23. 108 0
      paddlex/modules/ts_forecast/predictor/predictor.py
  24. 73 0
      paddlex/modules/ts_forecast/predictor/transforms.py
  25. 95 0
      paddlex/modules/ts_forecast/predictor/utils.py
  26. 58 5
      paddlex/modules/ts_forecast/trainer.py
  27. 5 5
      paddlex/repo_apis/PaddleTS_api/ts_ad/register.py
  28. 9 1
      paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py
  29. 26 2
      paddlex/repo_apis/PaddleTS_api/ts_base/model.py
  30. 9 1
      paddlex/repo_apis/PaddleTS_api/ts_base/runner.py
  31. 1 1
      paddlex/repo_apis/PaddleTS_api/ts_cls/register.py
  32. 9 1
      paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py
  33. 7 7
      paddlex/repo_apis/PaddleTS_api/ts_fc/register.py

+ 2 - 2
paddlex/modules/base/predictor/io/__init__.py

@@ -13,5 +13,5 @@
 # limitations under the License.
 
 
-from .readers import ImageReader, VideoReader, ReaderType
-from .writers import ImageWriter, TextWriter, WriterType
+from .readers import ImageReader, VideoReader, ReaderType, TSReader
+from .writers import ImageWriter, TextWriter, WriterType, TSWriter

+ 43 - 1
paddlex/modules/base/predictor/io/readers.py

@@ -16,9 +16,10 @@
 import enum
 import itertools
 import cv2
+import pandas as pd
 from PIL import Image, ImageOps
 
-__all__ = ["ImageReader", "VideoReader", "ReaderType"]
+__all__ = ["ImageReader", "VideoReader", "ReaderType", "TSReader"]
 
 
 class ReaderType(enum.Enum):
@@ -27,6 +28,7 @@ class ReaderType(enum.Enum):
     IMAGE = 1
     GENERATIVE = 2
     POINT_CLOUD = 3
+    TS = 4
 
 
 class _BaseReader(object):
@@ -231,3 +233,43 @@ class OpenCVVideoReaderBackend(_VideoReaderBackend):
         if self._cap is not None:
             self._cap_release()
             self._cap = None
+
+
+class TSReader(_BaseReader):
+    """TSReader"""
+
+    def __init__(self, backend="pandas", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def read(self, in_path):
+        """read the image file from path"""
+        arr = self._backend.read_file(in_path)
+        return arr
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "pandas":
+            return PandasTSReaderBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return ReaderType.TS
+
+
+class _TSReaderBackend(_BaseReaderBackend):
+    """_TSReaderBackend"""
+
+    pass
+
+
+class PandasTSReaderBackend(_TSReaderBackend):
+    """PandasTSReaderBackend"""
+
+    def __init__(self):
+        super().__init__()
+
+    def read_file(self, in_path):
+        """read image file from path by OpenCV"""
+        return pd.read_csv(in_path)

+ 46 - 1
paddlex/modules/base/predictor/io/writers.py

@@ -17,10 +17,11 @@ import os
 import enum
 
 import cv2
+import pandas as pd
 import numpy as np
 from PIL import Image
 
-__all__ = ["ImageWriter", "TextWriter", "WriterType"]
+__all__ = ["ImageWriter", "TextWriter", "WriterType", "TSWriter"]
 
 
 class WriterType(enum.Enum):
@@ -29,6 +30,7 @@ class WriterType(enum.Enum):
     IMAGE = 1
     VIDEO = 2
     TEXT = 3
+    TS = 4
 
 
 class _BaseWriter(object):
@@ -175,3 +177,46 @@ class PILImageWriterBackend(_ImageWriterBackend):
         else:
             raise TypeError("Unsupported object type")
         return img.save(out_path, format=self.format)
+
+
+class TSWriter(_BaseWriter):
+    """TSWriter"""
+
+    def __init__(self, backend="pandas", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj):
+        """write"""
+        return self._backend.write_obj(out_path, obj)
+
+    def _init_backend(self, bk_type, bk_args):
+        """init backend"""
+        if bk_type == "pandas":
+            return PandasTSWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.TS
+
+
+class _TSWriterBackend(_BaseWriterBackend):
+    """_TSWriterBackend"""
+
+    pass
+
+
+class PandasTSWriterBackend(_TSWriterBackend):
+    """PILImageWriterBackend"""
+
+    def __init__(self):
+        super().__init__()
+
+    def _write_obj(self, out_path, obj):
+        """write image object by PIL"""
+        if isinstance(obj, pd.DataFrame):
+            ts = obj
+        else:
+            raise TypeError("Unsupported object type")
+        return ts.to_csv(out_path)

+ 1 - 0
paddlex/modules/base/predictor/transforms/__init__.py

@@ -14,3 +14,4 @@
 
 
 from . import image_common
+from . import ts_common

+ 500 - 0
paddlex/modules/base/predictor/transforms/ts_common.py

@@ -0,0 +1,500 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from pathlib import Path
+import joblib
+import numpy as np
+import pandas as pd
+
+from .....utils.download import download
+from .....utils.cache import CACHE_DIR
+from ..transform import BaseTransform
+from ..io.readers import TSReader
+from ..io.writers import TSWriter
+from .ts_functions import load_from_dataframe, time_feature
+
+
+__all__ = [
+    "ReadTS",
+    "BuildTSDataset",
+    "TSCutOff",
+    "TSNormalize",
+    "TimeFeature",
+    "TStoArray",
+    "BuildPadMask",
+]
+
+
+class ReadTS(BaseTransform):
+    """Load image from the file."""
+
+    def __init__(self):
+        """
+        Initialize the instance.
+
+        Args:
+            format (str, optional): Target color format to convert the image to.
+                Choices are 'BGR', 'RGB', and 'GRAY'. Default: 'BGR'.
+        """
+        super().__init__()
+        self._reader = TSReader(backend="pandas")
+        self._writer = TSWriter(backend="pandas")
+
+    def apply(self, data):
+        """apply"""
+        if "ts" in data:
+            ts = data["ts"]
+            ts_path = (Path(CACHE_DIR) / "predict_input" / "tmp_ts.csv").as_posix()
+            self._writer.write(ts_path, ts)
+            data["input_path"] = ts_path
+            data["original_ts"] = ts
+            return data
+
+        elif "input_path" not in data:
+            raise KeyError(f"Key {repr('input_path')} is required, but not found.")
+
+        ts_path = data["input_path"]
+        # XXX: auto download for url
+        ts_path = self._download_from_url(ts_path)
+        blob = self._reader.read(ts_path)
+
+        data["input_path"] = ts_path
+        data["ts"] = blob
+        data["original_ts"] = blob
+        return data
+
+    def _download_from_url(self, in_path):
+        if in_path.startswith("http"):
+            file_name = Path(in_path).name
+            save_path = Path(CACHE_DIR) / "predict_input" / file_name
+            download(in_path, save_path, overwrite=True)
+            return save_path.as_posix()
+        return in_path
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # input_path: Path of the image.
+        return [["input_path"], ["ts"]]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        # original_image: Original image in hw or hwc format.
+        # original_image_size: Width and height of the original image.
+        return ["ts", "original_ts"]
+
+
+class TSCutOff(BaseTransform):
+    """Reorder the dimensions of the image from HWC to CHW."""
+
+    def __init__(self, size):
+        super().__init__()
+        self.size = size
+
+    def apply(self, data):
+        df = data["ts"].copy()
+        skip_len = self.size.get("skip_chunk_len", 0)
+        if len(df) < self.size["in_chunk_len"] + skip_len:
+            raise ValueError(
+                f"The length of the input data is {len(df)}, but it should be at least {self.size['in_chunk_len'] + self.size['skip_chunk_len']} for training."
+            )
+
+        df = df[-(self.size["in_chunk_len"] + skip_len) :]
+        data["ts"] = df
+        data["original_ts"] = df
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hwc format.
+        return ["ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in chw format.
+        return ["ts"]
+
+
+class TSNormalize(BaseTransform):
+    """Flip the image vertically or horizontally."""
+
+    def __init__(self, scale_path, params_info):
+        """
+        Initialize the instance.
+
+        Args:
+            mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
+                flipping. Default: 'H'.
+        """
+        super().__init__()
+        self.scaler = joblib.load(scale_path)
+        self.params_info = params_info
+
+    def apply(self, data):
+        """apply"""
+        df = data["ts"].copy()
+        if self.params_info.get("target_cols", None) is not None:
+            df[self.params_info["target_cols"]] = self.scaler.transform(
+                df[self.params_info["target_cols"]]
+            )
+        if self.params_info.get("feature_cols", None) is not None:
+            df[self.params_info["feature_cols"]] = self.scaler.transform(
+                df[self.params_info["feature_cols"]]
+            )
+
+        data["ts"] = df
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+
+class TSDeNormalize(BaseTransform):
+    """Flip the image vertically or horizontally."""
+
+    def __init__(self, scale_path, params_info):
+        """
+        Initialize the instance.
+
+        Args:
+            mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
+                flipping. Default: 'H'.
+        """
+        super().__init__()
+        self.scaler = joblib.load(scale_path)
+        self.params_info = params_info
+
+    def apply(self, data):
+        """apply"""
+        future_target = data["pred_ts"].copy()
+        scale_cols = future_target.columns.values.tolist()
+        future_target[scale_cols] = self.scaler.inverse_transform(
+            future_target[scale_cols]
+        )
+        data["pred_ts"] = future_target
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["pred_ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["pred_ts"]
+
+
+class BuildTSDataset(BaseTransform):
+    """bulid the ts."""
+
+    def __init__(self, params_info):
+        """
+        Initialize the instance.
+
+        Args:
+            mode (str, optional): 'H' for horizontal flipping and 'V' for vertical
+                flipping. Default: 'H'.
+        """
+        super().__init__()
+        self.params_info = params_info
+
+    def apply(self, data):
+        """apply"""
+        df = data["ts"].copy()
+        tsdata = load_from_dataframe(df, **self.params_info)
+        data["ts"] = tsdata
+        data["original_ts"] = tsdata
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+
+class TimeFeature(BaseTransform):
+    """Normalize the image."""
+
+    def __init__(self, params_info, size, holiday=False):
+        """
+        Initialize the instance.
+        """
+        super().__init__()
+        self.freq = params_info["freq"]
+        self.size = size
+        self.holiday = holiday
+
+    def apply(self, data):
+        """apply"""
+        ts = data["ts"].copy()
+        if not self.holiday:
+            ts = time_feature(
+                ts,
+                self.freq,
+                ["hourofday", "dayofmonth", "dayofweek", "dayofyear"],
+                self.size["out_chunk_len"],
+            )
+        else:
+            ts = time_feature(
+                ts,
+                self.freq,
+                [
+                    "minuteofhour",
+                    "hourofday",
+                    "dayofmonth",
+                    "dayofweek",
+                    "dayofyear",
+                    "monthofyear",
+                    "weekofyear",
+                    "holidays",
+                ],
+                self.size["out_chunk_len"],
+            )
+        data["ts"] = ts
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+
+class BuildPadMask(BaseTransform):
+
+    def __init__(self, input_data):
+
+        super().__init__()
+        self.input_data = input_data
+
+    def apply(self, data):
+        """apply"""
+        df = data["ts"].copy()
+
+        if "features" in self.input_data:
+            df["features"] = df["past_target"]
+
+        if "pad_mask" in self.input_data:
+            target_dim = len(df["features"])
+            max_length = self.input_data["pad_mask"][-1]
+            if max_length > 0:
+                ones = np.ones(max_length, dtype=np.int32)
+                if max_length != target_dim:
+                    target_ndarray = np.array(df["features"]).astype(np.float32)
+                    target_ndarray_final = np.zeros(
+                        [max_length, target_dim], dtype=np.int32
+                    )
+                    end = min(target_dim, max_length)
+                    target_ndarray_final[:end, :] = target_ndarray
+                    df["features"] = target_ndarray_final
+                    ones[end:] = 0.0
+                    df["pad_mask"] = ones
+                else:
+                    df["pad_mask"] = ones
+        data["ts"] = df
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+
+class TStoArray(BaseTransform):
+
+    def __init__(self, input_data):
+
+        super().__init__()
+        self.input_data = input_data
+
+    def apply(self, data):
+        """apply"""
+        df = data["ts"].copy()
+        ts_list = []
+        input_name = list(self.input_data.keys())
+        input_name.sort()
+        for key in input_name:
+            ts_list.append(np.array(df[key]).astype("float32"))
+
+        data["ts"] = ts_list
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["ts"]
+
+
+class ArraytoTS(BaseTransform):
+
+    def __init__(self, info_params):
+
+        super().__init__()
+        self.info_params = info_params
+
+    def apply(self, data):
+        """apply"""
+        output_data = data["pred_ts"].copy()
+        if data["original_ts"].get("past_target", None) is not None:
+            ts = data["original_ts"]["past_target"]
+        elif data["original_ts"].get("observed_cov_numeric", None) is not None:
+            ts = data["original_ts"]["observed_cov_numeric"]
+        elif data["original_ts"].get("known_cov_numeric", None) is not None:
+            ts = data["original_ts"]["known_cov_numeric"]
+        elif data["original_ts"].get("static_cov_numeric", None) is not None:
+            ts = data["original_ts"]["static_cov_numeric"]
+        else:
+            raise ValueError("No value in original_ts")
+
+        column_name = (
+            self.info_params["target_cols"]
+            if "target_cols" in self.info_params
+            else self.info_params["feature_cols"]
+        )
+        if isinstance(self.info_params["freq"], str):
+            past_target_index = ts.index
+            if past_target_index.freq is None:
+                past_target_index.freq = pd.infer_freq(ts.index)
+            future_target_index = pd.date_range(
+                past_target_index[-1] + past_target_index.freq,
+                periods=output_data.shape[0],
+                freq=self.info_params["freq"],
+                name=self.info_params["time_col"],
+            )
+        elif isinstance(self.info_params["freq"], int):
+            start_idx = max(ts.index) + 1
+            stop_idx = start_idx + output_data.shape[0]
+            future_target_index = pd.RangeIndex(
+                start=start_idx,
+                stop=stop_idx,
+                step=self.info_params["freq"],
+                name=self.info_params["time_col"],
+            )
+
+        future_target = pd.DataFrame(
+            np.reshape(output_data, newshape=[output_data.shape[0], -1]),
+            index=future_target_index,
+            columns=column_name,
+        )
+        data["pred_ts"] = future_target
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["pred_ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["pred_ts"]
+
+
+class GetAnomaly(BaseTransform):
+
+    def __init__(self, model_threshold, info_params):
+
+        super().__init__()
+        self.model_threshold = model_threshold
+        self.info_params = info_params
+
+    def apply(self, data):
+        """apply"""
+        output_data = data["pred_ts"].copy()
+        if data["original_ts"].get("past_target", None) is not None:
+            ts = data["original_ts"]["past_target"]
+        elif data["original_ts"].get("observed_cov_numeric", None) is not None:
+            ts = data["original_ts"]["observed_cov_numeric"]
+        elif data["original_ts"].get("known_cov_numeric", None) is not None:
+            ts = data["original_ts"]["known_cov_numeric"]
+        elif data["original_ts"].get("static_cov_numeric", None) is not None:
+            ts = data["original_ts"]["static_cov_numeric"]
+        else:
+            raise ValueError("No value in original_ts")
+        column_name = (
+            self.info_params["target_cols"]
+            if "target_cols" in self.info_params
+            else self.info_params["feature_cols"]
+        )
+
+        anomaly_score = np.mean(np.square(output_data - np.array(ts)), axis=-1)
+        anomaly_label = (anomaly_score >= self.model_threshold) + 0
+
+        past_target_index = ts.index
+        past_target_index.name = self.info_params["time_col"]
+        anomaly_label = pd.DataFrame(
+            np.reshape(anomaly_label, newshape=[output_data.shape[0], -1]),
+            index=past_target_index,
+            columns=["label"],
+        )
+        data["pred_ts"] = anomaly_label
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        # image: Image in hw or hwc format.
+        return ["pred_ts"]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        # image: Image in hw or hwc format.
+        return ["pred_ts"]

+ 424 - 0
paddlex/modules/base/predictor/transforms/ts_functions.py

@@ -0,0 +1,424 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict
+import numpy as np
+import pandas as pd
+import joblib
+import chinese_calendar
+from pandas.tseries.offsets import DateOffset, Easter, Day
+from pandas.tseries import holiday as hd
+from sklearn.preprocessing import StandardScaler
+
+
+MAX_WINDOW = 183 + 17
+EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
+NewYearsDay = hd.Holiday("New Years Day", month=1, day=1)
+SuperBowl = hd.Holiday("Superbowl", month=2, day=1, offset=DateOffset(weekday=hd.SU(1)))
+MothersDay = hd.Holiday(
+    "Mothers Day", month=5, day=1, offset=DateOffset(weekday=hd.SU(2))
+)
+IndependenceDay = hd.Holiday("Independence Day", month=7, day=4)
+ChristmasEve = hd.Holiday("Christmas", month=12, day=24)
+ChristmasDay = hd.Holiday("Christmas", month=12, day=25)
+NewYearsEve = hd.Holiday("New Years Eve", month=12, day=31)
+BlackFriday = hd.Holiday(
+    "Black Friday",
+    month=11,
+    day=1,
+    offset=[pd.DateOffset(weekday=hd.TH(4)), Day(1)],
+)
+CyberMonday = hd.Holiday(
+    "Cyber Monday",
+    month=11,
+    day=1,
+    offset=[pd.DateOffset(weekday=hd.TH(4)), Day(4)],
+)
+
+HOLIDAYS = [
+    hd.EasterMonday,
+    hd.GoodFriday,
+    hd.USColumbusDay,
+    hd.USLaborDay,
+    hd.USMartinLutherKingJr,
+    hd.USMemorialDay,
+    hd.USPresidentsDay,
+    hd.USThanksgivingDay,
+    EasterSunday,
+    NewYearsDay,
+    SuperBowl,
+    MothersDay,
+    IndependenceDay,
+    ChristmasEve,
+    ChristmasDay,
+    NewYearsEve,
+    BlackFriday,
+    CyberMonday,
+]
+
+
+def _cal_year(
+    x: np.datetime64,
+):
+    return x.year
+
+
+def _cal_month(
+    x: np.datetime64,
+):
+    return x.month
+
+
+def _cal_day(
+    x: np.datetime64,
+):
+    return x.day
+
+
+def _cal_hour(
+    x: np.datetime64,
+):
+    return x.hour
+
+
+def _cal_weekday(
+    x: np.datetime64,
+):
+    return x.dayofweek
+
+
+def _cal_quarter(
+    x: np.datetime64,
+):
+    return x.quarter
+
+
+def _cal_hourofday(
+    x: np.datetime64,
+):
+    return x.hour / 23.0 - 0.5
+
+
+def _cal_dayofweek(
+    x: np.datetime64,
+):
+    return x.dayofweek / 6.0 - 0.5
+
+
+def _cal_dayofmonth(
+    x: np.datetime64,
+):
+    return x.day / 30.0 - 0.5
+
+
+def _cal_dayofyear(
+    x: np.datetime64,
+):
+    return x.dayofyear / 364.0 - 0.5
+
+
+def _cal_weekofyear(
+    x: np.datetime64,
+):
+    return x.weekofyear / 51.0 - 0.5
+
+
+def _cal_holiday(
+    x: np.datetime64,
+):
+    return float(chinese_calendar.is_holiday(x))
+
+
+def _cal_workday(
+    x: np.datetime64,
+):
+    return float(chinese_calendar.is_workday(x))
+
+
+def _cal_minuteofhour(
+    x: np.datetime64,
+):
+    return x.minute / 59 - 0.5
+
+
+def _cal_monthofyear(
+    x: np.datetime64,
+):
+    return x.month / 11.0 - 0.5
+
+
+CAL_DATE_METHOD = {
+    "year": _cal_year,
+    "month": _cal_month,
+    "day": _cal_day,
+    "hour": _cal_hour,
+    "weekday": _cal_weekday,
+    "quarter": _cal_quarter,
+    "minuteofhour": _cal_minuteofhour,
+    "monthofyear": _cal_monthofyear,
+    "hourofday": _cal_hourofday,
+    "dayofweek": _cal_dayofweek,
+    "dayofmonth": _cal_dayofmonth,
+    "dayofyear": _cal_dayofyear,
+    "weekofyear": _cal_weekofyear,
+    "is_holiday": _cal_holiday,
+    "is_workday": _cal_workday,
+}
+
+
+def load_from_one_dataframe(
+    data: Union[pd.DataFrame, pd.Series],
+    time_col: Optional[str] = None,
+    value_cols: Optional[Union[List[str], str]] = None,
+    freq: Optional[Union[str, int]] = None,
+    drop_tail_nan: bool = False,
+    dtype: Optional[Union[type, Dict[str, type]]] = None,
+):
+
+    series_data = None
+    if value_cols is None:
+        if isinstance(data, pd.Series):
+            series_data = data.copy()
+        else:
+            series_data = data.loc[:, data.columns != time_col].copy()
+    else:
+        series_data = data.loc[:, value_cols].copy()
+
+    if time_col:
+        if time_col not in data.columns:
+            raise ValueError(
+                "The time column: {} doesn't exist in the `data`!".format(time_col)
+            )
+        time_col_vals = data.loc[:, time_col]
+    else:
+        time_col_vals = data.index
+
+    if np.issubdtype(time_col_vals.dtype, np.integer) and isinstance(freq, str):
+        time_col_vals = time_col_vals.astype(str)
+
+    if np.issubdtype(time_col_vals.dtype, np.integer):
+        if freq:
+            if not isinstance(freq, int) or freq < 1:
+                raise ValueError(
+                    "The type of `freq` should be `int` when the type of `time_col` is `RangeIndex`."
+                )
+        else:
+            freq = 1
+        start_idx, stop_idx = min(time_col_vals), max(time_col_vals) + freq
+        if (stop_idx - start_idx) / freq != len(data):
+            raise ValueError("The number of rows doesn't match with the RangeIndex!")
+        time_index = pd.RangeIndex(start=start_idx, stop=stop_idx, step=freq)
+    elif np.issubdtype(time_col_vals.dtype, np.object_) or np.issubdtype(
+        time_col_vals.dtype, np.datetime64
+    ):
+        time_col_vals = pd.to_datetime(time_col_vals, infer_datetime_format=True)
+        time_index = pd.DatetimeIndex(time_col_vals)
+        if freq:
+            if not isinstance(freq, str):
+                raise ValueError(
+                    "The type of `freq` should be `str` when the type of `time_col` is `DatetimeIndex`."
+                )
+        else:
+            # If freq is not provided and automatic inference fail, throw exception
+            freq = pd.infer_freq(time_index)
+            if freq is None:
+                raise ValueError(
+                    "Failed to infer the `freq`. A valid `freq` is required."
+                )
+            if freq[0] == "-":
+                freq = freq[1:]
+    else:
+        raise ValueError("The type of `time_col` is invalid.")
+    if isinstance(series_data, pd.Series):
+        series_data = series_data.to_frame()
+    series_data.set_index(time_index, inplace=True)
+    series_data.sort_index(inplace=True)
+    return series_data
+
+
+def load_from_dataframe(
+    df: pd.DataFrame,
+    group_id: str = None,
+    time_col: Optional[str] = None,
+    target_cols: Optional[Union[List[str], str]] = None,
+    label_col: Optional[Union[List[str], str]] = None,
+    observed_cov_cols: Optional[Union[List[str], str]] = None,
+    feature_cols: Optional[Union[List[str], str]] = None,
+    known_cov_cols: Optional[Union[List[str], str]] = None,
+    static_cov_cols: Optional[Union[List[str], str]] = None,
+    freq: Optional[Union[str, int]] = None,
+    fill_missing_dates: bool = False,
+    fillna_method: str = "pre",
+    fillna_window_size: int = 10,
+    **kwargs,
+):
+
+    dfs = []  # seperate multiple group
+    if group_id is not None:
+        group_unique = df[group_id].unique()
+        for column in group_unique:
+            dfs.append(df[df[group_id].isin([column])])
+    else:
+        dfs = [df]
+    res = []
+    if label_col:
+        if isinstance(label_col, str) and len(label_col) > 1:
+            raise ValueError("The length of label_col must be 1.")
+        target_cols = label_col
+    if feature_cols:
+        observed_cov_cols = feature_cols
+    for df in dfs:
+        target = None
+        observed_cov = None
+        known_cov = None
+        static_cov = dict()
+        if not any([target_cols, observed_cov_cols, known_cov_cols, static_cov_cols]):
+            target = load_from_one_dataframe(
+                df,
+                time_col,
+                [a for a in df.columns if a != time_col],
+                freq,
+            )
+
+        else:
+            if target_cols:
+                target = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    target_cols,
+                    freq,
+                )
+
+            if observed_cov_cols:
+                observed_cov = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    observed_cov_cols,
+                    freq,
+                )
+
+            if known_cov_cols:
+                known_cov = load_from_one_dataframe(
+                    df,
+                    time_col,
+                    known_cov_cols,
+                    freq,
+                )
+
+            if static_cov_cols:
+                if isinstance(static_cov_cols, str):
+                    static_cov_cols = [static_cov_cols]
+                for col in static_cov_cols:
+                    if col not in df.columns or len(np.unique(df[col])) != 1:
+                        raise ValueError(
+                            "static cov cals data is not in columns or schema is not right!"
+                        )
+                    static_cov[col] = df[col].iloc[0]
+        res.append(
+            {
+                "past_target": target,
+                "observed_cov_numeric": observed_cov,
+                "known_cov_numeric": known_cov,
+                "static_cov_numeric": static_cov,
+            }
+        )
+    return res[0]
+
+
+def _distance_to_holiday(holiday):
+    def _distance_to_day(index):
+        holiday_date = holiday.dates(
+            index - pd.Timedelta(days=MAX_WINDOW),
+            index + pd.Timedelta(days=MAX_WINDOW),
+        )
+        assert (
+            len(holiday_date) != 0
+        ), f"No closest holiday for the date index {index} found."
+        # It sometimes returns two dates if it is exactly half a year after the
+        # holiday. In this case, the smaller distance (182 days) is returned.
+        return float((index - holiday_date[0]).days)
+
+    return _distance_to_day
+
+
+def time_feature(dataset, freq, feature_cols, extend_points, inplace: bool = False):
+    """
+    Transform time column to time features.
+
+    Args:
+        dataset(TSDataset): Dataset to be transformed.
+        inplace(bool): Whether to perform the transformation inplace. default=False
+
+    Returns:
+        TSDataset
+    """
+    new_ts = dataset
+    if not inplace:
+        new_ts = dataset.copy()
+    # Get known_cov
+    kcov = new_ts["known_cov_numeric"]
+    if not kcov:
+        tf_kcov = new_ts["past_target"].index.to_frame()
+    else:
+        tf_kcov = kcov.index.to_frame()
+    time_col = tf_kcov.columns[0]
+    if np.issubdtype(tf_kcov[time_col].dtype, np.integer):
+        raise ValueError(
+            "The time_col can't be the type of numpy.integer, and it must be the type of numpy.datetime64"
+        )
+    if not kcov:
+        freq = freq if freq is not None else pd.infer_freq(tf_kcov[time_col])
+        extend_time = pd.date_range(
+            start=tf_kcov[time_col][-1],
+            freq=freq,
+            periods=extend_points + 1,
+            closed="right",
+            name=time_col,
+        ).to_frame()
+        tf_kcov = pd.concat([tf_kcov, extend_time])
+
+    for k in feature_cols:
+        if k != "holidays":
+            v = tf_kcov[time_col].apply(lambda x: CAL_DATE_METHOD[k](x))
+            v.index = tf_kcov[time_col]
+
+            if new_ts["known_cov_numeric"] is None:
+                new_ts["known_cov_numeric"] = pd.DataFrame(v.rename(k), index=v.index)
+            else:
+                new_ts["known_cov_numeric"][k] = v.rename(k).reindex(
+                    new_ts["known_cov_numeric"].index
+                )
+
+        else:
+            holidays_col = []
+            for i, H in enumerate(HOLIDAYS):
+                v = tf_kcov[time_col].apply(_distance_to_holiday(H))
+                v.index = tf_kcov[time_col]
+                holidays_col.append(k + "_" + str(i))
+                if new_ts["known_cov_numeric"] is None:
+                    new_ts["known_cov_numeric"] = pd.DataFrame(
+                        v.rename(k + "_" + str(i)), index=v.index
+                    )
+                else:
+                    new_ts["known_cov_numeric"][k + "_" + str(i)] = v.rename(k).reindex(
+                        new_ts["known_cov_numeric"].index
+                    )
+
+            scaler = StandardScaler()
+            scaler.fit(new_ts["known_cov_numeric"][holidays_col])
+            new_ts["known_cov_numeric"][holidays_col] = scaler.transform(
+                new_ts["known_cov_numeric"][holidays_col]
+            )
+    return new_ts

+ 1 - 0
paddlex/modules/base/predictor/utils/paddle_inference_predictor.py

@@ -124,6 +124,7 @@ No need to generate again."
 
         # Get input and output handlers
         input_names = predictor.get_input_names()
+        input_names.sort()
         input_handlers = []
         output_handlers = []
         for input_name in input_names:

+ 2 - 8
paddlex/modules/ts_anomaly_detection/predictor.py → paddlex/modules/ts_anomaly_detection/predictor/__init__.py

@@ -13,11 +13,5 @@
 # limitations under the License.
 
 
-from ..ts_forecast import TSFCPredictor
-from .model_list import MODELS
-
-
-class TSADPredictor(TSFCPredictor):
-    """TS Anomaly Detection Model Predictor"""
-
-    entities = MODELS
+from .predictor import TSADPredictor
+from . import transforms

+ 27 - 0
paddlex/modules/ts_anomaly_detection/predictor/keys.py

@@ -0,0 +1,27 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class TSFCKeys(object):
+    """
+    This class defines a set of keys used for communication of Seg predictors
+    and transforms. Both predictors and transforms accept a dict or a list of
+    dicts as input, and they get the objects of their interest from the dict, or
+    put the generated objects into the dict, all based on these keys.
+    """
+
+    # Common keys
+    TS = "ts"
+    TS_PATH = "input_path"
+    PRED = "pred_ts"

+ 97 - 0
paddlex/modules/ts_anomaly_detection/predictor/predictor.py

@@ -0,0 +1,97 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import numpy as np
+
+from ....utils import logging
+from ...base.predictor.transforms import ts_common
+from ...base import BasePredictor
+from .keys import TSFCKeys as K
+from . import transforms as T
+from .utils import InnerConfig
+from ..model_list import MODELS
+
+
+class TSADPredictor(BasePredictor):
+    """SegPredictor"""
+
+    entities = MODELS
+
+    def __init__(
+        self,
+        model_name,
+        model_dir,
+        kernel_option,
+        output,
+        pre_transforms=None,
+        post_transforms=None,
+    ):
+        super().__init__(
+            model_name=model_name,
+            model_dir=model_dir,
+            kernel_option=kernel_option,
+            output=output,
+            pre_transforms=pre_transforms,
+            post_transforms=post_transforms,
+        )
+
+    def load_other_src(self):
+        """load the inner config file"""
+        infer_cfg_file_path = os.path.join(self.model_dir, "inference.yml")
+        if not os.path.exists(infer_cfg_file_path):
+            raise FileNotFoundError(f"Cannot find config file: {infer_cfg_file_path}")
+        return InnerConfig(infer_cfg_file_path, self.model_dir)
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [[K.TS], [K.TS_PATH]]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.PRED]
+
+    def _run(self, batch_input):
+        """run"""
+        n = len(batch_input[0][K.TS])
+        input_ = [
+            np.stack([lst[i] for lst in [data[K.TS] for data in batch_input]], axis=0)
+            for i in range(n)
+        ]
+
+        outputs = self._predictor.predict(input_)
+        batch_output = outputs[0]
+        # In-place update
+        for dict_, output in zip(batch_input, batch_output):
+            dict_[K.PRED] = output
+        return batch_input
+
+    def _get_pre_transforms_from_config(self):
+        """_get_pre_transforms_from_config"""
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, ts_common.ReadTS())
+
+        return pre_transforms
+
+    def _get_post_transforms_from_config(self):
+        """_get_post_transforms_from_config"""
+        post_transforms = self.other_src.post_transforms
+        post_transforms.append(T.SaveTSResults(self.output))
+        return post_transforms

+ 73 - 0
paddlex/modules/ts_anomaly_detection/predictor/transforms.py

@@ -0,0 +1,73 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+import numpy as np
+from PIL import Image
+
+from ....utils import logging
+from ...base import BaseTransform
+from ...base.predictor.io.writers import TSWriter
+from .keys import TSFCKeys as K
+
+__all__ = ["SaveTSResults"]
+
+
+class SaveTSResults(BaseTransform):
+    """SaveSegResults"""
+
+    def __init__(self, save_dir):
+        super().__init__()
+        self.save_dir = save_dir
+        self._writer = TSWriter(backend="pandas")
+
+    def apply(self, data):
+        """apply"""
+        pred_ts = data[K.PRED]
+        file_name = os.path.basename(data[K.TS_PATH])
+
+        ts_save_path = os.path.join(self.save_dir, file_name)
+        self._write_ts(ts_save_path, pred_ts)
+
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.PRED]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return []
+
+    def _write_ts(self, path, ts):
+        """write ts"""
+        if os.path.exists(path):
+            logging.warning(f"{path} already exists. Overwriting it.")
+        self._writer.write(path, ts)
+
+    @staticmethod
+    def _add_suffix(path, suffix):
+        """add suffix"""
+        stem, ext = os.path.splitext(path)
+        return stem + suffix + ext
+
+    @staticmethod
+    def _replace_ext(path, new_ext):
+        """replace ext"""
+        stem, _ = os.path.splitext(path)
+        return stem + new_ext

+ 87 - 0
paddlex/modules/ts_anomaly_detection/predictor/utils.py

@@ -0,0 +1,87 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import codecs
+import yaml
+import os
+
+from ....utils import logging
+from ...base.predictor.transforms import ts_common
+
+
+class InnerConfig(object):
+    """Inner Config"""
+
+    def __init__(self, config_path, model_dir=None):
+        self.inner_cfg = self.load(config_path)
+        self.model_dir = model_dir
+
+    def load(self, config_path):
+        """load config"""
+        with codecs.open(config_path, "r", "utf-8") as file:
+            dic = yaml.load(file, Loader=yaml.FullLoader)
+        return dic
+
+    @property
+    def pre_transforms(self):
+        """read preprocess transforms from  config file"""
+
+        tfs = []
+        if self.inner_cfg.get("info_params", False):
+            tf = ts_common.TSCutOff(self.inner_cfg["size"])
+            tfs.append(tf)
+
+            if self.inner_cfg.get("scale", False):
+                scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+                if not os.path.exists(scaler_file_path):
+                    raise FileNotFoundError(
+                        f"Cannot find scaler file: {scaler_file_path}"
+                    )
+                tf = ts_common.TSNormalize(
+                    scaler_file_path, self.inner_cfg["info_params"]
+                )
+                tfs.append(tf)
+
+            tf = ts_common.BuildTSDataset(self.inner_cfg["info_params"])
+            tfs.append(tf)
+
+            if self.inner_cfg.get("time_feat", False):
+                tf = ts_common.TimeFeature(
+                    self.inner_cfg["info_params"],
+                    self.inner_cfg["size"],
+                    self.inner_cfg["holiday"],
+                )
+                tfs.append(tf)
+            tf = ts_common.TStoArray(self.inner_cfg["input_data"])
+
+            tfs.append(tf)
+        else:
+            raise ValueError("info_params is not found in config file")
+
+        return tfs
+
+    @property
+    def post_transforms(self):
+        """read preprocess transforms from  config file"""
+        tfs = []
+        if self.inner_cfg.get("info_params", False):
+            tf = ts_common.GetAnomaly(
+                self.inner_cfg["model_threshold"], self.inner_cfg["info_params"]
+            )
+            tfs.append(tf)
+        else:
+            raise ValueError("info_params is not found in config file")
+
+        return tfs

+ 59 - 5
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -139,23 +139,36 @@ class TSADTrainDeamon(BaseTrainDeamon):
 
     def update_result(self, result, train_output):
         """update every result"""
-        config = Path(train_output).joinpath("config.yaml")
-        if not config.exists():
+        train_output = Path(train_output).resolve()
+        config_path = Path(train_output).joinpath("config.yaml").resolve()
+        if not config_path.exists():
             return result
 
-        result["config"] = config
+        model_name = result["model_name"]
+        if (
+            model_name in self.config_recorder
+            and self.config_recorder[model_name] != config_path
+        ):
+            result["models"] = self.init_model_pkg()
+        result["config"] = config_path
+        self.config_recorder[model_name] = config_path
+
+        result["config"] = config_path
         result["train_log"] = self.update_train_log(train_output)
         result["visualdl_log"] = self.update_vdl_log(train_output)
         result["label_dict"] = self.update_label_dict(train_output)
-        self.update_models(result, train_output, "best")
+        model = self.get_model(result["model_name"], config_path)
+        self.update_models(result, model, train_output, "best")
+
         return result
 
-    def update_models(self, result, train_output, model_key):
+    def update_models(self, result, model, train_output, model_key):
         """update info of the models to be saved"""
         pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
         if pdparams.exists():
 
             score = self.get_score(Path(train_output).joinpath("score.json"))
+
             result["models"][model_key] = {
                 "score": "%.3f" % score,
                 "pdparams": pdparams,
@@ -167,6 +180,47 @@ class TSADTrainDeamon(BaseTrainDeamon):
                 "pdiparams": pdparams,
                 "pdiparams.info": "",
             }
+        self.update_inference_model(
+            model,
+            train_output,
+            train_output.joinpath(f"inference"),
+            result["models"][model_key],
+        )
+
+    def update_inference_model(
+        self, model, weight_path, export_save_dir, result_the_model
+    ):
+        """update inference model"""
+        export_save_dir.mkdir(parents=True, exist_ok=True)
+        export_result = model.export(weight_path=weight_path, save_dir=export_save_dir)
+
+        if export_result.returncode == 0:
+            inference_config = export_save_dir.joinpath("inference.yml")
+            if not inference_config.exists():
+                inference_config = ""
+            use_pir = (
+                hasattr(paddle.framework, "use_pir_api")
+                and paddle.framework.use_pir_api()
+            )
+            pdmodel = (
+                export_save_dir.joinpath("inference.json")
+                if use_pir
+                else export_save_dir.joinpath("inference.pdmodel")
+            )
+            pdiparams = export_save_dir.joinpath("inference.pdiparams")
+            pdiparams_info = (
+                "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
+            )
+        else:
+            inference_config = ""
+            pdmodel = ""
+            pdiparams = ""
+            pdiparams_info = ""
+
+        result_the_model["inference_config"] = inference_config
+        result_the_model["pdmodel"] = pdmodel
+        result_the_model["pdiparams"] = pdiparams
+        result_the_model["pdiparams.info"] = pdiparams_info
 
     def get_score(self, score_path):
         """get the score by pdstates file"""

+ 2 - 9
paddlex/modules/ts_classification/predictor.py → paddlex/modules/ts_classification/predictor/__init__.py

@@ -13,12 +13,5 @@
 # limitations under the License.
 
 
-from ..ts_forecast import TSFCPredictor
-from .model_list import MODELS
-from ...utils.errors import raise_unsupported_api_error
-
-
-class TSCLSPredictor(TSFCPredictor):
-    """TS Anomaly Detection Model Predictor"""
-
-    entities = MODELS
+from .predictor import TSCLSPredictor
+from . import transforms

+ 27 - 0
paddlex/modules/ts_classification/predictor/keys.py

@@ -0,0 +1,27 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class TSFCKeys(object):
+    """
+    This class defines a set of keys used for communication of Seg predictors
+    and transforms. Both predictors and transforms accept a dict or a list of
+    dicts as input, and they get the objects of their interest from the dict, or
+    put the generated objects into the dict, all based on these keys.
+    """
+
+    # Common keys
+    TS = "ts"
+    TS_PATH = "input_path"
+    PRED = "pred_ts"

+ 103 - 0
paddlex/modules/ts_classification/predictor/predictor.py

@@ -0,0 +1,103 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+import numpy as np
+
+from ....utils import logging
+from ...base.predictor.transforms import ts_common
+from ...base import BasePredictor
+from .keys import TSFCKeys as K
+from . import transforms as T
+from .utils import InnerConfig
+from ..model_list import MODELS
+
+
+class TSCLSPredictor(BasePredictor):
+    """SegPredictor"""
+
+    entities = MODELS
+
+    def __init__(
+        self,
+        model_name,
+        model_dir,
+        kernel_option,
+        output,
+        pre_transforms=None,
+        post_transforms=None,
+    ):
+        super().__init__(
+            model_name=model_name,
+            model_dir=model_dir,
+            kernel_option=kernel_option,
+            output=output,
+            pre_transforms=pre_transforms,
+            post_transforms=post_transforms,
+        )
+
+    def load_other_src(self):
+        """load the inner config file"""
+        infer_cfg_file_path = os.path.join(self.model_dir, "inference.yml")
+        if not os.path.exists(infer_cfg_file_path):
+            raise FileNotFoundError(f"Cannot find config file: {infer_cfg_file_path}")
+        return InnerConfig(infer_cfg_file_path, self.model_dir)
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [[K.TS], [K.TS_PATH]]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.PRED]
+
+    def _run(self, batch_input):
+        """run"""
+        n = len(batch_input[0][K.TS])
+        input_ = [
+            np.stack([lst[i] for lst in [data[K.TS] for data in batch_input]], axis=0)
+            for i in range(n)
+        ]
+
+        outputs = self._predictor.predict(input_)
+        batch_output = outputs[0]
+        # In-place update
+        for dict_, output in zip(batch_input, batch_output):
+            dict_[K.PRED] = output
+        return batch_input
+
+    def _get_pre_transforms_from_config(self):
+        """_get_pre_transforms_from_config"""
+        # If `K.TS` (the decoded image) is found, return a default list of
+        # transformation operators for the input (if possible).
+        # If `K.TS` (the decoded image) is not found, `K.IM_PATH` (the image
+        # path) must be contained in the input. In this case, we infer
+        # transformation operators from the config file.
+        # In cases where the input contains both `K.TS` and `K.IM_PATH`,
+        # `K.TS` takes precedence over `K.IM_PATH`.
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, ts_common.ReadTS())
+
+        return pre_transforms
+
+    def _get_post_transforms_from_config(self):
+
+        return [T.SaveTSClsResults(self.output)]

+ 79 - 0
paddlex/modules/ts_classification/predictor/transforms.py

@@ -0,0 +1,79 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+import numpy as np
+import pandas as pd
+
+from ....utils import logging
+from ...base import BaseTransform
+from ...base.predictor.io.writers import TSWriter
+from .keys import TSFCKeys as K
+
+__all__ = ["SaveTSClsResults"]
+
+
+class SaveTSClsResults(BaseTransform):
+    """SaveSegResults"""
+
+    def __init__(self, save_dir):
+        super().__init__()
+        self.save_dir = save_dir
+        self._writer = TSWriter(backend="pandas")
+
+    def apply(self, data):
+        """apply"""
+        pred_ts = data[K.PRED]
+        pred_ts -= np.max(pred_ts, axis=-1, keepdims=True)
+        pred_ts = np.exp(pred_ts) / np.sum(np.exp(pred_ts), axis=-1, keepdims=True)
+        classid = np.argmax(pred_ts, axis=-1)
+        pred_score = pred_ts[classid]
+        result = {"classid": [classid], "score": [pred_score]}
+        result = pd.DataFrame.from_dict(result)
+        result.index.name = "sample"
+        file_name = os.path.basename(data[K.TS_PATH])
+        ts_save_path = os.path.join(self.save_dir, file_name)
+        self._write_ts(ts_save_path, result)
+
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.PRED]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return []
+
+    def _write_ts(self, path, ts):
+        """write ts"""
+        if os.path.exists(path):
+            logging.warning(f"{path} already exists. Overwriting it.")
+        self._writer.write(path, ts)
+
+    @staticmethod
+    def _add_suffix(path, suffix):
+        """add suffix"""
+        stem, ext = os.path.splitext(path)
+        return stem + suffix + ext
+
+    @staticmethod
+    def _replace_ext(path, new_ext):
+        """replace ext"""
+        stem, _ = os.path.splitext(path)
+        return stem + new_ext

+ 66 - 0
paddlex/modules/ts_classification/predictor/utils.py

@@ -0,0 +1,66 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import codecs
+import yaml
+import os
+
+from ....utils import logging
+from ...base.predictor.transforms import ts_common
+
+
+class InnerConfig(object):
+    """Inner Config"""
+
+    def __init__(self, config_path, model_dir=None):
+        self.inner_cfg = self.load(config_path)
+        self.model_dir = model_dir
+
+    def load(self, config_path):
+        """load config"""
+        with codecs.open(config_path, "r", "utf-8") as file:
+            dic = yaml.load(file, Loader=yaml.FullLoader)
+        return dic
+
+    @property
+    def pre_transforms(self):
+        """read preprocess transforms from  config file"""
+
+        tfs = []
+        if self.inner_cfg.get("info_params", False):
+
+            if self.inner_cfg.get("scale", False):
+                scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+                if not os.path.exists(scaler_file_path):
+                    raise FileNotFoundError(
+                        f"Cannot find scaler file: {scaler_file_path}"
+                    )
+                tf = ts_common.TSNormalize(
+                    scaler_file_path, self.inner_cfg["info_params"]
+                )
+                tfs.append(tf)
+
+            tf = ts_common.BuildTSDataset(self.inner_cfg["info_params"])
+            tfs.append(tf)
+
+            tf = ts_common.BuildPadMask(self.inner_cfg["input_data"])
+            tfs.append(tf)
+
+            tf = ts_common.TStoArray(self.inner_cfg["input_data"])
+            tfs.append(tf)
+        else:
+            raise ValueError("info_params is not found in config file")
+
+        return tfs

+ 58 - 5
paddlex/modules/ts_classification/trainer.py

@@ -134,18 +134,30 @@ class TSCLSTrainDeamon(BaseTrainDeamon):
 
     def update_result(self, result, train_output):
         """update every result"""
-        config = Path(train_output).joinpath("config.yaml")
-        if not config.exists():
+        train_output = Path(train_output).resolve()
+        config_path = Path(train_output).joinpath("config.yaml").resolve()
+        if not config_path.exists():
             return result
 
-        result["config"] = config
+        model_name = result["model_name"]
+        if (
+            model_name in self.config_recorder
+            and self.config_recorder[model_name] != config_path
+        ):
+            result["models"] = self.init_model_pkg()
+        result["config"] = config_path
+        self.config_recorder[model_name] = config_path
+
+        result["config"] = config_path
         result["train_log"] = self.update_train_log(train_output)
         result["visualdl_log"] = self.update_vdl_log(train_output)
         result["label_dict"] = self.update_label_dict(train_output)
-        self.update_models(result, train_output, "best")
+        model = self.get_model(result["model_name"], config_path)
+        self.update_models(result, model, train_output, "best")
+
         return result
 
-    def update_models(self, result, train_output, model_key):
+    def update_models(self, result, model, train_output, model_key):
         """update info of the models to be saved"""
         pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
         if pdparams.exists():
@@ -163,6 +175,47 @@ class TSCLSTrainDeamon(BaseTrainDeamon):
                 "pdiparams": pdparams,
                 "pdiparams.info": "",
             }
+        self.update_inference_model(
+            model,
+            train_output,
+            train_output.joinpath(f"inference"),
+            result["models"][model_key],
+        )
+
+    def update_inference_model(
+        self, model, weight_path, export_save_dir, result_the_model
+    ):
+        """update inference model"""
+        export_save_dir.mkdir(parents=True, exist_ok=True)
+        export_result = model.export(weight_path=weight_path, save_dir=export_save_dir)
+
+        if export_result.returncode == 0:
+            inference_config = export_save_dir.joinpath("inference.yml")
+            if not inference_config.exists():
+                inference_config = ""
+            use_pir = (
+                hasattr(paddle.framework, "use_pir_api")
+                and paddle.framework.use_pir_api()
+            )
+            pdmodel = (
+                export_save_dir.joinpath("inference.json")
+                if use_pir
+                else export_save_dir.joinpath("inference.pdmodel")
+            )
+            pdiparams = export_save_dir.joinpath("inference.pdiparams")
+            pdiparams_info = (
+                "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
+            )
+        else:
+            inference_config = ""
+            pdmodel = ""
+            pdiparams = ""
+            pdiparams_info = ""
+
+        result_the_model["inference_config"] = inference_config
+        result_the_model["pdmodel"] = pdmodel
+        result_the_model["pdiparams"] = pdiparams
+        result_the_model["pdiparams.info"] = pdiparams_info
 
     def get_score(self, score_path):
         """get the score by pdstates file"""

+ 0 - 121
paddlex/modules/ts_forecast/predictor.py

@@ -1,121 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from pathlib import Path
-import tarfile
-
-from typing import Union
-from ...utils import logging
-from ..base.build_model import build_model
-from ..base.predictor import BasePredictor
-from ...utils.errors import raise_unsupported_api_error, raise_model_not_found_error
-from .model_list import MODELS
-from ...utils.download import download
-from ...utils.cache import CACHE_DIR
-
-
-class TSFCPredictor(BasePredictor):
-    """TS Forecast Model Predictor"""
-
-    entities = MODELS
-
-    def __init__(self, model_name, model_dir, kernel_option, output):
-        """initialize"""
-        model_dir = self._download_from_url(model_dir)
-        self.model_dir = self.uncompress_tar_file(model_dir)
-
-        self.device = kernel_option.get_device()
-        self.output = output
-        config_path = self.get_config_path()
-        self.pdx_config, self.pdx_model = build_model(
-            model_name, config_path=config_path
-        )
-
-    def uncompress_tar_file(self, model_dir):
-        """unpackage the tar file containing training outputs and update weight path"""
-        if tarfile.is_tarfile(model_dir):
-            dest_path = Path(model_dir).parent
-            with tarfile.open(model_dir, "r") as tar:
-                tar.extractall(path=dest_path)
-            return dest_path / "best_accuracy.pdparams/best_model/model.pdparams"
-        return model_dir
-
-    def get_config_path(self) -> Union[str, None]:
-        """
-        get config path
-
-        Returns:
-            config_path (str): The path to the config
-
-        """
-        if Path(self.model_dir).exists():
-            config_path = Path(self.model_dir).parent.parent / "config.yaml"
-            if config_path.exists():
-                return config_path
-            else:
-                logging.warning(
-                    f"The config file(`{config_path}`) related to model weight file(`{self.model_dir}`) \
-is not exist, use default instead."
-                )
-        else:
-            raise_model_not_found_error(self.model_dir)
-        return None
-
-    def _download_from_url(self, in_path):
-        if in_path.startswith("http"):
-            file_name = Path(in_path).name
-            save_path = Path(CACHE_DIR) / "predict_input" / file_name
-            download(in_path, save_path, overwrite=True)
-            return save_path.as_posix()
-        return in_path
-
-    def predict(self, input):
-        """execute model predict"""
-        # self.update_config()
-        input["input_path"] = self._download_from_url(input["input_path"])
-        result = self.pdx_model.predict(**input, **self.get_predict_kwargs())
-        assert (
-            result.returncode == 0
-        ), f"Encountered an unexpected error({result.returncode}) in predicting!"
-        return result
-
-    def get_predict_kwargs(self) -> dict:
-        """get key-value arguments of model predict function
-
-        Returns:
-            dict: the arguments of predict function.
-        """
-        return {
-            "weight_path": self.model_dir,
-            "device": self.device,
-            "save_dir": self.output,
-        }
-
-    def _get_post_transforms_from_config(self):
-        pass
-
-    def _get_pre_transforms_from_config(self):
-        pass
-
-    def _run(self):
-        pass
-
-    def get_input_keys(self):
-        """get input keys"""
-        return ["input_path"]
-
-    def get_output_keys(self):
-        """get output keys"""
-        pass

+ 17 - 0
paddlex/modules/ts_forecast/predictor/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .predictor import TSFCPredictor
+from . import transforms

+ 27 - 0
paddlex/modules/ts_forecast/predictor/keys.py

@@ -0,0 +1,27 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class TSFCKeys(object):
+    """
+    This class defines a set of keys used for communication of Seg predictors
+    and transforms. Both predictors and transforms accept a dict or a list of
+    dicts as input, and they get the objects of their interest from the dict, or
+    put the generated objects into the dict, all based on these keys.
+    """
+
+    # Common keys
+    TS = "ts"
+    TS_PATH = "input_path"
+    PRED = "pred_ts"

+ 108 - 0
paddlex/modules/ts_forecast/predictor/predictor.py

@@ -0,0 +1,108 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+import numpy as np
+
+from ....utils import logging
+from ...base.predictor.transforms import ts_common
+from ...base import BasePredictor
+from .keys import TSFCKeys as K
+from . import transforms as T
+from .utils import InnerConfig
+from ..model_list import MODELS
+
+
+class TSFCPredictor(BasePredictor):
+    """SegPredictor"""
+
+    entities = MODELS
+
+    def __init__(
+        self,
+        model_name,
+        model_dir,
+        kernel_option,
+        output,
+        pre_transforms=None,
+        post_transforms=None,
+        has_prob_map=False,
+    ):
+        super().__init__(
+            model_name=model_name,
+            model_dir=model_dir,
+            kernel_option=kernel_option,
+            output=output,
+            pre_transforms=pre_transforms,
+            post_transforms=post_transforms,
+        )
+        self.has_prob_map = has_prob_map
+
+    def load_other_src(self):
+        """load the inner config file"""
+        infer_cfg_file_path = os.path.join(self.model_dir, "inference.yml")
+        if not os.path.exists(infer_cfg_file_path):
+            raise FileNotFoundError(f"Cannot find config file: {infer_cfg_file_path}")
+        return InnerConfig(infer_cfg_file_path, self.model_dir)
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [[K.TS], [K.TS_PATH]]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.PRED]
+
+    def _run(self, batch_input):
+        """run"""
+        n = len(batch_input[0][K.TS])
+        input_ = [
+            np.stack([lst[i] for lst in [data[K.TS] for data in batch_input]], axis=0)
+            for i in range(n)
+        ]
+
+        outputs = self._predictor.predict(input_)
+        batch_output = outputs[0]
+        # In-place update
+        for dict_, output in zip(batch_input, batch_output):
+            dict_[K.PRED] = output
+        return batch_input
+
+    def _get_pre_transforms_from_config(self):
+        """_get_pre_transforms_from_config"""
+        # If `K.TS` (the decoded image) is found, return a default list of
+        # transformation operators for the input (if possible).
+        # If `K.TS` (the decoded image) is not found, `K.IM_PATH` (the image
+        # path) must be contained in the input. In this case, we infer
+        # transformation operators from the config file.
+        # In cases where the input contains both `K.TS` and `K.IM_PATH`,
+        # `K.TS` takes precedence over `K.IM_PATH`.
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, ts_common.ReadTS())
+
+        return pre_transforms
+
+    def _get_post_transforms_from_config(self):
+        """_get_post_transforms_from_config"""
+        post_transforms = self.other_src.post_transforms
+        post_transforms.append(T.SaveTSResults(self.output))
+        return post_transforms

+ 73 - 0
paddlex/modules/ts_forecast/predictor/transforms.py

@@ -0,0 +1,73 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+import numpy as np
+from PIL import Image
+
+from ....utils import logging
+from ...base import BaseTransform
+from ...base.predictor.io.writers import TSWriter
+from .keys import TSFCKeys as K
+
+__all__ = ["SaveTSResults"]
+
+
+class SaveTSResults(BaseTransform):
+    """SaveSegResults"""
+
+    def __init__(self, save_dir):
+        super().__init__()
+        self.save_dir = save_dir
+        self._writer = TSWriter(backend="pandas")
+
+    def apply(self, data):
+        """apply"""
+        pred_ts = data[K.PRED]
+        file_name = os.path.basename(data[K.TS_PATH])
+
+        ts_save_path = os.path.join(self.save_dir, file_name)
+        self._write_ts(ts_save_path, pred_ts)
+
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.PRED]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return []
+
+    def _write_ts(self, path, ts):
+        """write ts"""
+        if os.path.exists(path):
+            logging.warning(f"{path} already exists. Overwriting it.")
+        self._writer.write(path, ts)
+
+    @staticmethod
+    def _add_suffix(path, suffix):
+        """add suffix"""
+        stem, ext = os.path.splitext(path)
+        return stem + suffix + ext
+
+    @staticmethod
+    def _replace_ext(path, new_ext):
+        """replace ext"""
+        stem, _ = os.path.splitext(path)
+        return stem + new_ext

+ 95 - 0
paddlex/modules/ts_forecast/predictor/utils.py

@@ -0,0 +1,95 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import codecs
+import yaml
+import os
+
+from ....utils import logging
+from ...base.predictor.transforms import ts_common
+
+
+class InnerConfig(object):
+    """Inner Config"""
+
+    def __init__(self, config_path, model_dir=None):
+        self.inner_cfg = self.load(config_path)
+        self.model_dir = model_dir
+
+    def load(self, config_path):
+        """load config"""
+        with codecs.open(config_path, "r", "utf-8") as file:
+            dic = yaml.load(file, Loader=yaml.FullLoader)
+        return dic
+
+    @property
+    def pre_transforms(self):
+        """read preprocess transforms from  config file"""
+
+        tfs = []
+        if self.inner_cfg.get("info_params", False):
+            tf = ts_common.TSCutOff(self.inner_cfg["size"])
+            tfs.append(tf)
+
+            if self.inner_cfg.get("scale", False):
+                scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+                if not os.path.exists(scaler_file_path):
+                    raise FileNotFoundError(
+                        f"Cannot find scaler file: {scaler_file_path}"
+                    )
+                tf = ts_common.TSNormalize(
+                    scaler_file_path, self.inner_cfg["info_params"]
+                )
+                tfs.append(tf)
+
+            tf = ts_common.BuildTSDataset(self.inner_cfg["info_params"])
+            tfs.append(tf)
+
+            if self.inner_cfg.get("time_feat", False):
+                tf = ts_common.TimeFeature(
+                    self.inner_cfg["info_params"],
+                    self.inner_cfg["size"],
+                    self.inner_cfg["holiday"],
+                )
+                tfs.append(tf)
+            tf = ts_common.TStoArray(self.inner_cfg["input_data"])
+
+            tfs.append(tf)
+        else:
+            raise ValueError("info_params is not found in config file")
+
+        return tfs
+
+    @property
+    def post_transforms(self):
+        """read preprocess transforms from  config file"""
+        tfs = []
+        if self.inner_cfg.get("info_params", False):
+            tf = ts_common.ArraytoTS(self.inner_cfg["info_params"])
+            tfs.append(tf)
+            if self.inner_cfg.get("scale", False):
+                scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
+                if not os.path.exists(scaler_file_path):
+                    raise FileNotFoundError(
+                        f"Cannot find scaler file: {scaler_file_path}"
+                    )
+                tf = ts_common.TSDeNormalize(
+                    scaler_file_path, self.inner_cfg["info_params"]
+                )
+                tfs.append(tf)
+        else:
+            raise ValueError("info_params is not found in config file")
+
+        return tfs

+ 58 - 5
paddlex/modules/ts_forecast/trainer.py

@@ -134,18 +134,30 @@ class TSFCTrainDeamon(BaseTrainDeamon):
 
     def update_result(self, result, train_output):
         """update every result"""
-        config = Path(train_output).joinpath("config.yaml")
-        if not config.exists():
+        train_output = Path(train_output).resolve()
+        config_path = Path(train_output).joinpath("config.yaml").resolve()
+        if not config_path.exists():
             return result
 
-        result["config"] = config
+        model_name = result["model_name"]
+        if (
+            model_name in self.config_recorder
+            and self.config_recorder[model_name] != config_path
+        ):
+            result["models"] = self.init_model_pkg()
+        result["config"] = config_path
+        self.config_recorder[model_name] = config_path
+
+        result["config"] = config_path
         result["train_log"] = self.update_train_log(train_output)
         result["visualdl_log"] = self.update_vdl_log(train_output)
         result["label_dict"] = self.update_label_dict(train_output)
-        self.update_models(result, train_output, "best")
+        model = self.get_model(result["model_name"], config_path)
+        self.update_models(result, model, train_output, "best")
+
         return result
 
-    def update_models(self, result, train_output, model_key):
+    def update_models(self, result, model, train_output, model_key):
         """update info of the models to be saved"""
         pdparams = Path(train_output).joinpath("best_accuracy.pdparams.tar")
         if pdparams.exists():
@@ -163,6 +175,47 @@ class TSFCTrainDeamon(BaseTrainDeamon):
                 "pdiparams": pdparams,
                 "pdiparams.info": "",
             }
+        self.update_inference_model(
+            model,
+            train_output,
+            train_output.joinpath(f"inference"),
+            result["models"][model_key],
+        )
+
+    def update_inference_model(
+        self, model, weight_path, export_save_dir, result_the_model
+    ):
+        """update inference model"""
+        export_save_dir.mkdir(parents=True, exist_ok=True)
+        export_result = model.export(weight_path=weight_path, save_dir=export_save_dir)
+
+        if export_result.returncode == 0:
+            inference_config = export_save_dir.joinpath("inference.yml")
+            if not inference_config.exists():
+                inference_config = ""
+            use_pir = (
+                hasattr(paddle.framework, "use_pir_api")
+                and paddle.framework.use_pir_api()
+            )
+            pdmodel = (
+                export_save_dir.joinpath("inference.json")
+                if use_pir
+                else export_save_dir.joinpath("inference.pdmodel")
+            )
+            pdiparams = export_save_dir.joinpath("inference.pdiparams")
+            pdiparams_info = (
+                "" if use_pir else export_save_dir.joinpath("inference.pdiparams.info")
+            )
+        else:
+            inference_config = ""
+            pdmodel = ""
+            pdiparams = ""
+            pdiparams_info = ""
+
+        result_the_model["inference_config"] = inference_config
+        result_the_model["pdmodel"] = pdmodel
+        result_the_model["pdiparams"] = pdiparams
+        result_the_model["pdiparams.info"] = pdiparams_info
 
     def get_score(self, score_path):
         """get the score by pdstates file"""

+ 5 - 5
paddlex/repo_apis/PaddleTS_api/ts_ad/register.py

@@ -44,7 +44,7 @@ register_model_info(
         "suite": "TSAnomaly",
         "config_path": TimesNetAD_CFG_PATH,
         "auto_compression_config_path": TimesNetAD_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx"],
             "dy2st": False,
@@ -64,7 +64,7 @@ register_model_info(
         "model_name": "AutoEncoder_ad",
         "suite": "TSAnomaly",
         "config_path": AE_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -86,7 +86,7 @@ register_model_info(
         "model_name": "DLinear_ad",
         "suite": "TSAnomaly",
         "config_path": DL_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -108,7 +108,7 @@ register_model_info(
         "model_name": "PatchTST_ad",
         "suite": "TSAnomaly",
         "config_path": PATCHTST_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -130,7 +130,7 @@ register_model_info(
         "model_name": "Nonstationary_ad",
         "suite": "TSAnomaly",
         "config_path": NS_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,

+ 9 - 1
paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py

@@ -103,7 +103,15 @@ class TSADRunner(BaseRunner):
 
     def export(self, config_path, cli_args, device):
         """export"""
-        raise_unsupported_api_error("export", self.__class__)
+        cmd = [
+            self.python,
+            "tools/export.py",
+            "--config",
+            config_path,
+            *cli_args,
+        ]
+        cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+        return cp
 
     def infer(self, config_path, cli_args, device):
         """infer"""

+ 26 - 2
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -212,9 +212,33 @@ class TSModel(BaseModel):
             config.dump(config_path)
             return self.runner.predict(config_path, cli_args, device)
 
-    def export(self, weight_path: str, save_dir: str = None, **kwargs):
+    def export(
+        self, weight_path: str, save_dir: str = None, device: str = "gpu", **kwargs
+    ):
         """export"""
-        raise_unsupported_api_error("export", self.__class__)
+        weight_path = abspath(weight_path)
+        save_dir = abspath(save_dir)
+        cli_args = []
+
+        weight_path = abspath(weight_path)
+        cli_args.append(CLIArgument("--checkpoints", weight_path))
+        if save_dir is not None:
+            save_dir = abspath(save_dir)
+        else:
+            save_dir = abspath(os.path.join("output", "inference"))
+        cli_args.append(CLIArgument("--save_dir", save_dir))
+        if device is not None:
+            device_type, _ = self.runner.parse_device(device)
+            cli_args.append(CLIArgument("--device", device_type))
+
+        self._assert_empty_kwargs(kwargs)
+        with self._create_new_config_file() as config_path:
+            # Update YAML config file
+            config = self.config.copy()
+            config.update_pretrained_weights(weight_path)
+            config.dump(config_path)
+
+            return self.runner.export(config_path, cli_args, device)
 
     def infer(
         self,

+ 9 - 1
paddlex/repo_apis/PaddleTS_api/ts_base/runner.py

@@ -103,7 +103,15 @@ class TSRunner(BaseRunner):
 
     def export(self, config_path, cli_args, device):
         """export"""
-        raise_unsupported_api_error("export", self.__class__)
+        cmd = [
+            self.python,
+            "tools/export.py",
+            "--config",
+            config_path,
+            *cli_args,
+        ]
+        cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+        return cp
 
     def infer(self, config_path, cli_args, device):
         """infer"""

+ 1 - 1
paddlex/repo_apis/PaddleTS_api/ts_cls/register.py

@@ -43,7 +43,7 @@ register_model_info(
         "model_name": "TimesNet_cls",
         "suite": "TSClassify",
         "config_path": TimesNetCLS_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,

+ 9 - 1
paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py

@@ -103,7 +103,15 @@ class TSCLSRunner(BaseRunner):
 
     def export(self, config_path, cli_args, device):
         """export"""
-        raise_unsupported_api_error("export", self.__class__)
+        cmd = [
+            self.python,
+            "tools/export.py",
+            "--config",
+            config_path,
+            *cli_args,
+        ]
+        cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+        return cp
 
     def infer(self, config_path, cli_args, device):
         """infer"""

+ 7 - 7
paddlex/repo_apis/PaddleTS_api/ts_fc/register.py

@@ -43,7 +43,7 @@ register_model_info(
         "model_name": "DLinear",
         "suite": "LongForecast",
         "config_path": DLinear_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -64,7 +64,7 @@ register_model_info(
         "model_name": "RLinear",
         "suite": "LongForecast",
         "config_path": DLinear_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -85,7 +85,7 @@ register_model_info(
         "model_name": "NLinear",
         "suite": "LongForecast",
         "config_path": DLinear_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -107,7 +107,7 @@ register_model_info(
         "model_name": "TiDE",
         "suite": "LongForecast",
         "config_path": TiDE_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -129,7 +129,7 @@ register_model_info(
         "model_name": "PatchTST",
         "suite": "LongForecast",
         "config_path": PatchTST_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -148,7 +148,7 @@ register_model_info(
         "model_name": "Nonstationary",
         "suite": "LongForecast",
         "config_path": Nonstationary_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,
@@ -170,7 +170,7 @@ register_model_info(
         "model_name": "TimesNet",
         "suite": "LongForecast",
         "config_path": TimesNet_CFG_PATH,
-        "supported_apis": ["train", "evaluate", "predict"],
+        "supported_apis": ["train", "evaluate", "predict", "export"],
         "supported_train_opts": {
             "device": ["cpu", "gpu_n1cx", "xpu", "npu", "mlu"],
             "dy2st": False,