topic.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from __future__ import annotations
  2. from collections.abc import Iterator, Sequence
  3. from typing import Any, Generic
  4. from typing_extensions import Self
  5. from langgraph._internal._typing import MISSING
  6. from langgraph.channels.base import BaseChannel, Value
  7. from langgraph.errors import EmptyChannelError
  8. __all__ = ("Topic",)
  9. def _flatten(values: Sequence[Value | list[Value]]) -> Iterator[Value]:
  10. for value in values:
  11. if isinstance(value, list):
  12. yield from value
  13. else:
  14. yield value
  15. class Topic(
  16. Generic[Value],
  17. BaseChannel[Sequence[Value], Value | list[Value], list[Value]],
  18. ):
  19. """A configurable PubSub Topic.
  20. Args:
  21. typ: The type of the value stored in the channel.
  22. accumulate: Whether to accumulate values across steps. If `False`, the channel will be emptied after each step.
  23. """
  24. __slots__ = ("values", "accumulate")
  25. def __init__(self, typ: type[Value], accumulate: bool = False) -> None:
  26. super().__init__(typ)
  27. # attrs
  28. self.accumulate = accumulate
  29. # state
  30. self.values = list[Value]()
  31. def __eq__(self, value: object) -> bool:
  32. return isinstance(value, Topic) and value.accumulate == self.accumulate
  33. @property
  34. def ValueType(self) -> Any:
  35. """The type of the value stored in the channel."""
  36. return Sequence[self.typ] # type: ignore[name-defined]
  37. @property
  38. def UpdateType(self) -> Any:
  39. """The type of the update received by the channel."""
  40. return self.typ | list[self.typ] # type: ignore[name-defined]
  41. def copy(self) -> Self:
  42. """Return a copy of the channel."""
  43. empty = self.__class__(self.typ, self.accumulate)
  44. empty.key = self.key
  45. empty.values = self.values.copy()
  46. return empty
  47. def checkpoint(self) -> list[Value]:
  48. return self.values
  49. def from_checkpoint(self, checkpoint: list[Value]) -> Self:
  50. empty = self.__class__(self.typ, self.accumulate)
  51. empty.key = self.key
  52. if checkpoint is not MISSING:
  53. if isinstance(checkpoint, tuple):
  54. # backwards compatibility
  55. empty.values = checkpoint[1]
  56. else:
  57. empty.values = checkpoint
  58. return empty
  59. def update(self, values: Sequence[Value | list[Value]]) -> bool:
  60. updated = False
  61. if not self.accumulate:
  62. updated = bool(self.values)
  63. self.values = list[Value]()
  64. if flat_values := tuple(_flatten(values)):
  65. updated = True
  66. self.values.extend(flat_values)
  67. return updated
  68. def get(self) -> Sequence[Value]:
  69. if self.values:
  70. return list(self.values)
  71. else:
  72. raise EmptyChannelError
  73. def is_available(self) -> bool:
  74. return bool(self.values)