storage.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. from os import PathLike
  16. from pathlib import Path
  17. from typing import Any, Dict, Optional, Protocol, Union, runtime_checkable
  18. from pydantic import BaseModel, Discriminator, SecretStr, TypeAdapter
  19. from typing_extensions import Annotated, Literal, assert_never
  20. from ....utils.deps import class_requires_deps
  21. __all__ = [
  22. "InMemoryStorageConfig",
  23. "FileSystemStorageConfig",
  24. "BOSConfig",
  25. "FileStorageConfig",
  26. "SupportsGetURL",
  27. "Storage",
  28. "InMemoryStorage",
  29. "FileSystemStorage",
  30. "BOS",
  31. "create_storage",
  32. ]
  33. class InMemoryStorageConfig(BaseModel):
  34. type: Literal["memory"] = "memory"
  35. class FileSystemStorageConfig(BaseModel):
  36. directory: Union[str, PathLike]
  37. type: Literal["file_system"] = "file_system"
  38. class BOSConfig(BaseModel):
  39. endpoint: str
  40. ak: SecretStr
  41. sk: SecretStr
  42. bucket_name: str
  43. key_prefix: Optional[str] = None
  44. connection_timeout_in_mills: Optional[int] = None
  45. type: Literal["bos"] = "bos"
  46. FileStorageConfig = Annotated[
  47. Union[InMemoryStorageConfig, FileSystemStorageConfig, BOSConfig],
  48. Discriminator("type"),
  49. ]
  50. @runtime_checkable
  51. class SupportsGetURL(Protocol):
  52. def get_url(self, key: str) -> str: ...
  53. class Storage(metaclass=abc.ABCMeta):
  54. @abc.abstractmethod
  55. def get(self, key: str) -> bytes:
  56. raise NotImplementedError
  57. @abc.abstractmethod
  58. def set(self, key: str, value: bytes) -> None:
  59. raise NotImplementedError
  60. @abc.abstractmethod
  61. def delete(self, key: str) -> None:
  62. raise NotImplementedError
  63. class InMemoryStorage(Storage):
  64. def __init__(self, config: InMemoryStorageConfig) -> None:
  65. super().__init__()
  66. self._data: Dict[str, bytes] = {}
  67. def get(self, key: str) -> bytes:
  68. return self._data[key]
  69. def set(self, key: str, value: bytes) -> None:
  70. self._data[key] = value
  71. def delete(self, key: str) -> None:
  72. del self._data[key]
  73. class FileSystemStorage(Storage):
  74. def __init__(self, config: FileSystemStorageConfig) -> None:
  75. super().__init__()
  76. self._directory = Path(config.directory)
  77. self._directory.mkdir(exist_ok=True)
  78. def get(self, key: str) -> bytes:
  79. with open(self._get_file_path(key), "rb") as f:
  80. contents = f.read()
  81. return contents
  82. def set(self, key: str, value: bytes) -> None:
  83. path = self._get_file_path(key)
  84. path.parent.mkdir(exist_ok=True)
  85. with open(path, "wb") as f:
  86. f.write(value)
  87. def delete(self, key: str) -> None:
  88. file_path = self._get_file_path(key)
  89. file_path.unlink(missing_ok=True)
  90. def _get_file_path(self, key: str) -> Path:
  91. return self._directory / key
  92. @class_requires_deps("bce-python-sdk")
  93. class BOS(Storage):
  94. def __init__(self, config: BOSConfig) -> None:
  95. from baidubce.auth.bce_credentials import BceCredentials
  96. from baidubce.bce_client_configuration import BceClientConfiguration
  97. from baidubce.services.bos.bos_client import BosClient
  98. super().__init__()
  99. bos_cfg = BceClientConfiguration(
  100. credentials=BceCredentials(
  101. config.ak.get_secret_value(), config.sk.get_secret_value()
  102. ),
  103. endpoint=config.endpoint,
  104. connection_timeout_in_mills=config.connection_timeout_in_mills,
  105. )
  106. self._client = BosClient(bos_cfg)
  107. self._bucket_name = config.bucket_name
  108. self._key_prefix = config.key_prefix
  109. def get(self, key: str) -> bytes:
  110. key = self._get_full_key(key)
  111. return self._client.get_object_as_string(bucket_name=self._bucket_name, key=key)
  112. def set(self, key: str, value: bytes) -> None:
  113. key = self._get_full_key(key)
  114. self._client.put_object_from_string(
  115. bucket=self._bucket_name, key=key, data=value
  116. )
  117. def delete(self, key: str) -> None:
  118. key = self._get_full_key(key)
  119. self._client.delete_object(bucket_name=self._bucket_name, key=key)
  120. def get_url(self, key: str) -> str:
  121. key = self._get_full_key(key)
  122. return self._client.generate_pre_signed_url(
  123. self._bucket_name, key, expiration_in_seconds=-1
  124. ).decode("ascii")
  125. def _get_full_key(self, key: str) -> str:
  126. if self._key_prefix:
  127. return f"{self._key_prefix}/{key}"
  128. return key
  129. def create_storage(dic: Dict[str, Any], /) -> Storage:
  130. config = TypeAdapter(FileStorageConfig).validate_python(dic)
  131. if config.type == "memory":
  132. return InMemoryStorage(config)
  133. elif config.type == "file_system":
  134. return FileSystemStorage(config)
  135. elif config.type == "bos":
  136. return BOS(config)
  137. else:
  138. assert_never(config)