Pārlūkot izejas kodu

feat: implement S3 data reader and writer with multi-bucket support

myhloli 5 mēneši atpakaļ
vecāks
revīzija
6f8a961087

+ 17 - 0
mineru/data/data_reader_writer/__init__.py

@@ -0,0 +1,17 @@
+from .base import DataReader, DataWriter
+from .dummy import DummyDataWriter
+from .filebase import FileBasedDataReader, FileBasedDataWriter
+from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
+from .s3 import S3DataReader, S3DataWriter
+
+__all__ = [
+    "DataReader",
+    "DataWriter",
+    "FileBasedDataReader",
+    "FileBasedDataWriter",
+    "S3DataReader",
+    "S3DataWriter",
+    "MultiBucketS3DataReader",
+    "MultiBucketS3DataWriter",
+    "DummyDataWriter",
+]

+ 63 - 0
mineru/data/data_reader_writer/base.py

@@ -0,0 +1,63 @@
+
+from abc import ABC, abstractmethod
+
+
+class DataReader(ABC):
+
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(path)
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file at offset and limit.
+
+        Args:
+            path (str): the file path
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+
+class DataWriter(ABC):
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write the data to the file.
+
+        Args:
+            path (str): the target file where to write
+            data (bytes): the data want to write
+        """
+        pass
+
+    def write_string(self, path: str, data: str) -> None:
+        """Write the data to file, the data will be encoded to bytes.
+
+        Args:
+            path (str): the target file where to write
+            data (str): the data want to write
+        """
+
+        def safe_encode(data: str, method: str):
+            try:
+                bit_data = data.encode(encoding=method, errors='replace')
+                return bit_data, True
+            except:  # noqa
+                return None, False
+
+        for method in ['utf-8', 'ascii']:
+            bit_data, flag = safe_encode(data, method)
+            if flag:
+                self.write(path, bit_data)
+                break

+ 11 - 0
mineru/data/data_reader_writer/dummy.py

@@ -0,0 +1,11 @@
+from .base import DataWriter
+
+
+class DummyDataWriter(DataWriter):
+    def write(self, path: str, data: bytes) -> None:
+        """Dummy write method that does nothing."""
+        pass
+
+    def write_string(self, path: str, data: str) -> None:
+        """Dummy write_string method that does nothing."""
+        pass

+ 62 - 0
mineru/data/data_reader_writer/filebase.py

@@ -0,0 +1,62 @@
+import os
+
+from .base import DataReader, DataWriter
+
+
+class FileBasedDataReader(DataReader):
+    def __init__(self, parent_dir: str = ''):
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        with open(fn_path, 'rb') as f:
+            f.seek(offset)
+            if limit == -1:
+                return f.read()
+            else:
+                return f.read(limit)
+
+
+class FileBasedDataWriter(DataWriter):
+    def __init__(self, parent_dir: str = '') -> None:
+        """Initialized with parent_dir.
+
+        Args:
+            parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
+        """
+        self._parent_dir = parent_dir
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        fn_path = path
+        if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
+            fn_path = os.path.join(self._parent_dir, path)
+
+        if not os.path.exists(os.path.dirname(fn_path)) and os.path.dirname(fn_path) != "":
+            os.makedirs(os.path.dirname(fn_path), exist_ok=True)
+
+        with open(fn_path, 'wb') as f:
+            f.write(data)

+ 144 - 0
mineru/data/data_reader_writer/multi_bucket_s3.py

@@ -0,0 +1,144 @@
+
+from ..utils.exceptions import InvalidConfig, InvalidParams
+from .base import DataReader, DataWriter
+from ..io.s3 import S3Reader, S3Writer
+from ..utils.schemas import S3Config
+from ..utils.path_utils import parse_s3_range_params, parse_s3path, remove_non_official_s3_args
+
+
+class MultiS3Mixin:
+    def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
+        """Initialized with multiple s3 configs.
+
+        Args:
+            default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
+            s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
+
+        Raises:
+            InvalidConfig: default bucket config not in s3_configs.
+            InvalidConfig: bucket name not unique in s3_configs.
+            InvalidConfig: default bucket must be provided.
+        """
+        if len(default_prefix) == 0:
+            raise InvalidConfig('default_prefix must be provided')
+
+        arr = default_prefix.strip('/').split('/')
+        self.default_bucket = arr[0]
+        self.default_prefix = '/'.join(arr[1:])
+
+        found_default_bucket_config = False
+        for conf in s3_configs:
+            if conf.bucket_name == self.default_bucket:
+                found_default_bucket_config = True
+                break
+
+        if not found_default_bucket_config:
+            raise InvalidConfig(
+                f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
+            )
+
+        uniq_bucket = set([conf.bucket_name for conf in s3_configs])
+        if len(uniq_bucket) != len(s3_configs):
+            raise InvalidConfig(
+                f'the bucket_name in s3_configs: {s3_configs} must be unique'
+            )
+
+        self.s3_configs = s3_configs
+        self._s3_clients_h: dict = {}
+
+
+class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
+    def read(self, path: str) -> bytes:
+        """Read the path from s3, select diffect bucket client for each request
+        based on the bucket, also support range read.
+
+        Args:
+            path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
+            for example: s3://bucket_name/path?0,100.
+
+        Returns:
+            bytes: the content of s3 file.
+        """
+        may_range_params = parse_s3_range_params(path)
+        if may_range_params is None or 2 != len(may_range_params):
+            byte_start, byte_len = 0, -1
+        else:
+            byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
+        path = remove_non_official_s3_args(path)
+        return self.read_at(path, byte_start, byte_len)
+
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Reader(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read the file with offset and limit, select diffect bucket client
+        for each request based on the bucket.
+
+        Args:
+            path (str): the file path.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
+
+        Returns:
+            bytes: the file content.
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_reader = self.__get_s3_client(bucket_name)
+        else:
+            s3_reader = self.__get_s3_client(self.default_bucket)
+            if self.default_prefix:
+                path = self.default_prefix + '/' + path
+        return s3_reader.read_at(path, offset, limit)
+
+
+class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
+    def __get_s3_client(self, bucket_name: str):
+        if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
+            raise InvalidParams(
+                f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
+            )
+        if bucket_name not in self._s3_clients_h:
+            conf = next(
+                filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
+            )
+            self._s3_clients_h[bucket_name] = S3Writer(
+                bucket_name,
+                conf.access_key,
+                conf.secret_key,
+                conf.endpoint_url,
+                conf.addressing_style,
+            )
+        return self._s3_clients_h[bucket_name]
+
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data, also select diffect bucket client for each
+        request based on the bucket.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write.
+        """
+        if path.startswith('s3://'):
+            bucket_name, path = parse_s3path(path)
+            s3_writer = self.__get_s3_client(bucket_name)
+        else:
+            s3_writer = self.__get_s3_client(self.default_bucket)
+            if self.default_prefix:
+                path = self.default_prefix + '/' + path
+        return s3_writer.write(path, data)

