| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- from __future__ import annotations
- from collections import Counter
- from collections.abc import Iterator, Mapping, Sequence
- from typing import Any, Literal
- from langgraph._internal._constants import (
- ERROR,
- INTERRUPT,
- NULL_TASK_ID,
- RESUME,
- RETURN,
- TASKS,
- )
- from langgraph._internal._typing import EMPTY_SEQ, MISSING
- from langgraph.channels.base import BaseChannel, EmptyChannelError
- from langgraph.constants import START, TAG_HIDDEN
- from langgraph.errors import InvalidUpdateError
- from langgraph.pregel._log import logger
- from langgraph.types import Command, PregelExecutableTask, Send
- def read_channel(
- channels: Mapping[str, BaseChannel],
- chan: str,
- *,
- catch: bool = True,
- ) -> Any:
- try:
- return channels[chan].get()
- except EmptyChannelError:
- if catch:
- return None
- else:
- raise
- def read_channels(
- channels: Mapping[str, BaseChannel],
- select: Sequence[str] | str,
- *,
- skip_empty: bool = True,
- ) -> dict[str, Any] | Any:
- if isinstance(select, str):
- return read_channel(channels, select)
- else:
- values: dict[str, Any] = {}
- for k in select:
- try:
- values[k] = read_channel(channels, k, catch=not skip_empty)
- except EmptyChannelError:
- pass
- return values
- def map_command(cmd: Command) -> Iterator[tuple[str, str, Any]]:
- """Map input chunk to a sequence of pending writes in the form (channel, value)."""
- if cmd.graph == Command.PARENT:
- raise InvalidUpdateError("There is no parent graph")
- if cmd.goto:
- if isinstance(cmd.goto, (tuple, list)):
- sends = cmd.goto
- else:
- sends = [cmd.goto]
- for send in sends:
- if isinstance(send, Send):
- yield (NULL_TASK_ID, TASKS, send)
- elif isinstance(send, str):
- yield (NULL_TASK_ID, f"branch:to:{send}", START)
- else:
- raise TypeError(
- f"In Command.goto, expected Send/str, got {type(send).__name__}"
- )
- if cmd.resume is not None:
- yield (NULL_TASK_ID, RESUME, cmd.resume)
- if cmd.update:
- for k, v in cmd._update_as_tuples():
- yield (NULL_TASK_ID, k, v)
- def map_input(
- input_channels: str | Sequence[str],
- chunk: dict[str, Any] | Any | None,
- ) -> Iterator[tuple[str, Any]]:
- """Map input chunk to a sequence of pending writes in the form (channel, value)."""
- if chunk is None:
- return
- elif isinstance(input_channels, str):
- yield (input_channels, chunk)
- else:
- if not isinstance(chunk, dict):
- raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}")
- for k in chunk:
- if k in input_channels:
- yield (k, chunk[k])
- else:
- logger.warning(f"Input channel {k} not found in {input_channels}")
- def map_output_values(
- output_channels: str | Sequence[str],
- pending_writes: Literal[True] | Sequence[tuple[str, Any]],
- channels: Mapping[str, BaseChannel],
- ) -> Iterator[dict[str, Any] | Any]:
- """Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
- if isinstance(output_channels, str):
- if pending_writes is True or any(
- chan == output_channels for chan, _ in pending_writes
- ):
- yield read_channel(channels, output_channels)
- else:
- if pending_writes is True or {
- c for c, _ in pending_writes if c in output_channels
- }:
- yield read_channels(channels, output_channels)
- def map_output_updates(
- output_channels: str | Sequence[str],
- tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
- cached: bool = False,
- ) -> Iterator[dict[str, Any | dict[str, Any]]]:
- """Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
- output_tasks = [
- (t, ww)
- for t, ww in tasks
- if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
- and ww[0][0] != ERROR
- and ww[0][0] != INTERRUPT
- ]
- if not output_tasks:
- return
- updated: list[tuple[str, Any]] = []
- for task, writes in output_tasks:
- rtn = next((value for chan, value in writes if chan == RETURN), MISSING)
- if rtn is not MISSING:
- updated.append((task.name, rtn))
- elif isinstance(output_channels, str):
- updated.extend(
- (task.name, value) for chan, value in writes if chan == output_channels
- )
- elif any(chan in output_channels for chan, _ in writes):
- counts = Counter(chan for chan, _ in writes)
- if any(counts[chan] > 1 for chan in output_channels):
- updated.extend(
- (
- task.name,
- {chan: value},
- )
- for chan, value in writes
- if chan in output_channels
- )
- else:
- updated.append(
- (
- task.name,
- {
- chan: value
- for chan, value in writes
- if chan in output_channels
- },
- )
- )
- grouped: dict[str, Any] = {t.name: [] for t, _ in output_tasks}
- for node, value in updated:
- grouped[node].append(value)
- for node, value in grouped.items():
- if len(value) == 0:
- grouped[node] = None
- if len(value) == 1:
- grouped[node] = value[0]
- if cached:
- grouped["__metadata__"] = {"cached": cached}
- yield grouped
|