| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import os
- from typing import Any, Dict, List, Tuple, Union
- import pandas as pd
- from ....modules.ts_forecast.model_list import MODELS
- from ...common.batch_sampler import TSBatchSampler
- from ...common.reader import ReadTS
- from ..base import BasePredictor
- from ..common import (
- BuildTSDataset,
- TimeFeature,
- TSCutOff,
- TSNormalize,
- TStoArray,
- TStoBatch,
- )
- from .processors import ArraytoTS, TSDeNormalize
- from .result import TSFcResult
- class TSFcPredictor(BasePredictor):
- """TSFcPredictor that inherits from BasePredictor."""
- entities = MODELS
- def __init__(self, *args: List, **kwargs: Dict) -> None:
- """Initializes TSFcPredictor.
- Args:
- *args: Arbitrary positional arguments passed to the superclass.
- **kwargs: Arbitrary keyword arguments passed to the superclass.
- """
- super().__init__(*args, **kwargs)
- self.preprocessors, self.infer, self.postprocessors = self._build()
- def _build_batch_sampler(self) -> TSBatchSampler:
- """Builds and returns an ImageBatchSampler instance.
- Returns:
- ImageBatchSampler: An instance of ImageBatchSampler.
- """
- return TSBatchSampler()
- def _get_result_class(self) -> type:
- """Returns the result class, TopkResult.
- Returns:
- type: The TopkResult class.
- """
- return TSFcResult
- def _build(self) -> Tuple:
- """Build the preprocessors, inference engine, and postprocessors based on the configuration.
- Returns:
- tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
- """
- preprocessors = {
- "ReadTS": ReadTS(),
- "TSCutOff": TSCutOff(self.config["size"]),
- }
- if self.config.get("scale", None):
- scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
- if not os.path.exists(scaler_file_path):
- raise Exception(f"Cannot find scaler file: {scaler_file_path}")
- preprocessors["TSNormalize"] = TSNormalize(
- scaler_file_path, self.config["info_params"]
- )
- preprocessors["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
- if self.config.get("time_feat", None):
- preprocessors["TimeFeature"] = TimeFeature(
- self.config["info_params"],
- self.config["size"],
- self.config["holiday"],
- )
- preprocessors["TStoArray"] = TStoArray(self.config["input_data"])
- preprocessors["TStoBatch"] = TStoBatch()
- infer = self.create_static_infer()
- postprocessors = {}
- postprocessors["ArraytoTS"] = ArraytoTS(self.config["info_params"])
- if self.config.get("scale", None):
- scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
- if not os.path.exists(scaler_file_path):
- raise Exception(f"Cannot find scaler file: {scaler_file_path}")
- postprocessors["TSDeNormalize"] = TSDeNormalize(
- scaler_file_path, self.config["info_params"]
- )
- return preprocessors, infer, postprocessors
- def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
- """
- Process a batch of data through the preprocessing, inference, and postprocessing.
- Args:
- batch_data (List[Union[str, pd.DataFrame], ...]): A batch of input data (e.g., image file paths).
- Returns:
- dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
- """
- batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data.instances)
- batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
- batch_cutoff_ts = self.preprocessors["TSCutOff"](ts_list=batch_raw_ts)
- if "TSNormalize" in self.preprocessors:
- batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_cutoff_ts)
- batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_ts)
- else:
- batch_input_ts = self.preprocessors["BuildTSDataset"](
- ts_list=batch_cutoff_ts
- )
- if "TimeFeature" in self.preprocessors:
- batch_ts = self.preprocessors["TimeFeature"](ts_list=batch_input_ts)
- batch_ts = self.preprocessors["TStoArray"](ts_list=batch_ts)
- else:
- batch_ts = self.preprocessors["TStoArray"](ts_list=batch_input_ts)
- x = self.preprocessors["TStoBatch"](ts_list=batch_ts)
- batch_preds = self.infer(x=x)
- batch_ts_preds = self.postprocessors["ArraytoTS"](
- ori_ts_list=batch_input_ts, pred_list=batch_preds
- )
- if "TSDeNormalize" in self.postprocessors:
- batch_ts_preds = self.postprocessors["TSDeNormalize"](
- preds_list=batch_ts_preds
- )
- return {
- "input_path": batch_data.input_paths,
- "input_ts": batch_raw_ts,
- "cutoff_ts": batch_raw_ts_ori,
- "forecast": batch_ts_preds,
- }
|