+ 72 - 0
mineru/data/data_reader_writer/s3.py

@@ -0,0 +1,72 @@
+from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
+from ..utils.schemas import S3Config
+
+
+class S3DataReader(MultiBucketS3DataReader):
+    def __init__(
+        self,
+        default_prefix_without_bucket: str,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            default_prefix_without_bucket: prefix that not contains bucket
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            f'{bucket}/{default_prefix_without_bucket}',
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )
+
+
+class S3DataWriter(MultiBucketS3DataWriter):
+    def __init__(
+        self,
+        default_prefix_without_bucket: str,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 writer client.
+
+        Args:
+            default_prefix_without_bucket: prefix that not contains bucket
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        super().__init__(
+            f'{bucket}/{default_prefix_without_bucket}',
+            [
+                S3Config(
+                    bucket_name=bucket,
+                    access_key=ak,
+                    secret_key=sk,
+                    endpoint_url=endpoint_url,
+                    addressing_style=addressing_style,
+                )
+            ],
+        )

+ 6 - 0
mineru/data/io/__init__.py

@@ -0,0 +1,6 @@
+
+from .base import IOReader, IOWriter
+from .http import HttpReader, HttpWriter
+from .s3 import S3Reader, S3Writer
+
+__all__ = ['IOReader', 'IOWriter', 'HttpReader', 'HttpWriter', 'S3Reader', 'S3Writer']

+ 42 - 0
mineru/data/io/base.py

@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+
+
+class IOReader(ABC):
+    @abstractmethod
+    def read(self, path: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        pass
+
+    @abstractmethod
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        pass
+
+
+class IOWriter(ABC):
+
+    @abstractmethod
+    def write(self, path: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        pass

+ 37 - 0
mineru/data/io/http.py

@@ -0,0 +1,37 @@
+
+import io
+
+import requests
+
+from .base import IOReader, IOWriter
+
+
+class HttpReader(IOReader):
+
+    def read(self, url: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return requests.get(url).content
+
+    def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Not Implemented."""
+        raise NotImplementedError
+
+
+class HttpWriter(IOWriter):
+    def write(self, url: str, data: bytes) -> None:
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        files = {'file': io.BytesIO(data)}
+        response = requests.post(url, files=files)
+        assert 300 > response.status_code and response.status_code > 199

+ 114 - 0
mineru/data/io/s3.py

@@ -0,0 +1,114 @@
+import boto3
+from botocore.config import Config
+
+from ..io.base import IOReader, IOWriter
+
+
+class S3Reader(IOReader):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def read(self, key: str) -> bytes:
+        """Read the file.
+
+        Args:
+            path (str): file path to read
+
+        Returns:
+            bytes: the content of the file
+        """
+        return self.read_at(key)
+
+    def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
+        """Read at offset and limit.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            offset (int, optional): the number of bytes skipped. Defaults to 0.
+            limit (int, optional): the length of bytes want to read. Defaults to -1.
+
+        Returns:
+            bytes: the content of file
+        """
+        if limit > -1:
+            range_header = f'bytes={offset}-{offset+limit-1}'
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=range_header
+            )
+        else:
+            res = self._s3_client.get_object(
+                Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
+            )
+        return res['Body'].read()
+
+
+class S3Writer(IOWriter):
+    def __init__(
+        self,
+        bucket: str,
+        ak: str,
+        sk: str,
+        endpoint_url: str,
+        addressing_style: str = 'auto',
+    ):
+        """s3 reader client.
+
+        Args:
+            bucket (str): bucket name
+            ak (str): access key
+            sk (str): secret key
+            endpoint_url (str): endpoint url of s3
+            addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
+            refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
+        """
+        self._bucket = bucket
+        self._ak = ak
+        self._sk = sk
+        self._s3_client = boto3.client(
+            service_name='s3',
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=endpoint_url,
+            config=Config(
+                s3={'addressing_style': addressing_style},
+                retries={'max_attempts': 5, 'mode': 'standard'},
+            ),
+        )
+
+    def write(self, key: str, data: bytes):
+        """Write file with data.
+
+        Args:
+            path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
+            data (bytes): the data want to write
+        """
+        self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)

+ 1 - 0
mineru/data/utils/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 40 - 0
mineru/data/utils/exceptions.py

@@ -0,0 +1,40 @@
+# Copyright (c) Opendatalab. All rights reserved.
+
+class FileNotExisted(Exception):
+
+    def __init__(self, path):
+        self.path = path
+
+    def __str__(self):
+        return f'File {self.path} does not exist.'
+
+
+class InvalidConfig(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid config: {self.msg}'
+
+
+class InvalidParams(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Invalid params: {self.msg}'
+
+
+class EmptyData(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'Empty data: {self.msg}'
+
+class CUDA_NOT_AVAILABLE(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'CUDA not available: {self.msg}'

+ 33 - 0
mineru/data/utils/path_utils.py

@@ -0,0 +1,33 @@
+# Copyright (c) Opendatalab. All rights reserved.
+
+
+def remove_non_official_s3_args(s3path):
+    """
+    example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
+    """
+    arr = s3path.split("?")
+    return arr[0]
+
+def parse_s3path(s3path: str):
+    # from s3pathlib import S3Path
+    # p = S3Path(remove_non_official_s3_args(s3path))
+    # return p.bucket, p.key
+    s3path = remove_non_official_s3_args(s3path).strip()
+    if s3path.startswith(('s3://', 's3a://')):
+        prefix, path = s3path.split('://', 1)
+        bucket_name, key = path.split('/', 1)
+        return bucket_name, key
+    elif s3path.startswith('/'):
+        raise ValueError("The provided path starts with '/'. This does not conform to a valid S3 path format.")
+    else:
+        raise ValueError("Invalid S3 path format. Expected 's3://bucket-name/key' or 's3a://bucket-name/key'.")
+
+
+def parse_s3_range_params(s3path: str):
+    """
+    example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
+    """
+    arr = s3path.split("?bytes=")
+    if len(arr) == 1:
+        return None
+    return arr[1].split(",")

+ 20 - 0
mineru/data/utils/schemas.py

@@ -0,0 +1,20 @@
+# Copyright (c) Opendatalab. All rights reserved.
+
+from pydantic import BaseModel, Field
+
+
+class S3Config(BaseModel):
+    """S3 config
+    """
+    bucket_name: str = Field(description='s3 bucket name', min_length=1)
+    access_key: str = Field(description='s3 access key', min_length=1)
+    secret_key: str = Field(description='s3 secret key', min_length=1)
+    endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
+    addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
+
+
+class PageInfo(BaseModel):
+    """The width and height of page
+    """
+    w: float = Field(description='the width of page')
+    h: float = Field(description='the height of page')