encrypted.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. from typing import Any
  3. from langgraph.checkpoint.serde.base import CipherProtocol, SerializerProtocol
  4. from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
  5. class EncryptedSerializer(SerializerProtocol):
  6. """Serializer that encrypts and decrypts data using an encryption protocol."""
  7. def __init__(
  8. self, cipher: CipherProtocol, serde: SerializerProtocol = JsonPlusSerializer()
  9. ) -> None:
  10. self.cipher = cipher
  11. self.serde = serde
  12. def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
  13. """Serialize an object to a tuple `(type, bytes)` and encrypt the bytes."""
  14. # serialize data
  15. typ, data = self.serde.dumps_typed(obj)
  16. # encrypt data
  17. ciphername, ciphertext = self.cipher.encrypt(data)
  18. # add cipher name to type
  19. return f"{typ}+{ciphername}", ciphertext
  20. def loads_typed(self, data: tuple[str, bytes]) -> Any:
  21. enc_cipher, ciphertext = data
  22. # unencrypted data
  23. if "+" not in enc_cipher:
  24. return self.serde.loads_typed(data)
  25. # extract cipher name
  26. typ, ciphername = enc_cipher.split("+", 1)
  27. # decrypt data
  28. decrypted_data = self.cipher.decrypt(ciphername, ciphertext)
  29. # deserialize data
  30. return self.serde.loads_typed((typ, decrypted_data))
  31. @classmethod
  32. def from_pycryptodome_aes(
  33. cls, serde: SerializerProtocol = JsonPlusSerializer(), **kwargs: Any
  34. ) -> "EncryptedSerializer":
  35. """Create an `EncryptedSerializer` using AES encryption."""
  36. try:
  37. from Crypto.Cipher import AES # type: ignore
  38. except ImportError:
  39. raise ImportError(
  40. "Pycryptodome is not installed. Please install it with `pip install pycryptodome`."
  41. ) from None
  42. # check if AES key is provided
  43. if "key" in kwargs:
  44. key: bytes = kwargs.pop("key")
  45. else:
  46. key_str = os.getenv("LANGGRAPH_AES_KEY")
  47. if key_str is None:
  48. raise ValueError("LANGGRAPH_AES_KEY environment variable is not set.")
  49. key = key_str.encode()
  50. if len(key) not in (16, 24, 32):
  51. raise ValueError("LANGGRAPH_AES_KEY must be 16, 24, or 32 bytes long.")
  52. # set default mode to EAX if not provided
  53. if kwargs.get("mode") is None:
  54. kwargs["mode"] = AES.MODE_EAX
  55. class PycryptodomeAesCipher(CipherProtocol):
  56. def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
  57. cipher = AES.new(key, **kwargs)
  58. ciphertext, tag = cipher.encrypt_and_digest(plaintext)
  59. return "aes", cipher.nonce + tag + ciphertext
  60. def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
  61. assert ciphername == "aes", f"Unsupported cipher: {ciphername}"
  62. nonce = ciphertext[:16]
  63. tag = ciphertext[16:32]
  64. actual_ciphertext = ciphertext[32:]
  65. cipher = AES.new(key, **kwargs, nonce=nonce)
  66. return cipher.decrypt_and_verify(actual_ciphertext, tag)
  67. return cls(PycryptodomeAesCipher(), serde)