file_storage.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import uuid
  16. from typing import Any, Dict, Literal, Optional, Union
  17. from baidubce.auth.bce_credentials import BceCredentials
  18. from baidubce.bce_client_configuration import BceClientConfiguration
  19. from baidubce.services.bos.bos_client import BosClient
  20. from pydantic import BaseModel, Discriminator, SecretStr, TypeAdapter
  21. from typing_extensions import Annotated, assert_never
  22. class InMemoryStorageConfig(BaseModel):
  23. type: Literal["memory"] = "memory"
  24. class BOSConfig(BaseModel):
  25. endpoint: str
  26. ak: SecretStr
  27. sk: SecretStr
  28. bucket_name: str
  29. key_prefix: Optional[str] = None
  30. connection_timeout_in_mills: Optional[int] = None
  31. type: Literal["bos"] = "bos"
  32. FileStorageConfig = Union[InMemoryStorageConfig, BOSConfig]
  33. def parse_file_storage_config(dic: Dict[str, Any]) -> FileStorageConfig:
  34. # XXX: mypy deduces a wrong type
  35. return TypeAdapter(
  36. Annotated[FileStorageConfig, Discriminator("type")]
  37. ).validate_python(
  38. dic
  39. ) # type: ignore[return-value]
  40. def postprocess_file(
  41. file: bytes, config: FileStorageConfig, key: Optional[str] = None
  42. ) -> str:
  43. if config.type == "memory":
  44. return base64.b64encode(file).decode("ascii")
  45. elif config.type == "bos":
  46. # TODO: Currently BOS clients are created on the fly since they are not
  47. # thread-safe. Should we use a background thread with a queue or use a
  48. # dedicated thread?
  49. bos_cfg = BceClientConfiguration(
  50. credentials=BceCredentials(
  51. config.ak.get_secret_value(), config.sk.get_secret_value()
  52. ),
  53. endpoint=config.endpoint,
  54. connection_timeout_in_mills=config.connection_timeout_in_mills,
  55. )
  56. client = BosClient(bos_cfg)
  57. if key is None:
  58. key = str(uuid.uuid4())
  59. if config.key_prefix:
  60. key = f"{config.key_prefix}{key}"
  61. client.put_object_from_string(bucket=config.bucket_name, key=key, data=file)
  62. url = client.generate_pre_signed_url(
  63. config.bucket_name, key, expiration_in_seconds=-1
  64. ).decode("ascii")
  65. return url
  66. else:
  67. assert_never(config.type)