any_value.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. from typing import Any, Generic
  4. from typing_extensions import Self
  5. from langgraph._internal._typing import MISSING
  6. from langgraph.channels.base import BaseChannel, Value
  7. from langgraph.errors import EmptyChannelError
  8. __all__ = ("AnyValue",)
  9. class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
  10. """Stores the last value received, assumes that if multiple values are
  11. received, they are all equal."""
  12. __slots__ = ("typ", "value")
  13. value: Value | Any
  14. def __init__(self, typ: Any, key: str = "") -> None:
  15. super().__init__(typ, key)
  16. self.value = MISSING
  17. def __eq__(self, value: object) -> bool:
  18. return isinstance(value, AnyValue)
  19. @property
  20. def ValueType(self) -> type[Value]:
  21. """The type of the value stored in the channel."""
  22. return self.typ
  23. @property
  24. def UpdateType(self) -> type[Value]:
  25. """The type of the update received by the channel."""
  26. return self.typ
  27. def copy(self) -> Self:
  28. """Return a copy of the channel."""
  29. empty = self.__class__(self.typ, self.key)
  30. empty.value = self.value
  31. return empty
  32. def from_checkpoint(self, checkpoint: Value) -> Self:
  33. empty = self.__class__(self.typ, self.key)
  34. if checkpoint is not MISSING:
  35. empty.value = checkpoint
  36. return empty
  37. def update(self, values: Sequence[Value]) -> bool:
  38. if len(values) == 0:
  39. if self.value is MISSING:
  40. return False
  41. else:
  42. self.value = MISSING
  43. return True
  44. self.value = values[-1]
  45. return True
  46. def get(self) -> Value:
  47. if self.value is MISSING:
  48. raise EmptyChannelError()
  49. return self.value
  50. def is_available(self) -> bool:
  51. return self.value is not MISSING
  52. def checkpoint(self) -> Value:
  53. return self.value