|
|
@@ -0,0 +1,144 @@
|
|
|
+
|
|
|
+from ..utils.exceptions import InvalidConfig, InvalidParams
|
|
|
+from .base import DataReader, DataWriter
|
|
|
+from ..io.s3 import S3Reader, S3Writer
|
|
|
+from ..utils.schemas import S3Config
|
|
|
+from ..utils.path_utils import parse_s3_range_params, parse_s3path, remove_non_official_s3_args
|
|
|
+
|
|
|
+
|
|
|
+class MultiS3Mixin:
|
|
|
+ def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
|
|
|
+ """Initialized with multiple s3 configs.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
|
|
|
+ s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ InvalidConfig: default bucket config not in s3_configs.
|
|
|
+ InvalidConfig: bucket name not unique in s3_configs.
|
|
|
+ InvalidConfig: default bucket must be provided.
|
|
|
+ """
|
|
|
+ if len(default_prefix) == 0:
|
|
|
+ raise InvalidConfig('default_prefix must be provided')
|
|
|
+
|
|
|
+ arr = default_prefix.strip('/').split('/')
|
|
|
+ self.default_bucket = arr[0]
|
|
|
+ self.default_prefix = '/'.join(arr[1:])
|
|
|
+
|
|
|
+ found_default_bucket_config = False
|
|
|
+ for conf in s3_configs:
|
|
|
+ if conf.bucket_name == self.default_bucket:
|
|
|
+ found_default_bucket_config = True
|
|
|
+ break
|
|
|
+
|
|
|
+ if not found_default_bucket_config:
|
|
|
+ raise InvalidConfig(
|
|
|
+ f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
|
|
|
+ )
|
|
|
+
|
|
|
+ uniq_bucket = set([conf.bucket_name for conf in s3_configs])
|
|
|
+ if len(uniq_bucket) != len(s3_configs):
|
|
|
+ raise InvalidConfig(
|
|
|
+ f'the bucket_name in s3_configs: {s3_configs} must be unique'
|
|
|
+ )
|
|
|
+
|
|
|
+ self.s3_configs = s3_configs
|
|
|
+ self._s3_clients_h: dict = {}
|
|
|
+
|
|
|
+
|
|
|
+class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
|
|
|
+ def read(self, path: str) -> bytes:
|
|
|
+ """Read the path from s3, select diffect bucket client for each request
|
|
|
+ based on the bucket, also support range read.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
|
|
|
+ for example: s3://bucket_name/path?0,100.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bytes: the content of s3 file.
|
|
|
+ """
|
|
|
+ may_range_params = parse_s3_range_params(path)
|
|
|
+ if may_range_params is None or 2 != len(may_range_params):
|
|
|
+ byte_start, byte_len = 0, -1
|
|
|
+ else:
|
|
|
+ byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
|
|
|
+ path = remove_non_official_s3_args(path)
|
|
|
+ return self.read_at(path, byte_start, byte_len)
|
|
|
+
|
|
|
+ def __get_s3_client(self, bucket_name: str):
|
|
|
+ if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
|
|
|
+ raise InvalidParams(
|
|
|
+ f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
|
|
|
+ )
|
|
|
+ if bucket_name not in self._s3_clients_h:
|
|
|
+ conf = next(
|
|
|
+ filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
|
|
|
+ )
|
|
|
+ self._s3_clients_h[bucket_name] = S3Reader(
|
|
|
+ bucket_name,
|
|
|
+ conf.access_key,
|
|
|
+ conf.secret_key,
|
|
|
+ conf.endpoint_url,
|
|
|
+ conf.addressing_style,
|
|
|
+ )
|
|
|
+ return self._s3_clients_h[bucket_name]
|
|
|
+
|
|
|
+ def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
|
|
|
+ """Read the file with offset and limit, select diffect bucket client
|
|
|
+ for each request based on the bucket.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ path (str): the file path.
|
|
|
+ offset (int, optional): the number of bytes skipped. Defaults to 0.
|
|
|
+ limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bytes: the file content.
|
|
|
+ """
|
|
|
+ if path.startswith('s3://'):
|
|
|
+ bucket_name, path = parse_s3path(path)
|
|
|
+ s3_reader = self.__get_s3_client(bucket_name)
|
|
|
+ else:
|
|
|
+ s3_reader = self.__get_s3_client(self.default_bucket)
|
|
|
+ if self.default_prefix:
|
|
|
+ path = self.default_prefix + '/' + path
|
|
|
+ return s3_reader.read_at(path, offset, limit)
|
|
|
+
|
|
|
+
|
|
|
+class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
|
|
|
+ def __get_s3_client(self, bucket_name: str):
|
|
|
+ if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
|
|
|
+ raise InvalidParams(
|
|
|
+ f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
|
|
|
+ )
|
|
|
+ if bucket_name not in self._s3_clients_h:
|
|
|
+ conf = next(
|
|
|
+ filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
|
|
|
+ )
|
|
|
+ self._s3_clients_h[bucket_name] = S3Writer(
|
|
|
+ bucket_name,
|
|
|
+ conf.access_key,
|
|
|
+ conf.secret_key,
|
|
|
+ conf.endpoint_url,
|
|
|
+ conf.addressing_style,
|
|
|
+ )
|
|
|
+ return self._s3_clients_h[bucket_name]
|
|
|
+
|
|
|
+ def write(self, path: str, data: bytes) -> None:
|
|
|
+ """Write file with data, also select diffect bucket client for each
|
|
|
+ request based on the bucket.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
|
|
+ data (bytes): the data want to write.
|
|
|
+ """
|
|
|
+ if path.startswith('s3://'):
|
|
|
+ bucket_name, path = parse_s3path(path)
|
|
|
+ s3_writer = self.__get_s3_client(bucket_name)
|
|
|
+ else:
|
|
|
+ s3_writer = self.__get_s3_client(self.default_bucket)
|
|
|
+ if self.default_prefix:
|
|
|
+ path = self.default_prefix + '/' + path
|
|
|
+ return s3_writer.write(path, data)
|