| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import os
- from typing import Any
- from langgraph.checkpoint.serde.base import CipherProtocol, SerializerProtocol
- from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
- class EncryptedSerializer(SerializerProtocol):
- """Serializer that encrypts and decrypts data using an encryption protocol."""
- def __init__(
- self, cipher: CipherProtocol, serde: SerializerProtocol = JsonPlusSerializer()
- ) -> None:
- self.cipher = cipher
- self.serde = serde
- def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
- """Serialize an object to a tuple `(type, bytes)` and encrypt the bytes."""
- # serialize data
- typ, data = self.serde.dumps_typed(obj)
- # encrypt data
- ciphername, ciphertext = self.cipher.encrypt(data)
- # add cipher name to type
- return f"{typ}+{ciphername}", ciphertext
- def loads_typed(self, data: tuple[str, bytes]) -> Any:
- enc_cipher, ciphertext = data
- # unencrypted data
- if "+" not in enc_cipher:
- return self.serde.loads_typed(data)
- # extract cipher name
- typ, ciphername = enc_cipher.split("+", 1)
- # decrypt data
- decrypted_data = self.cipher.decrypt(ciphername, ciphertext)
- # deserialize data
- return self.serde.loads_typed((typ, decrypted_data))
- @classmethod
- def from_pycryptodome_aes(
- cls, serde: SerializerProtocol = JsonPlusSerializer(), **kwargs: Any
- ) -> "EncryptedSerializer":
- """Create an `EncryptedSerializer` using AES encryption."""
- try:
- from Crypto.Cipher import AES # type: ignore
- except ImportError:
- raise ImportError(
- "Pycryptodome is not installed. Please install it with `pip install pycryptodome`."
- ) from None
- # check if AES key is provided
- if "key" in kwargs:
- key: bytes = kwargs.pop("key")
- else:
- key_str = os.getenv("LANGGRAPH_AES_KEY")
- if key_str is None:
- raise ValueError("LANGGRAPH_AES_KEY environment variable is not set.")
- key = key_str.encode()
- if len(key) not in (16, 24, 32):
- raise ValueError("LANGGRAPH_AES_KEY must be 16, 24, or 32 bytes long.")
- # set default mode to EAX if not provided
- if kwargs.get("mode") is None:
- kwargs["mode"] = AES.MODE_EAX
- class PycryptodomeAesCipher(CipherProtocol):
- def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
- cipher = AES.new(key, **kwargs)
- ciphertext, tag = cipher.encrypt_and_digest(plaintext)
- return "aes", cipher.nonce + tag + ciphertext
- def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
- assert ciphername == "aes", f"Unsupported cipher: {ciphername}"
- nonce = ciphertext[:16]
- tag = ciphertext[16:32]
- actual_ciphertext = ciphertext[32:]
- cipher = AES.new(key, **kwargs, nonce=nonce)
- return cipher.decrypt_and_verify(actual_ciphertext, tag)
- return cls(PycryptodomeAesCipher(), serde)
|