storage.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. from os import PathLike
  16. from pathlib import Path
  17. from typing import Any, Dict, Optional, Protocol, Union, runtime_checkable
  18. from baidubce.auth.bce_credentials import BceCredentials
  19. from baidubce.bce_client_configuration import BceClientConfiguration
  20. from baidubce.services.bos.bos_client import BosClient
  21. from pydantic import BaseModel, Discriminator, SecretStr, TypeAdapter
  22. from typing_extensions import Annotated, Literal, assert_never
  23. __all__ = [
  24. "InMemoryStorageConfig",
  25. "FileSystemStorageConfig",
  26. "BOSConfig",
  27. "FileStorageConfig",
  28. "SupportsGetURL",
  29. "Storage",
  30. "InMemoryStorage",
  31. "FileSystemStorage",
  32. "BOS",
  33. "create_storage",
  34. ]
  35. class InMemoryStorageConfig(BaseModel):
  36. type: Literal["memory"] = "memory"
  37. class FileSystemStorageConfig(BaseModel):
  38. directory: Union[str, PathLike]
  39. type: Literal["file_system"] = "file_system"
  40. class BOSConfig(BaseModel):
  41. endpoint: str
  42. ak: SecretStr
  43. sk: SecretStr
  44. bucket_name: str
  45. key_prefix: Optional[str] = None
  46. connection_timeout_in_mills: Optional[int] = None
  47. type: Literal["bos"] = "bos"
  48. FileStorageConfig = Annotated[
  49. Union[InMemoryStorageConfig, FileSystemStorageConfig, BOSConfig],
  50. Discriminator("type"),
  51. ]
  52. @runtime_checkable
  53. class SupportsGetURL(Protocol):
  54. def get_url(self, key: str) -> str: ...
  55. class Storage(metaclass=abc.ABCMeta):
  56. @abc.abstractmethod
  57. def get(self, key: str) -> bytes:
  58. raise NotImplementedError
  59. @abc.abstractmethod
  60. def set(self, key: str, value: bytes) -> None:
  61. raise NotImplementedError
  62. @abc.abstractmethod
  63. def delete(self, key: str) -> None:
  64. raise NotImplementedError
  65. class InMemoryStorage(Storage):
  66. def __init__(self, config: InMemoryStorageConfig) -> None:
  67. super().__init__()
  68. self._data: Dict[str, bytes] = {}
  69. def get(self, key: str) -> bytes:
  70. return self._data[key]
  71. def set(self, key: str, value: bytes) -> None:
  72. self._data[key] = value
  73. def delete(self, key: str) -> None:
  74. del self._data[key]
  75. class FileSystemStorage(Storage):
  76. def __init__(self, config: FileSystemStorageConfig) -> None:
  77. super().__init__()
  78. self._directory = Path(config.directory)
  79. self._directory.mkdir(exist_ok=True)
  80. def get(self, key: str) -> bytes:
  81. with open(self._get_file_path(key), "rb") as f:
  82. contents = f.read()
  83. return contents
  84. def set(self, key: str, value: bytes) -> None:
  85. path = self._get_file_path(key)
  86. path.parent.mkdir(exist_ok=True)
  87. with open(path, "wb") as f:
  88. f.write(value)
  89. def delete(self, key: str) -> None:
  90. file_path = self._get_file_path(key)
  91. file_path.unlink(missing_ok=True)
  92. def _get_file_path(self, key: str) -> Path:
  93. return self._directory / key
  94. class BOS(Storage):
  95. def __init__(self, config: BOSConfig) -> None:
  96. super().__init__()
  97. bos_cfg = BceClientConfiguration(
  98. credentials=BceCredentials(
  99. config.ak.get_secret_value(), config.sk.get_secret_value()
  100. ),
  101. endpoint=config.endpoint,
  102. connection_timeout_in_mills=config.connection_timeout_in_mills,
  103. )
  104. self._client = BosClient(bos_cfg)
  105. self._bucket_name = config.bucket_name
  106. self._key_prefix = config.key_prefix
  107. def get(self, key: str) -> bytes:
  108. key = self._get_full_key(key)
  109. return self._client.get_object_as_string(bucket_name=self._bucket_name, key=key)
  110. def set(self, key: str, value: bytes) -> None:
  111. key = self._get_full_key(key)
  112. self._client.put_object_from_string(
  113. bucket=self._bucket_name, key=key, data=value
  114. )
  115. def delete(self, key: str) -> None:
  116. key = self._get_full_key(key)
  117. self._client.delete_object(bucket_name=self._bucket_name, key=key)
  118. def get_url(self, key: str) -> str:
  119. key = self._get_full_key(key)
  120. return self._client.generate_pre_signed_url(
  121. self._bucket_name, key, expiration_in_seconds=-1
  122. ).decode("ascii")
  123. def _get_full_key(self, key: str) -> str:
  124. if self._key_prefix:
  125. return f"{self._key_prefix}/{key}"
  126. return key
  127. def create_storage(dic: Dict[str, Any], /) -> Storage:
  128. config = TypeAdapter(FileStorageConfig).validate_python(dic)
  129. if config.type == "memory":
  130. return InMemoryStorage(config)
  131. elif config.type == "file_system":
  132. return FileSystemStorage(config)
  133. elif config.type == "bos":
  134. return BOS(config)
  135. else:
  136. assert_never(config)