| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- from __future__ import annotations
- from collections import defaultdict
- from collections.abc import Mapping, Sequence
- from typing import Any, NamedTuple, cast
- from langchain_core.runnables.config import RunnableConfig
- from langchain_core.runnables.graph import Graph, Node
- from langgraph.checkpoint.base import BaseCheckpointSaver
- from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, INPUT
- from langgraph.channels.base import BaseChannel
- from langgraph.channels.last_value import LastValueAfterFinish
- from langgraph.constants import END, START
- from langgraph.managed.base import ManagedValueSpec
- from langgraph.pregel._algo import (
- PregelTaskWrites,
- apply_writes,
- increment,
- prepare_next_tasks,
- )
- from langgraph.pregel._checkpoint import channels_from_checkpoint, empty_checkpoint
- from langgraph.pregel._io import map_input
- from langgraph.pregel._read import PregelNode
- from langgraph.pregel._write import ChannelWrite
- from langgraph.types import All, Checkpointer
- class Edge(NamedTuple):
- source: str
- target: str
- conditional: bool
- data: str | None
- class TriggerEdge(NamedTuple):
- source: str
- conditional: bool
- data: str | None
- def draw_graph(
- config: RunnableConfig,
- *,
- nodes: dict[str, PregelNode],
- specs: dict[str, BaseChannel | ManagedValueSpec],
- input_channels: str | Sequence[str],
- interrupt_after_nodes: All | Sequence[str],
- interrupt_before_nodes: All | Sequence[str],
- trigger_to_nodes: Mapping[str, Sequence[str]],
- checkpointer: Checkpointer,
- subgraphs: dict[str, Graph],
- limit: int = 250,
- ) -> Graph:
- """Get the graph for this Pregel instance.
- Args:
- config: The configuration to use for the graph.
- subgraphs: The subgraphs to include in the graph.
- checkpointer: The checkpointer to use for the graph.
- Returns:
- The graph for this Pregel instance.
- """
- # (src, dest, is_conditional, label)
- edges: set[Edge] = set()
- step = -1
- checkpoint = empty_checkpoint()
- get_next_version = (
- checkpointer.get_next_version
- if isinstance(checkpointer, BaseCheckpointSaver)
- else increment
- )
- channels, managed = channels_from_checkpoint(
- specs,
- checkpoint,
- )
- static_seen: set[Any] = set()
- sources: dict[str, set[TriggerEdge]] = {}
- step_sources: dict[str, set[TriggerEdge]] = {}
- static_declared_writes: dict[str, set[TriggerEdge]] = defaultdict(set)
- # remove node mappers
- nodes = {
- k: v.copy(update={"mapper": None}) if v.mapper is not None else v
- for k, v in nodes.items()
- }
- # apply input writes
- input_writes = list(map_input(input_channels, {}))
- updated_channels = apply_writes(
- checkpoint,
- channels,
- [
- PregelTaskWrites((), INPUT, input_writes, []),
- ],
- get_next_version,
- trigger_to_nodes,
- )
- # prepare first tasks
- tasks = prepare_next_tasks(
- checkpoint,
- [],
- nodes,
- channels,
- managed,
- config,
- step,
- -1,
- for_execution=True,
- store=None,
- checkpointer=None,
- manager=None,
- trigger_to_nodes=trigger_to_nodes,
- updated_channels=updated_channels,
- )
- start_tasks = tasks
- # run the pregel loop
- for step in range(step, limit):
- if not tasks:
- break
- conditionals: dict[tuple[str, str, Any], str | None] = {}
- # run task writers
- for task in tasks.values():
- for w in task.writers:
- # apply regular writes
- if isinstance(w, ChannelWrite):
- empty_input = (
- cast(BaseChannel, specs["__root__"]).ValueType()
- if "__root__" in specs
- else None
- )
- w.invoke(empty_input, task.config)
- # apply conditional writes declared for static analysis, only once
- if w not in static_seen:
- static_seen.add(w)
- # apply static writes
- if writes := ChannelWrite.get_static_writes(w):
- # END writes are not written, but become edges directly
- for t in writes:
- if t[0] == END:
- edges.add(Edge(task.name, t[0], True, t[2]))
- writes = [t for t in writes if t[0] != END]
- conditionals.update(
- {(task.name, t[0], t[1] or None): t[2] for t in writes}
- )
- # record static writes for edge creation
- for t in writes:
- static_declared_writes[task.name].add(
- TriggerEdge(t[0], True, t[2])
- )
- task.config[CONF][CONFIG_KEY_SEND]([t[:2] for t in writes])
- # collect sources
- step_sources = {}
- for task in tasks.values():
- task_edges = {
- TriggerEdge(
- w[0],
- (task.name, w[0], w[1] or None) in conditionals,
- conditionals.get((task.name, w[0], w[1] or None)),
- )
- for w in task.writes
- }
- task_edges |= static_declared_writes.get(task.name, set())
- step_sources[task.name] = task_edges
- sources.update(step_sources)
- # invert triggers
- trigger_to_sources: dict[str, set[TriggerEdge]] = defaultdict(set)
- for src, triggers in sources.items():
- for trigger, cond, label in triggers:
- trigger_to_sources[trigger].add(TriggerEdge(src, cond, label))
- # apply writes
- updated_channels = apply_writes(
- checkpoint, channels, tasks.values(), get_next_version, trigger_to_nodes
- )
- # prepare next tasks
- tasks = prepare_next_tasks(
- checkpoint,
- [],
- nodes,
- channels,
- managed,
- config,
- step,
- limit,
- for_execution=True,
- store=None,
- checkpointer=None,
- manager=None,
- trigger_to_nodes=trigger_to_nodes,
- updated_channels=updated_channels,
- )
- # collect deferred nodes
- deferred_nodes: set[str] = set()
- edges_to_deferred_nodes: set[Edge] = set()
- for channel, item in channels.items():
- if isinstance(item, LastValueAfterFinish):
- deferred_node = channel.split(":", 2)[-1]
- deferred_nodes.add(deferred_node)
- # collect edges
- for task in tasks.values():
- added = False
- for trigger in task.triggers:
- for src, cond, label in sorted(trigger_to_sources[trigger]):
- # record edge to be reviewed later
- if task.name in deferred_nodes:
- edges_to_deferred_nodes.add(Edge(src, task.name, cond, label))
- edges.add(Edge(src, task.name, cond, label))
- # if the edge is from this step, skip adding the implicit edges
- if (trigger, cond, label) in step_sources.get(src, set()):
- added = True
- else:
- sources[src].discard(TriggerEdge(trigger, cond, label))
- # if no edges from this step, add implicit edges from all previous tasks
- if not added:
- for src in step_sources:
- edges.add(Edge(src, task.name, True, None))
- # assemble the graph
- graph = Graph()
- # add nodes
- for name, node in nodes.items():
- metadata = dict(node.metadata or {})
- if name in deferred_nodes:
- metadata["defer"] = True
- if name in interrupt_before_nodes and name in interrupt_after_nodes:
- metadata["__interrupt"] = "before,after"
- elif name in interrupt_before_nodes:
- metadata["__interrupt"] = "before"
- elif name in interrupt_after_nodes:
- metadata["__interrupt"] = "after"
- graph.add_node(node.bound, name, metadata=metadata or None)
- # add start node
- if START not in nodes:
- graph.add_node(None, START)
- for task in start_tasks.values():
- add_edge(graph, START, task.name)
- # add discovered edges
- for src, dest, is_conditional, label in sorted(edges):
- add_edge(
- graph,
- src,
- dest,
- data=label if label != dest else None,
- conditional=is_conditional,
- )
- # add end edges
- termini = {d for _, d, _, _ in edges if d != END}.difference(
- s for s, _, _, _ in edges
- )
- end_edge_exists = any(d == END for _, d, _, _ in edges)
- if termini:
- for src in sorted(termini):
- add_edge(graph, src, END)
- elif len(step_sources) == 1 and not end_edge_exists:
- for src in sorted(step_sources):
- add_edge(graph, src, END, conditional=True)
- # replace subgraphs
- for name, subgraph in subgraphs.items():
- if (
- len(subgraph.nodes) > 1
- and name in graph.nodes
- and subgraph.first_node()
- and subgraph.last_node()
- ):
- subgraph.trim_first_node()
- subgraph.trim_last_node()
- # replace the node with the subgraph
- graph.nodes.pop(name)
- first, last = graph.extend(subgraph, prefix=name)
- for idx, edge in enumerate(graph.edges):
- if edge.source == name:
- edge = edge.copy(source=cast(Node, last).id)
- if edge.target == name:
- edge = edge.copy(target=cast(Node, first).id)
- graph.edges[idx] = edge
- return graph
- def add_edge(
- graph: Graph,
- source: str,
- target: str,
- *,
- data: Any | None = None,
- conditional: bool = False,
- ) -> None:
- """Add an edge to the graph."""
- for edge in graph.edges:
- if edge.source == source and edge.target == target:
- return
- if target not in graph.nodes and target == END:
- graph.add_node(None, END)
- graph.add_edge(graph.nodes[source], graph.nodes[target], data, conditional)
|