predictor.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import os
  16. from typing import Any, Dict, List, Tuple, Union
  17. import pandas as pd
  18. from ....modules.ts_classification.model_list import MODELS
  19. from ...common.batch_sampler import TSBatchSampler
  20. from ...common.reader import ReadTS
  21. from ..base import BasePredictor
  22. from ..common import BuildTSDataset, TSCutOff, TSNormalize, TStoArray, TStoBatch
  23. from .processors import BuildPadMask, GetCls
  24. from .result import TSClsResult
  25. class TSClsPredictor(BasePredictor):
  26. """TSClsPredictor that inherits from BasePredictor."""
  27. entities = MODELS
  28. def __init__(self, *args: List, **kwargs: Dict) -> None:
  29. """Initializes TSClsPredictor.
  30. Args:
  31. *args: Arbitrary positional arguments passed to the superclass.
  32. **kwargs: Arbitrary keyword arguments passed to the superclass.
  33. """
  34. super().__init__(*args, **kwargs)
  35. self.preprocessors, self.infer, self.postprocessors = self._build()
  36. def _build_batch_sampler(self) -> TSBatchSampler:
  37. """Builds and returns an TSBatchSampler instance.
  38. Returns:
  39. TSBatchSampler: An instance of TSBatchSampler.
  40. """
  41. return TSBatchSampler()
  42. def _get_result_class(self) -> type:
  43. """Returns the result class.
  44. Returns:
  45. type: The Result class.
  46. """
  47. return TSClsResult
  48. def _build(self) -> Tuple:
  49. """Build the preprocessors, inference engine, and postprocessors based on the configuration.
  50. Returns:
  51. tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
  52. """
  53. preprocessors = {
  54. "ReadTS": ReadTS(),
  55. "TSCutOff": TSCutOff(self.config["size"]),
  56. }
  57. if self.config.get("scale", None):
  58. scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
  59. if not os.path.exists(scaler_file_path):
  60. raise Exception(f"Cannot find scaler file: {scaler_file_path}")
  61. preprocessors["TSNormalize"] = TSNormalize(
  62. scaler_file_path, self.config["info_params"]
  63. )
  64. preprocessors["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
  65. preprocessors["BuildPadMask"] = BuildPadMask(self.config["input_data"])
  66. preprocessors["TStoArray"] = TStoArray(self.config["input_data"])
  67. preprocessors["TStoBatch"] = TStoBatch()
  68. infer = self.create_static_infer()
  69. postprocessors = {}
  70. postprocessors["GetCls"] = GetCls()
  71. return preprocessors, infer, postprocessors
  72. def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
  73. """
  74. Processes a batch of time series data through a series of preprocessing, inference, and postprocessing steps.
  75. Args:
  76. batch_data (List[Union[str, pd.DataFrame]]): A list of paths or identifiers for the batch of time series data to be processed.
  77. Returns:
  78. Dict[str, Any]: A dictionary containing the paths to the input data, the raw input time series, and the classification results.
  79. """
  80. batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data.instances)
  81. batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
  82. if "TSNormalize" in self.preprocessors:
  83. batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_raw_ts)
  84. batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_ts)
  85. else:
  86. batch_input_ts = self.preprocessors["BuildTSDataset"](ts_list=batch_raw_ts)
  87. batch_input_ts = self.preprocessors["BuildPadMask"](ts_list=batch_input_ts)
  88. batch_ts = self.preprocessors["TStoArray"](ts_list=batch_input_ts)
  89. x = self.preprocessors["TStoBatch"](ts_list=batch_ts)
  90. batch_preds = self.infer(x=x)
  91. batch_ts_preds = self.postprocessors["GetCls"](pred_list=batch_preds)
  92. return {
  93. "input_path": batch_data.input_paths,
  94. "input_ts": batch_raw_ts,
  95. "input_ts_data": batch_raw_ts_ori,
  96. "classification": batch_ts_preds,
  97. "target_cols": [self.config["info_params"]["target_cols"]],
  98. }