| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- from __future__ import annotations
- from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
- from functools import cached_property
- from typing import (
- Any,
- )
- from langchain_core.runnables import Runnable, RunnableConfig
- from langgraph._internal._config import merge_configs
- from langgraph._internal._constants import CONF, CONFIG_KEY_READ
- from langgraph._internal._runnable import RunnableCallable, RunnableSeq
- from langgraph.pregel._utils import find_subgraph_pregel
- from langgraph.pregel._write import ChannelWrite
- from langgraph.pregel.protocol import PregelProtocol
- from langgraph.types import CachePolicy, RetryPolicy
- READ_TYPE = Callable[[str | Sequence[str], bool], Any | dict[str, Any]]
- INPUT_CACHE_KEY_TYPE = tuple[Callable[..., Any], tuple[str, ...]]
- class ChannelRead(RunnableCallable):
- """Implements the logic for reading state from CONFIG_KEY_READ.
- Usable both as a runnable as well as a static method to call imperatively."""
- channel: str | list[str]
- fresh: bool = False
- mapper: Callable[[Any], Any] | None = None
- def __init__(
- self,
- channel: str | list[str],
- *,
- fresh: bool = False,
- mapper: Callable[[Any], Any] | None = None,
- tags: list[str] | None = None,
- ) -> None:
- super().__init__(
- func=self._read,
- afunc=self._aread,
- tags=tags,
- name=None,
- trace=False,
- )
- self.fresh = fresh
- self.mapper = mapper
- self.channel = channel
- def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
- if name:
- pass
- elif isinstance(self.channel, str):
- name = f"ChannelRead<{self.channel}>"
- else:
- name = f"ChannelRead<{','.join(self.channel)}>"
- return super().get_name(suffix, name=name)
- def _read(self, _: Any, config: RunnableConfig) -> Any:
- return self.do_read(
- config, select=self.channel, fresh=self.fresh, mapper=self.mapper
- )
- async def _aread(self, _: Any, config: RunnableConfig) -> Any:
- return self.do_read(
- config, select=self.channel, fresh=self.fresh, mapper=self.mapper
- )
- @staticmethod
- def do_read(
- config: RunnableConfig,
- *,
- select: str | list[str],
- fresh: bool = False,
- mapper: Callable[[Any], Any] | None = None,
- ) -> Any:
- try:
- read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
- except KeyError:
- raise RuntimeError(
- "Not configured with a read function"
- "Make sure to call in the context of a Pregel process"
- )
- if mapper:
- return mapper(read(select, fresh))
- else:
- return read(select, fresh)
- DEFAULT_BOUND = RunnableCallable(lambda input: input)
- class PregelNode:
- """A node in a Pregel graph. This won't be invoked as a runnable by the graph
- itself, but instead acts as a container for the components necessary to make
- a PregelExecutableTask for a node."""
- channels: str | list[str]
- """The channels that will be passed as input to `bound`.
- If a str, the node will be invoked with its value if it isn't empty.
- If a list, the node will be invoked with a dict of those channels' values."""
- triggers: list[str]
- """If any of these channels is written to, this node will be triggered in
- the next step."""
- mapper: Callable[[Any], Any] | None
- """A function to transform the input before passing it to `bound`."""
- writers: list[Runnable]
- """A list of writers that will be executed after `bound`, responsible for
- taking the output of `bound` and writing it to the appropriate channels."""
- bound: Runnable[Any, Any]
- """The main logic of the node. This will be invoked with the input from
- `channels`."""
- retry_policy: Sequence[RetryPolicy] | None
- """The retry policies to use when invoking the node."""
- cache_policy: CachePolicy | None
- """The cache policy to use when invoking the node."""
- tags: Sequence[str] | None
- """Tags to attach to the node for tracing."""
- metadata: Mapping[str, Any] | None
- """Metadata to attach to the node for tracing."""
- subgraphs: Sequence[PregelProtocol]
- """Subgraphs used by the node."""
- def __init__(
- self,
- *,
- channels: str | list[str],
- triggers: Sequence[str],
- mapper: Callable[[Any], Any] | None = None,
- writers: list[Runnable] | None = None,
- tags: list[str] | None = None,
- metadata: Mapping[str, Any] | None = None,
- bound: Runnable[Any, Any] | None = None,
- retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
- cache_policy: CachePolicy | None = None,
- subgraphs: Sequence[PregelProtocol] | None = None,
- ) -> None:
- self.channels = channels
- self.triggers = list(triggers)
- self.mapper = mapper
- self.writers = writers or []
- self.bound = bound if bound is not None else DEFAULT_BOUND
- self.cache_policy = cache_policy
- if isinstance(retry_policy, RetryPolicy):
- self.retry_policy = (retry_policy,)
- else:
- self.retry_policy = retry_policy
- self.tags = tags
- self.metadata = metadata
- if subgraphs is not None:
- self.subgraphs = subgraphs
- elif self.bound is not DEFAULT_BOUND:
- try:
- subgraph = find_subgraph_pregel(self.bound)
- except Exception:
- subgraph = None
- if subgraph:
- self.subgraphs = [subgraph]
- else:
- self.subgraphs = []
- else:
- self.subgraphs = []
- def copy(self, update: dict[str, Any]) -> PregelNode:
- attrs = {**self.__dict__, **update}
- # Drop the cached properties
- attrs.pop("flat_writers", None)
- attrs.pop("node", None)
- attrs.pop("input_cache_key", None)
- return PregelNode(**attrs)
- @cached_property
- def flat_writers(self) -> list[Runnable]:
- """Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
- writers = self.writers.copy()
- while (
- len(writers) > 1
- and isinstance(writers[-1], ChannelWrite)
- and isinstance(writers[-2], ChannelWrite)
- ):
- # we can combine writes if they are consecutive
- # careful to not modify the original writers list or ChannelWrite
- writers[-2] = ChannelWrite(
- writes=writers[-2].writes + writers[-1].writes,
- )
- writers.pop()
- return writers
- @cached_property
- def node(self) -> Runnable[Any, Any] | None:
- """Get a runnable that combines `bound` and `writers`."""
- writers = self.flat_writers
- if self.bound is DEFAULT_BOUND and not writers:
- return None
- elif self.bound is DEFAULT_BOUND and len(writers) == 1:
- return writers[0]
- elif self.bound is DEFAULT_BOUND:
- return RunnableSeq(*writers)
- elif writers:
- return RunnableSeq(self.bound, *writers)
- else:
- return self.bound
- @cached_property
- def input_cache_key(self) -> INPUT_CACHE_KEY_TYPE:
- """Get a cache key for the input to the node.
- This is used to avoid calculating the same input multiple times."""
- return (
- self.mapper,
- tuple(self.channels)
- if isinstance(self.channels, list)
- else (self.channels,),
- )
- def invoke(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> Any:
- self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
- return self.bound.invoke(
- input,
- merge_configs(self_config, config),
- **kwargs,
- )
- async def ainvoke(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> Any:
- self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
- return await self.bound.ainvoke(
- input,
- merge_configs(self_config, config),
- **kwargs,
- )
- def stream(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> Iterator[Any]:
- self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
- yield from self.bound.stream(
- input,
- merge_configs(self_config, config),
- **kwargs,
- )
- async def astream(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> AsyncIterator[Any]:
- self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
- async for item in self.bound.astream(
- input,
- merge_configs(self_config, config),
- **kwargs,
- ):
- yield item
|