multi_bucket_s3.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from magic_pdf.config.exceptions import InvalidConfig, InvalidParams
  2. from magic_pdf.data.data_reader_writer.base import DataReader, DataWriter
  3. from magic_pdf.data.io.s3 import S3Reader, S3Writer
  4. from magic_pdf.data.schemas import S3Config
  5. from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
  6. remove_non_official_s3_args)
  7. class MultiS3Mixin:
  8. def __init__(self, default_bucket: str, s3_configs: list[S3Config]):
  9. """Initialized with multiple s3 configs.
  10. Args:
  11. default_bucket (str): the default bucket name of the relative path
  12. s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
  13. Raises:
  14. InvalidConfig: default bucket config not in s3_configs
  15. InvalidConfig: bucket name not unique in s3_configs
  16. InvalidConfig: default bucket must be provided
  17. """
  18. if len(default_bucket) == 0:
  19. raise InvalidConfig('default_bucket must be provided')
  20. found_default_bucket_config = False
  21. for conf in s3_configs:
  22. if conf.bucket_name == default_bucket:
  23. found_default_bucket_config = True
  24. break
  25. if not found_default_bucket_config:
  26. raise InvalidConfig(
  27. f'default_bucket: {default_bucket} config must be provided in s3_configs: {s3_configs}'
  28. )
  29. uniq_bucket = set([conf.bucket_name for conf in s3_configs])
  30. if len(uniq_bucket) != len(s3_configs):
  31. raise InvalidConfig(
  32. f'the bucket_name in s3_configs: {s3_configs} must be unique'
  33. )
  34. self.default_bucket = default_bucket
  35. self.s3_configs = s3_configs
  36. self._s3_clients_h: dict = {}
  37. class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
  38. def read(self, path: str) -> bytes:
  39. """Read the path from s3, select diffect bucket client for each request
  40. based on the path, also support range read.
  41. Args:
  42. path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit
  43. for example: s3://bucket_name/path?0,100
  44. Returns:
  45. bytes: the content of s3 file
  46. """
  47. may_range_params = parse_s3_range_params(path)
  48. if may_range_params is None or 2 != len(may_range_params):
  49. byte_start, byte_len = 0, -1
  50. else:
  51. byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
  52. path = remove_non_official_s3_args(path)
  53. return self.read_at(path, byte_start, byte_len)
  54. def __get_s3_client(self, bucket_name: str):
  55. if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
  56. raise InvalidParams(
  57. f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
  58. )
  59. if bucket_name not in self._s3_clients_h:
  60. conf = next(
  61. filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
  62. )
  63. self._s3_clients_h[bucket_name] = S3Reader(
  64. bucket_name,
  65. conf.access_key,
  66. conf.secret_key,
  67. conf.endpoint_url,
  68. conf.addressing_style,
  69. )
  70. return self._s3_clients_h[bucket_name]
  71. def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
  72. """Read the file with offset and limit, select diffect bucket client
  73. for each request based on the path.
  74. Args:
  75. path (str): the file path
  76. offset (int, optional): the number of bytes skipped. Defaults to 0.
  77. limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
  78. Returns:
  79. bytes: the file content
  80. """
  81. if path.startswith('s3://'):
  82. bucket_name, path = parse_s3path(path)
  83. s3_reader = self.__get_s3_client(bucket_name)
  84. else:
  85. s3_reader = self.__get_s3_client(self.default_bucket)
  86. return s3_reader.read_at(path, offset, limit)
  87. class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
  88. def __get_s3_client(self, bucket_name: str):
  89. if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
  90. raise InvalidParams(
  91. f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
  92. )
  93. if bucket_name not in self._s3_clients_h:
  94. conf = next(
  95. filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
  96. )
  97. self._s3_clients_h[bucket_name] = S3Writer(
  98. bucket_name,
  99. conf.access_key,
  100. conf.secret_key,
  101. conf.endpoint_url,
  102. conf.addressing_style,
  103. )
  104. return self._s3_clients_h[bucket_name]
  105. def write(self, path: str, data: bytes) -> None:
  106. """Write file with data, also select diffect bucket client for each
  107. request based on the path.
  108. Args:
  109. path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
  110. data (bytes): the data want to write
  111. """
  112. if path.startswith('s3://'):
  113. bucket_name, path = parse_s3path(path)
  114. s3_writer = self.__get_s3_client(bucket_name)
  115. else:
  116. s3_writer = self.__get_s3_client(self.default_bucket)
  117. return s3_writer.write(path, data)