_write.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from __future__ import annotations
  2. from collections.abc import Callable, Sequence
  3. from typing import (
  4. Any,
  5. NamedTuple,
  6. TypeVar,
  7. cast,
  8. )
  9. from langchain_core.runnables import Runnable, RunnableConfig
  10. from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, TASKS
  11. from langgraph._internal._runnable import RunnableCallable
  12. from langgraph._internal._typing import MISSING
  13. from langgraph.errors import InvalidUpdateError
  14. from langgraph.types import Send
  15. TYPE_SEND = Callable[[Sequence[tuple[str, Any]]], None]
  16. R = TypeVar("R", bound=Runnable)
  17. SKIP_WRITE = object()
  18. PASSTHROUGH = object()
  19. class ChannelWriteEntry(NamedTuple):
  20. channel: str
  21. """Channel name to write to."""
  22. value: Any = PASSTHROUGH
  23. """Value to write, or PASSTHROUGH to use the input."""
  24. skip_none: bool = False
  25. """Whether to skip writing if the value is None."""
  26. mapper: Callable | None = None
  27. """Function to transform the value before writing."""
  28. class ChannelWriteTupleEntry(NamedTuple):
  29. mapper: Callable[[Any], Sequence[tuple[str, Any]] | None]
  30. """Function to extract tuples from value."""
  31. value: Any = PASSTHROUGH
  32. """Value to write, or PASSTHROUGH to use the input."""
  33. static: Sequence[tuple[str, Any, str | None]] | None = None
  34. """Optional, declared writes for static analysis."""
  35. class ChannelWrite(RunnableCallable):
  36. """Implements the logic for sending writes to CONFIG_KEY_SEND.
  37. Can be used as a runnable or as a static method to call imperatively."""
  38. writes: list[ChannelWriteEntry | ChannelWriteTupleEntry | Send]
  39. """Sequence of write entries or Send objects to write."""
  40. def __init__(
  41. self,
  42. writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
  43. *,
  44. tags: Sequence[str] | None = None,
  45. ):
  46. super().__init__(
  47. func=self._write,
  48. afunc=self._awrite,
  49. name=None,
  50. tags=tags,
  51. trace=False,
  52. )
  53. self.writes = cast(
  54. list[ChannelWriteEntry | ChannelWriteTupleEntry | Send], writes
  55. )
  56. def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
  57. if not name:
  58. name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else '...' if isinstance(w, ChannelWriteTupleEntry) else w.node for w in self.writes)}>"
  59. return super().get_name(suffix, name=name)
  60. def _write(self, input: Any, config: RunnableConfig) -> None:
  61. writes = [
  62. ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
  63. if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
  64. else ChannelWriteTupleEntry(write.mapper, input)
  65. if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
  66. else write
  67. for write in self.writes
  68. ]
  69. self.do_write(
  70. config,
  71. writes,
  72. )
  73. return input
  74. async def _awrite(self, input: Any, config: RunnableConfig) -> None:
  75. writes = [
  76. ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
  77. if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
  78. else ChannelWriteTupleEntry(write.mapper, input)
  79. if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
  80. else write
  81. for write in self.writes
  82. ]
  83. self.do_write(
  84. config,
  85. writes,
  86. )
  87. return input
  88. @staticmethod
  89. def do_write(
  90. config: RunnableConfig,
  91. writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
  92. allow_passthrough: bool = True,
  93. ) -> None:
  94. # validate
  95. for w in writes:
  96. if isinstance(w, ChannelWriteEntry):
  97. if w.channel == TASKS:
  98. raise InvalidUpdateError(
  99. "Cannot write to the reserved channel TASKS"
  100. )
  101. if w.value is PASSTHROUGH and not allow_passthrough:
  102. raise InvalidUpdateError("PASSTHROUGH value must be replaced")
  103. if isinstance(w, ChannelWriteTupleEntry):
  104. if w.value is PASSTHROUGH and not allow_passthrough:
  105. raise InvalidUpdateError("PASSTHROUGH value must be replaced")
  106. # if we want to persist writes found before hitting a ParentCommand
  107. # can move this to a finally block
  108. write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
  109. write(_assemble_writes(writes))
  110. @staticmethod
  111. def is_writer(runnable: Runnable) -> bool:
  112. """Used by PregelNode to distinguish between writers and other runnables."""
  113. return (
  114. isinstance(runnable, ChannelWrite)
  115. or getattr(runnable, "_is_channel_writer", MISSING) is not MISSING
  116. )
  117. @staticmethod
  118. def get_static_writes(
  119. runnable: Runnable,
  120. ) -> Sequence[tuple[str, Any, str | None]] | None:
  121. """Used to get conditional writes a writer declares for static analysis."""
  122. if isinstance(runnable, ChannelWrite):
  123. return [
  124. w
  125. for entry in runnable.writes
  126. if isinstance(entry, ChannelWriteTupleEntry) and entry.static
  127. for w in entry.static
  128. ] or None
  129. elif writes := getattr(runnable, "_is_channel_writer", MISSING):
  130. if writes is not MISSING:
  131. writes = cast(
  132. Sequence[tuple[ChannelWriteEntry | Send, str | None]],
  133. writes,
  134. )
  135. entries = [e for e, _ in writes]
  136. labels = [la for _, la in writes]
  137. return [(*t, la) for t, la in zip(_assemble_writes(entries), labels)]
  138. @staticmethod
  139. def register_writer(
  140. runnable: R,
  141. static: Sequence[tuple[ChannelWriteEntry | Send, str | None]] | None = None,
  142. ) -> R:
  143. """Used to mark a runnable as a writer, so that it can be detected by is_writer.
  144. Instances of ChannelWrite are automatically marked as writers.
  145. Optionally, a list of declared writes can be passed for static analysis."""
  146. # using object.__setattr__ to work around objects that override __setattr__
  147. # eg. pydantic models and dataclasses
  148. object.__setattr__(runnable, "_is_channel_writer", static)
  149. return runnable
  150. def _assemble_writes(
  151. writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
  152. ) -> list[tuple[str, Any]]:
  153. """Assembles the writes into a list of tuples."""
  154. tuples: list[tuple[str, Any]] = []
  155. for w in writes:
  156. if isinstance(w, Send):
  157. tuples.append((TASKS, w))
  158. elif isinstance(w, ChannelWriteTupleEntry):
  159. if ww := w.mapper(w.value):
  160. tuples.extend(ww)
  161. elif isinstance(w, ChannelWriteEntry):
  162. value = w.mapper(w.value) if w.mapper is not None else w.value
  163. if value is SKIP_WRITE:
  164. continue
  165. if w.skip_none and value is None:
  166. continue
  167. tuples.append((w.channel, value))
  168. else:
  169. raise ValueError(f"Invalid write entry: {w}")
  170. return tuples