ts_batch_sampler.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import ast
  16. from pathlib import Path
  17. import numpy as np
  18. import pandas as pd
  19. from ....utils import logging
  20. from ....utils.download import download
  21. from ....utils.cache import CACHE_DIR
  22. from .base_batch_sampler import BaseBatchSampler
  23. class TSBatchSampler(BaseBatchSampler):
  24. """Batch sampler for time series data, supporting CSV file inputs."""
  25. SUFFIX = ["csv"]
  26. def _download_from_url(self, in_path: str) -> str:
  27. """Download a file from a URL to a cache directory.
  28. Args:
  29. in_path (str): URL of the file to be downloaded.
  30. Returns:
  31. str: Path to the downloaded file.
  32. """
  33. file_name = Path(in_path).name
  34. save_path = Path(CACHE_DIR) / "predict_input" / file_name
  35. download(in_path, save_path, overwrite=True)
  36. return save_path.as_posix()
  37. def _get_files_list(self, fp: str) -> list:
  38. """Get a list of CSV files from a directory or a single file path.
  39. Args:
  40. fp (str): Path to a directory or a single CSV file.
  41. Returns:
  42. list: Sorted list of CSV file paths.
  43. Raises:
  44. Exception: If no CSV file is found in the path.
  45. """
  46. file_list = []
  47. if fp is None or not os.path.exists(fp):
  48. raise Exception(f"Not found any csv file in path: {fp}")
  49. if os.path.isfile(fp) and fp.split(".")[-1] in self.SUFFIX:
  50. file_list.append(fp)
  51. elif os.path.isdir(fp):
  52. for root, dirs, files in os.walk(fp):
  53. for single_file in files:
  54. if single_file.split(".")[-1] in self.SUFFIX:
  55. file_list.append(os.path.join(root, single_file))
  56. if len(file_list) == 0:
  57. raise Exception("Not found any file in {}".format(fp))
  58. file_list = sorted(file_list)
  59. return file_list
  60. def sample(self, inputs: list) -> list:
  61. """Generate batches of data from inputs, which can be DataFrames or file paths.
  62. Args:
  63. inputs (list): List of DataFrames or file paths.
  64. Yields:
  65. list: A batch of data which is either DataFrames or file paths.
  66. """
  67. if not isinstance(inputs, list):
  68. inputs = [inputs]
  69. batch = []
  70. for input in inputs:
  71. if isinstance(input, pd.DataFrame):
  72. batch.append(input)
  73. if len(batch) == self.batch_size:
  74. yield batch
  75. batch = []
  76. elif isinstance(input, str):
  77. file_path = (
  78. self._download_from_url(input)
  79. if input.startswith("http")
  80. else input
  81. )
  82. file_list = self._get_files_list(file_path)
  83. for file_path in file_list:
  84. batch.append(file_path)
  85. if len(batch) == self.batch_size:
  86. yield batch
  87. batch = []
  88. else:
  89. logging.warning(
  90. f"Not supported input data type! Only `pd.DataFrame` and `str` are supported! So has been ignored: {input}."
  91. )
  92. if len(batch) > 0:
  93. yield batch