|
|
@@ -17,65 +17,55 @@ import os
|
|
|
from ...modules.ts_forecast.model_list import MODELS
|
|
|
from ..components import *
|
|
|
from ..results import TSFcResult
|
|
|
-from ..utils.process_hook import batchable_method
|
|
|
-from .base import BasicPredictor
|
|
|
+from .base import TSPredictor
|
|
|
|
|
|
|
|
|
-class TSFcPredictor(BasicPredictor):
|
|
|
+class TSFcPredictor(TSPredictor):
|
|
|
|
|
|
entities = MODELS
|
|
|
|
|
|
def _build_components(self):
|
|
|
- preprocess = self._build_preprocess()
|
|
|
- predictor = TSPPPredictor(
|
|
|
- model_dir=self.model_dir,
|
|
|
- model_prefix=self.MODEL_FILE_PREFIX,
|
|
|
- option=self.pp_option,
|
|
|
- )
|
|
|
- postprocess = self._build_postprocess()
|
|
|
- return {**preprocess, "predictor": predictor, **postprocess}
|
|
|
-
|
|
|
- def _build_preprocess(self):
|
|
|
if not self.config.get("info_params", None):
|
|
|
raise Exception("info_params is not found in config file")
|
|
|
|
|
|
- ops = {}
|
|
|
- ops["ReadTS"] = ReadTS()
|
|
|
- ops["TSCutOff"] = TSCutOff(self.config["size"])
|
|
|
+ self._add_component(ReadTS())
|
|
|
+ self._add_component(TSCutOff(self.config["size"]))
|
|
|
|
|
|
if self.config.get("scale", None):
|
|
|
scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
|
|
|
if not os.path.exists(scaler_file_path):
|
|
|
raise Exception(f"Cannot find scaler file: {scaler_file_path}")
|
|
|
- ops["TSNormalize"] = TSNormalize(
|
|
|
- scaler_file_path, self.config["info_params"]
|
|
|
+ self._add_component(
|
|
|
+ TSNormalize(scaler_file_path, self.config["info_params"])
|
|
|
)
|
|
|
|
|
|
- ops["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
|
|
|
+ self._add_component(BuildTSDataset(self.config["info_params"]))
|
|
|
|
|
|
if self.config.get("time_feat", None):
|
|
|
- ops["TimeFeature"] = TimeFeature(
|
|
|
- self.config["info_params"],
|
|
|
- self.config["size"],
|
|
|
- self.config["holiday"],
|
|
|
+ self._add_component(
|
|
|
+ TimeFeature(
|
|
|
+ self.config["info_params"],
|
|
|
+ self.config["size"],
|
|
|
+ self.config["holiday"],
|
|
|
+ )
|
|
|
)
|
|
|
- ops["TStoArray"] = TStoArray(self.config["input_data"])
|
|
|
- return ops
|
|
|
+ self._add_component(TStoArray(self.config["input_data"]))
|
|
|
|
|
|
- def _build_postprocess(self):
|
|
|
- if not self.config.get("info_params", None):
|
|
|
- raise Exception("info_params is not found in config file")
|
|
|
+ predictor = TSPPPredictor(
|
|
|
+ model_dir=self.model_dir,
|
|
|
+ model_prefix=self.MODEL_FILE_PREFIX,
|
|
|
+ option=self.pp_option,
|
|
|
+ )
|
|
|
+ self._add_component(("Predictor", predictor))
|
|
|
|
|
|
- ops = {}
|
|
|
- ops["ArraytoTS"] = ArraytoTS(self.config["info_params"])
|
|
|
+ self._add_component(ArraytoTS(self.config["info_params"]))
|
|
|
if self.config.get("scale", None):
|
|
|
scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
|
|
|
if not os.path.exists(scaler_file_path):
|
|
|
raise Exception(f"Cannot find scaler file: {scaler_file_path}")
|
|
|
- ops["TSDeNormalize"] = TSDeNormalize(
|
|
|
- scaler_file_path, self.config["info_params"]
|
|
|
+ self._add_component(
|
|
|
+ TSDeNormalize(scaler_file_path, self.config["info_params"])
|
|
|
)
|
|
|
- return ops
|
|
|
|
|
|
def _pack_res(self, single):
|
|
|
return TSFcResult({"ts_path": single["ts_path"], "forecast": single["pred"]})
|