multi_bucket_s3.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from ..utils.exceptions import InvalidConfig, InvalidParams
  2. from .base import DataReader, DataWriter
  3. from ..io.s3 import S3Reader, S3Writer
  4. from ..utils.schemas import S3Config
  5. from ..utils.path_utils import parse_s3_range_params, parse_s3path, remove_non_official_s3_args
  6. class MultiS3Mixin:
  7. def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
  8. """Initialized with multiple s3 configs.
  9. Args:
  10. default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
  11. s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
  12. Raises:
  13. InvalidConfig: default bucket config not in s3_configs.
  14. InvalidConfig: bucket name not unique in s3_configs.
  15. InvalidConfig: default bucket must be provided.
  16. """
  17. if len(default_prefix) == 0:
  18. raise InvalidConfig('default_prefix must be provided')
  19. arr = default_prefix.strip('/').split('/')
  20. self.default_bucket = arr[0]
  21. self.default_prefix = '/'.join(arr[1:])
  22. found_default_bucket_config = False
  23. for conf in s3_configs:
  24. if conf.bucket_name == self.default_bucket:
  25. found_default_bucket_config = True
  26. break
  27. if not found_default_bucket_config:
  28. raise InvalidConfig(
  29. f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
  30. )
  31. uniq_bucket = set([conf.bucket_name for conf in s3_configs])
  32. if len(uniq_bucket) != len(s3_configs):
  33. raise InvalidConfig(
  34. f'the bucket_name in s3_configs: {s3_configs} must be unique'
  35. )
  36. self.s3_configs = s3_configs
  37. self._s3_clients_h: dict = {}
  38. class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
  39. def read(self, path: str) -> bytes:
  40. """Read the path from s3, select diffect bucket client for each request
  41. based on the bucket, also support range read.
  42. Args:
  43. path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
  44. for example: s3://bucket_name/path?0,100.
  45. Returns:
  46. bytes: the content of s3 file.
  47. """
  48. may_range_params = parse_s3_range_params(path)
  49. if may_range_params is None or 2 != len(may_range_params):
  50. byte_start, byte_len = 0, -1
  51. else:
  52. byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
  53. path = remove_non_official_s3_args(path)
  54. return self.read_at(path, byte_start, byte_len)
  55. def __get_s3_client(self, bucket_name: str):
  56. if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
  57. raise InvalidParams(
  58. f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
  59. )
  60. if bucket_name not in self._s3_clients_h:
  61. conf = next(
  62. filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
  63. )
  64. self._s3_clients_h[bucket_name] = S3Reader(
  65. bucket_name,
  66. conf.access_key,
  67. conf.secret_key,
  68. conf.endpoint_url,
  69. conf.addressing_style,
  70. )
  71. return self._s3_clients_h[bucket_name]
  72. def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
  73. """Read the file with offset and limit, select diffect bucket client
  74. for each request based on the bucket.
  75. Args:
  76. path (str): the file path.
  77. offset (int, optional): the number of bytes skipped. Defaults to 0.
  78. limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
  79. Returns:
  80. bytes: the file content.
  81. """
  82. if path.startswith('s3://'):
  83. bucket_name, path = parse_s3path(path)
  84. s3_reader = self.__get_s3_client(bucket_name)
  85. else:
  86. s3_reader = self.__get_s3_client(self.default_bucket)
  87. if self.default_prefix:
  88. path = self.default_prefix + '/' + path
  89. return s3_reader.read_at(path, offset, limit)
  90. class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
  91. def __get_s3_client(self, bucket_name: str):
  92. if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
  93. raise InvalidParams(
  94. f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
  95. )
  96. if bucket_name not in self._s3_clients_h:
  97. conf = next(
  98. filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
  99. )
  100. self._s3_clients_h[bucket_name] = S3Writer(
  101. bucket_name,
  102. conf.access_key,
  103. conf.secret_key,
  104. conf.endpoint_url,
  105. conf.addressing_style,
  106. )
  107. return self._s3_clients_h[bucket_name]
  108. def write(self, path: str, data: bytes) -> None:
  109. """Write file with data, also select diffect bucket client for each
  110. request based on the bucket.
  111. Args:
  112. path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
  113. data (bytes): the data want to write.
  114. """
  115. if path.startswith('s3://'):
  116. bucket_name, path = parse_s3path(path)
  117. s3_writer = self.__get_s3_client(bucket_name)
  118. else:
  119. s3_writer = self.__get_s3_client(self.default_bucket)
  120. if self.default_prefix:
  121. path = self.default_prefix + '/' + path
  122. return s3_writer.write(path, data)