_draw.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from __future__ import annotations
  2. from collections import defaultdict
  3. from collections.abc import Mapping, Sequence
  4. from typing import Any, NamedTuple, cast
  5. from langchain_core.runnables.config import RunnableConfig
  6. from langchain_core.runnables.graph import Graph, Node
  7. from langgraph.checkpoint.base import BaseCheckpointSaver
  8. from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, INPUT
  9. from langgraph.channels.base import BaseChannel
  10. from langgraph.channels.last_value import LastValueAfterFinish
  11. from langgraph.constants import END, START
  12. from langgraph.managed.base import ManagedValueSpec
  13. from langgraph.pregel._algo import (
  14. PregelTaskWrites,
  15. apply_writes,
  16. increment,
  17. prepare_next_tasks,
  18. )
  19. from langgraph.pregel._checkpoint import channels_from_checkpoint, empty_checkpoint
  20. from langgraph.pregel._io import map_input
  21. from langgraph.pregel._read import PregelNode
  22. from langgraph.pregel._write import ChannelWrite
  23. from langgraph.types import All, Checkpointer
  24. class Edge(NamedTuple):
  25. source: str
  26. target: str
  27. conditional: bool
  28. data: str | None
  29. class TriggerEdge(NamedTuple):
  30. source: str
  31. conditional: bool
  32. data: str | None
  33. def draw_graph(
  34. config: RunnableConfig,
  35. *,
  36. nodes: dict[str, PregelNode],
  37. specs: dict[str, BaseChannel | ManagedValueSpec],
  38. input_channels: str | Sequence[str],
  39. interrupt_after_nodes: All | Sequence[str],
  40. interrupt_before_nodes: All | Sequence[str],
  41. trigger_to_nodes: Mapping[str, Sequence[str]],
  42. checkpointer: Checkpointer,
  43. subgraphs: dict[str, Graph],
  44. limit: int = 250,
  45. ) -> Graph:
  46. """Get the graph for this Pregel instance.
  47. Args:
  48. config: The configuration to use for the graph.
  49. subgraphs: The subgraphs to include in the graph.
  50. checkpointer: The checkpointer to use for the graph.
  51. Returns:
  52. The graph for this Pregel instance.
  53. """
  54. # (src, dest, is_conditional, label)
  55. edges: set[Edge] = set()
  56. step = -1
  57. checkpoint = empty_checkpoint()
  58. get_next_version = (
  59. checkpointer.get_next_version
  60. if isinstance(checkpointer, BaseCheckpointSaver)
  61. else increment
  62. )
  63. channels, managed = channels_from_checkpoint(
  64. specs,
  65. checkpoint,
  66. )
  67. static_seen: set[Any] = set()
  68. sources: dict[str, set[TriggerEdge]] = {}
  69. step_sources: dict[str, set[TriggerEdge]] = {}
  70. static_declared_writes: dict[str, set[TriggerEdge]] = defaultdict(set)
  71. # remove node mappers
  72. nodes = {
  73. k: v.copy(update={"mapper": None}) if v.mapper is not None else v
  74. for k, v in nodes.items()
  75. }
  76. # apply input writes
  77. input_writes = list(map_input(input_channels, {}))
  78. updated_channels = apply_writes(
  79. checkpoint,
  80. channels,
  81. [
  82. PregelTaskWrites((), INPUT, input_writes, []),
  83. ],
  84. get_next_version,
  85. trigger_to_nodes,
  86. )
  87. # prepare first tasks
  88. tasks = prepare_next_tasks(
  89. checkpoint,
  90. [],
  91. nodes,
  92. channels,
  93. managed,
  94. config,
  95. step,
  96. -1,
  97. for_execution=True,
  98. store=None,
  99. checkpointer=None,
  100. manager=None,
  101. trigger_to_nodes=trigger_to_nodes,
  102. updated_channels=updated_channels,
  103. )
  104. start_tasks = tasks
  105. # run the pregel loop
  106. for step in range(step, limit):
  107. if not tasks:
  108. break
  109. conditionals: dict[tuple[str, str, Any], str | None] = {}
  110. # run task writers
  111. for task in tasks.values():
  112. for w in task.writers:
  113. # apply regular writes
  114. if isinstance(w, ChannelWrite):
  115. empty_input = (
  116. cast(BaseChannel, specs["__root__"]).ValueType()
  117. if "__root__" in specs
  118. else None
  119. )
  120. w.invoke(empty_input, task.config)
  121. # apply conditional writes declared for static analysis, only once
  122. if w not in static_seen:
  123. static_seen.add(w)
  124. # apply static writes
  125. if writes := ChannelWrite.get_static_writes(w):
  126. # END writes are not written, but become edges directly
  127. for t in writes:
  128. if t[0] == END:
  129. edges.add(Edge(task.name, t[0], True, t[2]))
  130. writes = [t for t in writes if t[0] != END]
  131. conditionals.update(
  132. {(task.name, t[0], t[1] or None): t[2] for t in writes}
  133. )
  134. # record static writes for edge creation
  135. for t in writes:
  136. static_declared_writes[task.name].add(
  137. TriggerEdge(t[0], True, t[2])
  138. )
  139. task.config[CONF][CONFIG_KEY_SEND]([t[:2] for t in writes])
  140. # collect sources
  141. step_sources = {}
  142. for task in tasks.values():
  143. task_edges = {
  144. TriggerEdge(
  145. w[0],
  146. (task.name, w[0], w[1] or None) in conditionals,
  147. conditionals.get((task.name, w[0], w[1] or None)),
  148. )
  149. for w in task.writes
  150. }
  151. task_edges |= static_declared_writes.get(task.name, set())
  152. step_sources[task.name] = task_edges
  153. sources.update(step_sources)
  154. # invert triggers
  155. trigger_to_sources: dict[str, set[TriggerEdge]] = defaultdict(set)
  156. for src, triggers in sources.items():
  157. for trigger, cond, label in triggers:
  158. trigger_to_sources[trigger].add(TriggerEdge(src, cond, label))
  159. # apply writes
  160. updated_channels = apply_writes(
  161. checkpoint, channels, tasks.values(), get_next_version, trigger_to_nodes
  162. )
  163. # prepare next tasks
  164. tasks = prepare_next_tasks(
  165. checkpoint,
  166. [],
  167. nodes,
  168. channels,
  169. managed,
  170. config,
  171. step,
  172. limit,
  173. for_execution=True,
  174. store=None,
  175. checkpointer=None,
  176. manager=None,
  177. trigger_to_nodes=trigger_to_nodes,
  178. updated_channels=updated_channels,
  179. )
  180. # collect deferred nodes
  181. deferred_nodes: set[str] = set()
  182. edges_to_deferred_nodes: set[Edge] = set()
  183. for channel, item in channels.items():
  184. if isinstance(item, LastValueAfterFinish):
  185. deferred_node = channel.split(":", 2)[-1]
  186. deferred_nodes.add(deferred_node)
  187. # collect edges
  188. for task in tasks.values():
  189. added = False
  190. for trigger in task.triggers:
  191. for src, cond, label in sorted(trigger_to_sources[trigger]):
  192. # record edge to be reviewed later
  193. if task.name in deferred_nodes:
  194. edges_to_deferred_nodes.add(Edge(src, task.name, cond, label))
  195. edges.add(Edge(src, task.name, cond, label))
  196. # if the edge is from this step, skip adding the implicit edges
  197. if (trigger, cond, label) in step_sources.get(src, set()):
  198. added = True
  199. else:
  200. sources[src].discard(TriggerEdge(trigger, cond, label))
  201. # if no edges from this step, add implicit edges from all previous tasks
  202. if not added:
  203. for src in step_sources:
  204. edges.add(Edge(src, task.name, True, None))
  205. # assemble the graph
  206. graph = Graph()
  207. # add nodes
  208. for name, node in nodes.items():
  209. metadata = dict(node.metadata or {})
  210. if name in deferred_nodes:
  211. metadata["defer"] = True
  212. if name in interrupt_before_nodes and name in interrupt_after_nodes:
  213. metadata["__interrupt"] = "before,after"
  214. elif name in interrupt_before_nodes:
  215. metadata["__interrupt"] = "before"
  216. elif name in interrupt_after_nodes:
  217. metadata["__interrupt"] = "after"
  218. graph.add_node(node.bound, name, metadata=metadata or None)
  219. # add start node
  220. if START not in nodes:
  221. graph.add_node(None, START)
  222. for task in start_tasks.values():
  223. add_edge(graph, START, task.name)
  224. # add discovered edges
  225. for src, dest, is_conditional, label in sorted(edges):
  226. add_edge(
  227. graph,
  228. src,
  229. dest,
  230. data=label if label != dest else None,
  231. conditional=is_conditional,
  232. )
  233. # add end edges
  234. termini = {d for _, d, _, _ in edges if d != END}.difference(
  235. s for s, _, _, _ in edges
  236. )
  237. end_edge_exists = any(d == END for _, d, _, _ in edges)
  238. if termini:
  239. for src in sorted(termini):
  240. add_edge(graph, src, END)
  241. elif len(step_sources) == 1 and not end_edge_exists:
  242. for src in sorted(step_sources):
  243. add_edge(graph, src, END, conditional=True)
  244. # replace subgraphs
  245. for name, subgraph in subgraphs.items():
  246. if (
  247. len(subgraph.nodes) > 1
  248. and name in graph.nodes
  249. and subgraph.first_node()
  250. and subgraph.last_node()
  251. ):
  252. subgraph.trim_first_node()
  253. subgraph.trim_last_node()
  254. # replace the node with the subgraph
  255. graph.nodes.pop(name)
  256. first, last = graph.extend(subgraph, prefix=name)
  257. for idx, edge in enumerate(graph.edges):
  258. if edge.source == name:
  259. edge = edge.copy(source=cast(Node, last).id)
  260. if edge.target == name:
  261. edge = edge.copy(target=cast(Node, first).id)
  262. graph.edges[idx] = edge
  263. return graph
  264. def add_edge(
  265. graph: Graph,
  266. source: str,
  267. target: str,
  268. *,
  269. data: Any | None = None,
  270. conditional: bool = False,
  271. ) -> None:
  272. """Add an edge to the graph."""
  273. for edge in graph.edges:
  274. if edge.source == source and edge.target == target:
  275. return
  276. if target not in graph.nodes and target == END:
  277. graph.add_node(None, END)
  278. graph.add_edge(graph.nodes[source], graph.nodes[target], data, conditional)