predictor.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple
  15. import numpy as np
  16. import pandas as pd
  17. import os
  18. import copy
  19. from ....modules.ts_classification.model_list import MODELS
  20. from ...common.batch_sampler import TSBatchSampler
  21. from ...common.reader import ReadTS
  22. from ..common import (
  23. TSCutOff,
  24. BuildTSDataset,
  25. TSNormalize,
  26. TimeFeature,
  27. TStoArray,
  28. TStoBatch,
  29. )
  30. from .processors import GetCls, BuildPadMask
  31. from ..base import BasePredictor
  32. from .result import TSClsResult
  33. class TSClsPredictor(BasePredictor):
  34. """TSClsPredictor that inherits from BasePredictor."""
  35. entities = MODELS
  36. def __init__(self, *args: List, **kwargs: Dict) -> None:
  37. """Initializes TSClsPredictor.
  38. Args:
  39. *args: Arbitrary positional arguments passed to the superclass.
  40. **kwargs: Arbitrary keyword arguments passed to the superclass.
  41. """
  42. super().__init__(*args, **kwargs)
  43. self.preprocessors, self.infer, self.postprocessors = self._build()
  44. def _build_batch_sampler(self) -> TSBatchSampler:
  45. """Builds and returns an TSBatchSampler instance.
  46. Returns:
  47. TSBatchSampler: An instance of TSBatchSampler.
  48. """
  49. return TSBatchSampler()
  50. def _get_result_class(self) -> type:
  51. """Returns the result class.
  52. Returns:
  53. type: The Result class.
  54. """
  55. return TSClsResult
  56. def _build(self) -> Tuple:
  57. """Build the preprocessors, inference engine, and postprocessors based on the configuration.
  58. Returns:
  59. tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
  60. """
  61. preprocessors = {
  62. "ReadTS": ReadTS(),
  63. "TSCutOff": TSCutOff(self.config["size"]),
  64. }
  65. if self.config.get("scale", None):
  66. scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
  67. if not os.path.exists(scaler_file_path):
  68. raise Exception(f"Cannot find scaler file: {scaler_file_path}")
  69. preprocessors["TSNormalize"] = TSNormalize(
  70. scaler_file_path, self.config["info_params"]
  71. )
  72. preprocessors["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
  73. preprocessors["BuildPadMask"] = BuildPadMask(self.config["input_data"])
  74. preprocessors["TStoArray"] = TStoArray(self.config["input_data"])
  75. preprocessors["TStoBatch"] = TStoBatch()
  76. infer = self.create_static_infer()
  77. postprocessors = {}
  78. postprocessors["GetCls"] = GetCls()
  79. return preprocessors, infer, postprocessors
  80. def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
  81. """
  82. Processes a batch of time series data through a series of preprocessing, inference, and postprocessing steps.
  83. Args:
  84. batch_data (List[Union[str, pd.DataFrame]]): A list of paths or identifiers for the batch of time series data to be processed.
  85. Returns:
  86. Dict[str, Any]: A dictionary containing the paths to the input data, the raw input time series, and the classification results.
  87. """
  88. batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data.instances)
  89. batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
  90. if "TSNormalize" in self.preprocessors:
  91. batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_raw_ts)
  92. batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_ts)
  93. else:
  94. batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_raw_ts)
  95. batch_input_ts = self.preprocessors["BuildPadMask"](ts_list=batch_input_ts)
  96. batch_ts = self.preprocessors["TStoArray"](ts_list=batch_input_ts)
  97. x = self.preprocessors["TStoBatch"](ts_list=batch_ts)
  98. batch_preds = self.infer(x=x)
  99. batch_ts_preds = self.postprocessors["GetCls"](pred_list=batch_preds)
  100. return {
  101. "input_path": batch_data.input_paths,
  102. "input_ts": batch_raw_ts,
  103. "input_ts_data": batch_raw_ts_ori,
  104. "classification": batch_ts_preds,
  105. "target_cols": [self.config["info_params"]["target_cols"]],
  106. }